[split-required] Split final 43 files (500-668 LOC) to complete refactoring

klausur-service (11 files):
- cv_gutter_repair, ocr_pipeline_regression, upload_api
- ocr_pipeline_sessions, smart_spell, nru_worksheet_generator
- ocr_pipeline_overlays, mail/aggregator, zeugnis_api
- cv_syllable_detect, self_rag

backend-lehrer (17 files):
- classroom_engine/suggestions, generators/quiz_generator
- worksheets_api, llm_gateway/comparison, state_engine_api
- classroom/models (→ 4 submodules), services/file_processor
- alerts_agent/api/wizard+digests+routes, content_generators/pdf
- classroom/routes/sessions, llm_gateway/inference
- classroom_engine/analytics, auth/keycloak_auth
- alerts_agent/processing/rule_engine, ai_processor/print_versions

agent-core (5 files):
- brain/memory_store, brain/knowledge_graph, brain/context_manager
- orchestrator/supervisor, sessions/session_manager

admin-lehrer (5 components):
- GridOverlay, StepGridReview, DevOpsPipelineSidebar
- DataFlowDiagram, sbom/wizard/page

website (2 files):
- DependencyMap, lehrer/abitur-archiv

Other: nibis_ingestion, grid_detection_service, export-doclayout-onnx

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 09:41:42 +02:00
parent 451365a312
commit bd4b956e3c
113 changed files with 13790 additions and 14148 deletions

View File

@@ -1,610 +1,35 @@
"""
Gutter Repair — detects and fixes words truncated or blurred at the book gutter.
Gutter Repair — barrel re-export.
When scanning double-page spreads, the binding area (gutter) causes:
1. Blurry/garbled trailing characters ("stammeli""stammeln")
2. Words split across lines with a hyphen lost in the gutter
("ve" + "künden""verkünden")
This module analyses grid cells, identifies gutter-edge candidates, and
proposes corrections using pyspellchecker (DE + EN).
All implementation split into:
cv_gutter_repair_core — spellchecker setup, data types, single-word repair
cv_gutter_repair_grid — grid analysis, suggestion application
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import itertools
import logging
import re
import time
import uuid
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Spellchecker setup (lazy, cached)
# ---------------------------------------------------------------------------
_spell_de = None
_spell_en = None
_SPELL_AVAILABLE = False
def _init_spellcheckers():
"""Lazy-load DE + EN spellcheckers (cached across calls)."""
global _spell_de, _spell_en, _SPELL_AVAILABLE
if _spell_de is not None:
return
try:
from spellchecker import SpellChecker
_spell_de = SpellChecker(language='de', distance=1)
_spell_en = SpellChecker(language='en', distance=1)
_SPELL_AVAILABLE = True
logger.info("Gutter repair: spellcheckers loaded (DE + EN)")
except ImportError:
logger.warning("pyspellchecker not installed — gutter repair unavailable")
def _is_known(word: str) -> bool:
"""Check if a word is known in DE or EN dictionary."""
_init_spellcheckers()
if not _SPELL_AVAILABLE:
return False
w = word.lower()
return bool(_spell_de.known([w])) or bool(_spell_en.known([w]))
def _spell_candidates(word: str, lang: str = "both") -> List[str]:
"""Get all plausible spellchecker candidates for a word (deduplicated)."""
_init_spellcheckers()
if not _SPELL_AVAILABLE:
return []
w = word.lower()
seen: set = set()
results: List[str] = []
for checker in ([_spell_de, _spell_en] if lang == "both"
else [_spell_de] if lang == "de"
else [_spell_en]):
if checker is None:
continue
cands = checker.candidates(w)
if cands:
for c in cands:
if c and c != w and c not in seen:
seen.add(c)
results.append(c)
return results
# ---------------------------------------------------------------------------
# Gutter position detection
# ---------------------------------------------------------------------------
# Minimum word length for spell-fix (very short words are often legitimate)
_MIN_WORD_LEN_SPELL = 3
# Minimum word length for hyphen-join candidates (fragments at the gutter
# can be as short as 1-2 chars, e.g. "ve" from "ver-künden")
_MIN_WORD_LEN_HYPHEN = 2
# How close to the right column edge a word must be to count as "gutter-adjacent".
# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%).
_GUTTER_EDGE_THRESHOLD = 0.70
# Small common words / abbreviations that should NOT be repaired
_STOPWORDS = frozenset([
# German
"ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um",
"zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh",
# English
"a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in",
"is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us",
"we",
])
# IPA / phonetic patterns — skip these cells
_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]')
def _is_ipa_text(text: str) -> bool:
"""True if text looks like IPA transcription."""
return bool(_IPA_RE.search(text))
def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool:
"""Check if a word's right edge is near the right boundary of its column."""
if col_width <= 0:
return False
word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0)
col_right = col_x + col_width
# Word's right edge within the rightmost portion of the column
relative_pos = (word_right - col_x) / col_width
return relative_pos >= _GUTTER_EDGE_THRESHOLD
# ---------------------------------------------------------------------------
# Suggestion types
# ---------------------------------------------------------------------------
@dataclass
class GutterSuggestion:
"""A single correction suggestion."""
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
type: str = "" # "hyphen_join" | "spell_fix"
zone_index: int = 0
row_index: int = 0
col_index: int = 0
col_type: str = ""
cell_id: str = ""
original_text: str = ""
suggested_text: str = ""
# For hyphen_join:
next_row_index: int = -1
next_row_cell_id: str = ""
next_row_text: str = ""
missing_chars: str = ""
display_parts: List[str] = field(default_factory=list)
# Alternatives (other plausible corrections the user can pick from)
alternatives: List[str] = field(default_factory=list)
# Meta:
confidence: float = 0.0
reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation"
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
# ---------------------------------------------------------------------------
# Core repair logic
# ---------------------------------------------------------------------------
_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$')
def _try_hyphen_join(
word_text: str,
next_word_text: str,
max_missing: int = 3,
) -> Optional[Tuple[str, str, float]]:
"""Try joining two fragments with 0..max_missing interpolated chars.
Strips trailing punctuation from the continuation word before testing
(e.g. "künden,""künden") so dictionary lookup succeeds.
Returns (joined_word, missing_chars, confidence) or None.
"""
base = word_text.rstrip("-").rstrip()
# Strip trailing punctuation from continuation (commas, periods, etc.)
raw_continuation = next_word_text.lstrip()
continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation)
if not base or not continuation:
return None
# 1. Direct join (no missing chars)
direct = base + continuation
if _is_known(direct):
return (direct, "", 0.95)
# 2. Try with 1..max_missing missing characters
# Use common letters, weighted by frequency in German/English
_COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu"
for n_missing in range(1, max_missing + 1):
for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing):
candidate = base + "".join(chars) + continuation
if _is_known(candidate):
missing = "".join(chars)
# Confidence decreases with more missing chars
conf = 0.90 - (n_missing - 1) * 0.10
return (candidate, missing, conf)
return None
def _try_spell_fix(
word_text: str, col_type: str = "",
) -> Optional[Tuple[str, float, List[str]]]:
"""Try to fix a single garbled gutter word via spellchecker.
Returns (best_correction, confidence, alternatives_list) or None.
The alternatives list contains other plausible corrections the user
can choose from (e.g. "stammelt" vs "stammeln").
"""
if len(word_text) < _MIN_WORD_LEN_SPELL:
return None
# Strip trailing/leading parentheses and check if the bare word is valid.
# Words like "probieren)" or "(Englisch" are valid words with punctuation,
# not OCR errors. Don't suggest corrections for them.
stripped = word_text.strip("()")
if stripped and _is_known(stripped):
return None
# Determine language priority from column type
if "en" in col_type:
lang = "en"
elif "de" in col_type:
lang = "de"
else:
lang = "both"
candidates = _spell_candidates(word_text, lang=lang)
if not candidates and lang != "both":
candidates = _spell_candidates(word_text, lang="both")
if not candidates:
return None
# Preserve original casing
is_upper = word_text[0].isupper()
def _preserve_case(w: str) -> str:
if is_upper and w:
return w[0].upper() + w[1:]
return w
# Sort candidates by edit distance (closest first)
scored = []
for c in candidates:
dist = _edit_distance(word_text.lower(), c.lower())
scored.append((dist, c))
scored.sort(key=lambda x: x[0])
best_dist, best = scored[0]
best = _preserve_case(best)
conf = max(0.5, 1.0 - best_dist * 0.15)
# Build alternatives (all other candidates, also case-preserved)
alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()]
# Limit to top 5 alternatives
alts = alts[:5]
return (best, conf, alts)
def _edit_distance(a: str, b: str) -> int:
"""Simple Levenshtein distance."""
if len(a) < len(b):
return _edit_distance(b, a)
if len(b) == 0:
return len(a)
prev = list(range(len(b) + 1))
for i, ca in enumerate(a):
curr = [i + 1]
for j, cb in enumerate(b):
cost = 0 if ca == cb else 1
curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost))
prev = curr
return prev[len(b)]
# ---------------------------------------------------------------------------
# Grid analysis
# ---------------------------------------------------------------------------
def analyse_grid_for_gutter_repair(
grid_data: Dict[str, Any],
image_width: int = 0,
) -> Dict[str, Any]:
"""Analyse a structured grid and return gutter repair suggestions.
Args:
grid_data: The grid_editor_result from the session (zones→cells structure).
image_width: Image width in pixels (for determining gutter side).
Returns:
Dict with "suggestions" list and "stats".
"""
t0 = time.time()
_init_spellcheckers()
if not _SPELL_AVAILABLE:
return {
"suggestions": [],
"stats": {"error": "pyspellchecker not installed"},
"duration_seconds": 0,
}
zones = grid_data.get("zones", [])
suggestions: List[GutterSuggestion] = []
words_checked = 0
gutter_candidates = 0
for zi, zone in enumerate(zones):
columns = zone.get("columns", [])
cells = zone.get("cells", [])
if not columns or not cells:
continue
# Build column lookup: col_index → {x, width, type}
col_info: Dict[int, Dict] = {}
for col in columns:
ci = col.get("index", col.get("col_index", -1))
col_info[ci] = {
"x": col.get("x_min_px", col.get("x", 0)),
"width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)),
"type": col.get("type", col.get("col_type", "")),
}
# Build row→col→cell lookup
cell_map: Dict[Tuple[int, int], Dict] = {}
max_row = 0
for cell in cells:
ri = cell.get("row_index", 0)
ci = cell.get("col_index", 0)
cell_map[(ri, ci)] = cell
if ri > max_row:
max_row = ri
# Determine which columns are at the gutter edge.
# For a left page: rightmost content columns.
# For now, check ALL columns — a word is a candidate if it's at the
# right edge of its column AND not a known word.
for (ri, ci), cell in cell_map.items():
text = (cell.get("text") or "").strip()
if not text:
continue
if _is_ipa_text(text):
continue
words_checked += 1
col = col_info.get(ci, {})
col_type = col.get("type", "")
# Get word boxes to check position
word_boxes = cell.get("word_boxes", [])
# Check the LAST word in the cell (rightmost, closest to gutter)
cell_words = text.split()
if not cell_words:
continue
last_word = cell_words[-1]
# Skip stopwords
if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS:
continue
last_word_clean = last_word.rstrip(".,;:!?)(")
if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN:
continue
# Check if the last word is at the gutter edge
is_at_edge = False
if word_boxes:
last_wb = word_boxes[-1]
is_at_edge = _word_is_at_gutter_edge(
last_wb, col.get("x", 0), col.get("width", 1)
)
else:
# No word boxes — use cell bbox
bbox = cell.get("bbox_px", {})
is_at_edge = _word_is_at_gutter_edge(
{"left": bbox.get("x", 0), "width": bbox.get("w", 0)},
col.get("x", 0), col.get("width", 1)
)
if not is_at_edge:
continue
# Word is at gutter edge — check if it's a known word
if _is_known(last_word_clean):
continue
# Check if the word ends with "-" (explicit hyphen break)
ends_with_hyphen = last_word.endswith("-")
# If the word already ends with "-" and the stem (without
# the hyphen) is a known word, this is a VALID line-break
# hyphenation — not a gutter error. Gutter problems cause
# the hyphen to be LOST ("ve" instead of "ver-"), so a
# visible hyphen + known stem = intentional word-wrap.
# Example: "wunder-" → "wunder" is known → skip.
if ends_with_hyphen:
stem = last_word_clean.rstrip("-")
if stem and _is_known(stem):
continue
gutter_candidates += 1
# --- Strategy 1: Hyphen join with next row ---
next_cell = cell_map.get((ri + 1, ci))
if next_cell:
next_text = (next_cell.get("text") or "").strip()
next_words = next_text.split()
if next_words:
first_next = next_words[0]
first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next)
first_alpha = next((c for c in first_next if c.isalpha()), "")
# Also skip if the joined word is known (covers compound
# words where the stem alone might not be in the dictionary)
if ends_with_hyphen and first_next_clean:
direct = last_word_clean.rstrip("-") + first_next_clean
if _is_known(direct):
continue
# Continuation likely if:
# - explicit hyphen, OR
# - next row starts lowercase (= not a new entry)
if ends_with_hyphen or (first_alpha and first_alpha.islower()):
result = _try_hyphen_join(last_word_clean, first_next)
if result:
joined, missing, conf = result
# Build display parts: show hyphenation for original layout
if ends_with_hyphen:
display_p1 = last_word_clean.rstrip("-")
if missing:
display_p1 += missing
display_p1 += "-"
else:
display_p1 = last_word_clean
if missing:
display_p1 += missing + "-"
else:
display_p1 += "-"
suggestion = GutterSuggestion(
type="hyphen_join",
zone_index=zi,
row_index=ri,
col_index=ci,
col_type=col_type,
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
original_text=last_word,
suggested_text=joined,
next_row_index=ri + 1,
next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"),
next_row_text=next_text,
missing_chars=missing,
display_parts=[display_p1, first_next],
confidence=conf,
reason="gutter_truncation" if missing else "hyphen_continuation",
)
suggestions.append(suggestion)
continue # skip spell_fix if hyphen_join found
# --- Strategy 2: Single-word spell fix (only for longer words) ---
fix_result = _try_spell_fix(last_word_clean, col_type)
if fix_result:
corrected, conf, alts = fix_result
suggestion = GutterSuggestion(
type="spell_fix",
zone_index=zi,
row_index=ri,
col_index=ci,
col_type=col_type,
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
original_text=last_word,
suggested_text=corrected,
alternatives=alts,
confidence=conf,
reason="gutter_blur",
)
suggestions.append(suggestion)
duration = round(time.time() - t0, 3)
logger.info(
"Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)",
words_checked, gutter_candidates, len(suggestions), duration,
)
return {
"suggestions": [s.to_dict() for s in suggestions],
"stats": {
"words_checked": words_checked,
"gutter_candidates": gutter_candidates,
"suggestions_found": len(suggestions),
},
"duration_seconds": duration,
}
def apply_gutter_suggestions(
grid_data: Dict[str, Any],
accepted_ids: List[str],
suggestions: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Apply accepted gutter repair suggestions to the grid data.
Modifies cells in-place and returns summary of changes.
Args:
grid_data: The grid_editor_result (zones→cells).
accepted_ids: List of suggestion IDs the user accepted.
suggestions: The full suggestions list (from analyse_grid_for_gutter_repair).
Returns:
Dict with "applied_count" and "changes" list.
"""
accepted_set = set(accepted_ids)
accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set]
zones = grid_data.get("zones", [])
changes: List[Dict[str, Any]] = []
for s in accepted_suggestions:
zi = s.get("zone_index", 0)
ri = s.get("row_index", 0)
ci = s.get("col_index", 0)
stype = s.get("type", "")
if zi >= len(zones):
continue
zone_cells = zones[zi].get("cells", [])
# Find the target cell
target_cell = None
for cell in zone_cells:
if cell.get("row_index") == ri and cell.get("col_index") == ci:
target_cell = cell
break
if not target_cell:
continue
old_text = target_cell.get("text", "")
if stype == "spell_fix":
# Replace the last word in the cell text
original_word = s.get("original_text", "")
corrected = s.get("suggested_text", "")
if original_word and corrected:
# Replace from the right (last occurrence)
idx = old_text.rfind(original_word)
if idx >= 0:
new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):]
target_cell["text"] = new_text
changes.append({
"type": "spell_fix",
"zone_index": zi,
"row_index": ri,
"col_index": ci,
"cell_id": target_cell.get("cell_id", ""),
"old_text": old_text,
"new_text": new_text,
})
elif stype == "hyphen_join":
# Current cell: replace last word with the hyphenated first part
original_word = s.get("original_text", "")
joined = s.get("suggested_text", "")
display_parts = s.get("display_parts", [])
next_ri = s.get("next_row_index", -1)
if not original_word or not joined or not display_parts:
continue
# The first display part is what goes in the current row
first_part = display_parts[0] if display_parts else ""
# Replace the last word in current cell with the restored form.
# The next row is NOT modified — "künden" stays in its row
# because the original book layout has it there. We only fix
# the truncated word in the current row (e.g. "ve" → "ver-").
idx = old_text.rfind(original_word)
if idx >= 0:
new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):]
target_cell["text"] = new_text
changes.append({
"type": "hyphen_join",
"zone_index": zi,
"row_index": ri,
"col_index": ci,
"cell_id": target_cell.get("cell_id", ""),
"old_text": old_text,
"new_text": new_text,
"joined_word": joined,
})
logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions))
return {
"applied_count": len(accepted_suggestions),
"changes": changes,
}
# Core: spellchecker, data types, repair helpers
from cv_gutter_repair_core import ( # noqa: F401
_init_spellcheckers,
_is_known,
_spell_candidates,
_MIN_WORD_LEN_SPELL,
_MIN_WORD_LEN_HYPHEN,
_GUTTER_EDGE_THRESHOLD,
_STOPWORDS,
_IPA_RE,
_is_ipa_text,
_word_is_at_gutter_edge,
GutterSuggestion,
_TRAILING_PUNCT_RE,
_try_hyphen_join,
_try_spell_fix,
_edit_distance,
)
# Grid: analysis and application
from cv_gutter_repair_grid import ( # noqa: F401
analyse_grid_for_gutter_repair,
apply_gutter_suggestions,
)

View File

@@ -0,0 +1,275 @@
"""
Gutter Repair Core — spellchecker setup, data types, and single-word repair logic.
Extracted from cv_gutter_repair.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import itertools
import logging
import re
import uuid
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Spellchecker setup (lazy, cached)
# ---------------------------------------------------------------------------
_spell_de = None
_spell_en = None
_SPELL_AVAILABLE = False
def _init_spellcheckers():
"""Lazy-load DE + EN spellcheckers (cached across calls)."""
global _spell_de, _spell_en, _SPELL_AVAILABLE
if _spell_de is not None:
return
try:
from spellchecker import SpellChecker
_spell_de = SpellChecker(language='de', distance=1)
_spell_en = SpellChecker(language='en', distance=1)
_SPELL_AVAILABLE = True
logger.info("Gutter repair: spellcheckers loaded (DE + EN)")
except ImportError:
logger.warning("pyspellchecker not installed — gutter repair unavailable")
def _is_known(word: str) -> bool:
"""Check if a word is known in DE or EN dictionary."""
_init_spellcheckers()
if not _SPELL_AVAILABLE:
return False
w = word.lower()
return bool(_spell_de.known([w])) or bool(_spell_en.known([w]))
def _spell_candidates(word: str, lang: str = "both") -> List[str]:
"""Get all plausible spellchecker candidates for a word (deduplicated)."""
_init_spellcheckers()
if not _SPELL_AVAILABLE:
return []
w = word.lower()
seen: set = set()
results: List[str] = []
for checker in ([_spell_de, _spell_en] if lang == "both"
else [_spell_de] if lang == "de"
else [_spell_en]):
if checker is None:
continue
cands = checker.candidates(w)
if cands:
for c in cands:
if c and c != w and c not in seen:
seen.add(c)
results.append(c)
return results
# ---------------------------------------------------------------------------
# Gutter position detection
# ---------------------------------------------------------------------------
# Minimum word length for spell-fix (very short words are often legitimate)
_MIN_WORD_LEN_SPELL = 3
# Minimum word length for hyphen-join candidates (fragments at the gutter
# can be as short as 1-2 chars, e.g. "ve" from "ver-künden")
_MIN_WORD_LEN_HYPHEN = 2
# How close to the right column edge a word must be to count as "gutter-adjacent".
# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%).
_GUTTER_EDGE_THRESHOLD = 0.70
# Small common words / abbreviations that should NOT be repaired
_STOPWORDS = frozenset([
# German
"ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um",
"zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh",
# English
"a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in",
"is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us",
"we",
])
# IPA / phonetic patterns — skip these cells
_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]')
def _is_ipa_text(text: str) -> bool:
"""True if text looks like IPA transcription."""
return bool(_IPA_RE.search(text))
def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool:
"""Check if a word's right edge is near the right boundary of its column."""
if col_width <= 0:
return False
word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0)
col_right = col_x + col_width
# Word's right edge within the rightmost portion of the column
relative_pos = (word_right - col_x) / col_width
return relative_pos >= _GUTTER_EDGE_THRESHOLD
# ---------------------------------------------------------------------------
# Suggestion types
# ---------------------------------------------------------------------------
@dataclass
class GutterSuggestion:
"""A single correction suggestion."""
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
type: str = "" # "hyphen_join" | "spell_fix"
zone_index: int = 0
row_index: int = 0
col_index: int = 0
col_type: str = ""
cell_id: str = ""
original_text: str = ""
suggested_text: str = ""
# For hyphen_join:
next_row_index: int = -1
next_row_cell_id: str = ""
next_row_text: str = ""
missing_chars: str = ""
display_parts: List[str] = field(default_factory=list)
# Alternatives (other plausible corrections the user can pick from)
alternatives: List[str] = field(default_factory=list)
# Meta:
confidence: float = 0.0
reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation"
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
# ---------------------------------------------------------------------------
# Core repair logic
# ---------------------------------------------------------------------------
_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$')
def _try_hyphen_join(
word_text: str,
next_word_text: str,
max_missing: int = 3,
) -> Optional[Tuple[str, str, float]]:
"""Try joining two fragments with 0..max_missing interpolated chars.
Strips trailing punctuation from the continuation word before testing
(e.g. "künden,""künden") so dictionary lookup succeeds.
Returns (joined_word, missing_chars, confidence) or None.
"""
base = word_text.rstrip("-").rstrip()
# Strip trailing punctuation from continuation (commas, periods, etc.)
raw_continuation = next_word_text.lstrip()
continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation)
if not base or not continuation:
return None
# 1. Direct join (no missing chars)
direct = base + continuation
if _is_known(direct):
return (direct, "", 0.95)
# 2. Try with 1..max_missing missing characters
# Use common letters, weighted by frequency in German/English
_COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu"
for n_missing in range(1, max_missing + 1):
for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing):
candidate = base + "".join(chars) + continuation
if _is_known(candidate):
missing = "".join(chars)
# Confidence decreases with more missing chars
conf = 0.90 - (n_missing - 1) * 0.10
return (candidate, missing, conf)
return None
def _try_spell_fix(
word_text: str, col_type: str = "",
) -> Optional[Tuple[str, float, List[str]]]:
"""Try to fix a single garbled gutter word via spellchecker.
Returns (best_correction, confidence, alternatives_list) or None.
The alternatives list contains other plausible corrections the user
can choose from (e.g. "stammelt" vs "stammeln").
"""
if len(word_text) < _MIN_WORD_LEN_SPELL:
return None
# Strip trailing/leading parentheses and check if the bare word is valid.
# Words like "probieren)" or "(Englisch" are valid words with punctuation,
# not OCR errors. Don't suggest corrections for them.
stripped = word_text.strip("()")
if stripped and _is_known(stripped):
return None
# Determine language priority from column type
if "en" in col_type:
lang = "en"
elif "de" in col_type:
lang = "de"
else:
lang = "both"
candidates = _spell_candidates(word_text, lang=lang)
if not candidates and lang != "both":
candidates = _spell_candidates(word_text, lang="both")
if not candidates:
return None
# Preserve original casing
is_upper = word_text[0].isupper()
def _preserve_case(w: str) -> str:
if is_upper and w:
return w[0].upper() + w[1:]
return w
# Sort candidates by edit distance (closest first)
scored = []
for c in candidates:
dist = _edit_distance(word_text.lower(), c.lower())
scored.append((dist, c))
scored.sort(key=lambda x: x[0])
best_dist, best = scored[0]
best = _preserve_case(best)
conf = max(0.5, 1.0 - best_dist * 0.15)
# Build alternatives (all other candidates, also case-preserved)
alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()]
# Limit to top 5 alternatives
alts = alts[:5]
return (best, conf, alts)
def _edit_distance(a: str, b: str) -> int:
"""Simple Levenshtein distance."""
if len(a) < len(b):
return _edit_distance(b, a)
if len(b) == 0:
return len(a)
prev = list(range(len(b) + 1))
for i, ca in enumerate(a):
curr = [i + 1]
for j, cb in enumerate(b):
cost = 0 if ca == cb else 1
curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost))
prev = curr
return prev[len(b)]

View File

@@ -0,0 +1,356 @@
"""
Gutter Repair Grid — grid analysis and suggestion application.
Extracted from cv_gutter_repair.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from typing import Any, Dict, List, Tuple
from cv_gutter_repair_core import (
_init_spellcheckers,
_is_ipa_text,
_is_known,
_MIN_WORD_LEN_HYPHEN,
_SPELL_AVAILABLE,
_STOPWORDS,
_TRAILING_PUNCT_RE,
_try_hyphen_join,
_try_spell_fix,
_word_is_at_gutter_edge,
GutterSuggestion,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Grid analysis
# ---------------------------------------------------------------------------
def analyse_grid_for_gutter_repair(
grid_data: Dict[str, Any],
image_width: int = 0,
) -> Dict[str, Any]:
"""Analyse a structured grid and return gutter repair suggestions.
Args:
grid_data: The grid_editor_result from the session (zones→cells structure).
image_width: Image width in pixels (for determining gutter side).
Returns:
Dict with "suggestions" list and "stats".
"""
t0 = time.time()
_init_spellcheckers()
if not _SPELL_AVAILABLE:
return {
"suggestions": [],
"stats": {"error": "pyspellchecker not installed"},
"duration_seconds": 0,
}
zones = grid_data.get("zones", [])
suggestions: List[GutterSuggestion] = []
words_checked = 0
gutter_candidates = 0
for zi, zone in enumerate(zones):
columns = zone.get("columns", [])
cells = zone.get("cells", [])
if not columns or not cells:
continue
# Build column lookup: col_index → {x, width, type}
col_info: Dict[int, Dict] = {}
for col in columns:
ci = col.get("index", col.get("col_index", -1))
col_info[ci] = {
"x": col.get("x_min_px", col.get("x", 0)),
"width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)),
"type": col.get("type", col.get("col_type", "")),
}
# Build row→col→cell lookup
cell_map: Dict[Tuple[int, int], Dict] = {}
max_row = 0
for cell in cells:
ri = cell.get("row_index", 0)
ci = cell.get("col_index", 0)
cell_map[(ri, ci)] = cell
if ri > max_row:
max_row = ri
# Determine which columns are at the gutter edge.
# For a left page: rightmost content columns.
# For now, check ALL columns — a word is a candidate if it's at the
# right edge of its column AND not a known word.
for (ri, ci), cell in cell_map.items():
text = (cell.get("text") or "").strip()
if not text:
continue
if _is_ipa_text(text):
continue
words_checked += 1
col = col_info.get(ci, {})
col_type = col.get("type", "")
# Get word boxes to check position
word_boxes = cell.get("word_boxes", [])
# Check the LAST word in the cell (rightmost, closest to gutter)
cell_words = text.split()
if not cell_words:
continue
last_word = cell_words[-1]
# Skip stopwords
if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS:
continue
last_word_clean = last_word.rstrip(".,;:!?)(")
if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN:
continue
# Check if the last word is at the gutter edge
is_at_edge = False
if word_boxes:
last_wb = word_boxes[-1]
is_at_edge = _word_is_at_gutter_edge(
last_wb, col.get("x", 0), col.get("width", 1)
)
else:
# No word boxes — use cell bbox
bbox = cell.get("bbox_px", {})
is_at_edge = _word_is_at_gutter_edge(
{"left": bbox.get("x", 0), "width": bbox.get("w", 0)},
col.get("x", 0), col.get("width", 1)
)
if not is_at_edge:
continue
# Word is at gutter edge — check if it's a known word
if _is_known(last_word_clean):
continue
# Check if the word ends with "-" (explicit hyphen break)
ends_with_hyphen = last_word.endswith("-")
# If the word already ends with "-" and the stem (without
# the hyphen) is a known word, this is a VALID line-break
# hyphenation — not a gutter error. Gutter problems cause
# the hyphen to be LOST ("ve" instead of "ver-"), so a
# visible hyphen + known stem = intentional word-wrap.
# Example: "wunder-" → "wunder" is known → skip.
if ends_with_hyphen:
stem = last_word_clean.rstrip("-")
if stem and _is_known(stem):
continue
gutter_candidates += 1
# --- Strategy 1: Hyphen join with next row ---
next_cell = cell_map.get((ri + 1, ci))
if next_cell:
next_text = (next_cell.get("text") or "").strip()
next_words = next_text.split()
if next_words:
first_next = next_words[0]
first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next)
first_alpha = next((c for c in first_next if c.isalpha()), "")
# Also skip if the joined word is known (covers compound
# words where the stem alone might not be in the dictionary)
if ends_with_hyphen and first_next_clean:
direct = last_word_clean.rstrip("-") + first_next_clean
if _is_known(direct):
continue
# Continuation likely if:
# - explicit hyphen, OR
# - next row starts lowercase (= not a new entry)
if ends_with_hyphen or (first_alpha and first_alpha.islower()):
result = _try_hyphen_join(last_word_clean, first_next)
if result:
joined, missing, conf = result
# Build display parts: show hyphenation for original layout
if ends_with_hyphen:
display_p1 = last_word_clean.rstrip("-")
if missing:
display_p1 += missing
display_p1 += "-"
else:
display_p1 = last_word_clean
if missing:
display_p1 += missing + "-"
else:
display_p1 += "-"
suggestion = GutterSuggestion(
type="hyphen_join",
zone_index=zi,
row_index=ri,
col_index=ci,
col_type=col_type,
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
original_text=last_word,
suggested_text=joined,
next_row_index=ri + 1,
next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"),
next_row_text=next_text,
missing_chars=missing,
display_parts=[display_p1, first_next],
confidence=conf,
reason="gutter_truncation" if missing else "hyphen_continuation",
)
suggestions.append(suggestion)
continue # skip spell_fix if hyphen_join found
# --- Strategy 2: Single-word spell fix (only for longer words) ---
fix_result = _try_spell_fix(last_word_clean, col_type)
if fix_result:
corrected, conf, alts = fix_result
suggestion = GutterSuggestion(
type="spell_fix",
zone_index=zi,
row_index=ri,
col_index=ci,
col_type=col_type,
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
original_text=last_word,
suggested_text=corrected,
alternatives=alts,
confidence=conf,
reason="gutter_blur",
)
suggestions.append(suggestion)
duration = round(time.time() - t0, 3)
logger.info(
"Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)",
words_checked, gutter_candidates, len(suggestions), duration,
)
return {
"suggestions": [s.to_dict() for s in suggestions],
"stats": {
"words_checked": words_checked,
"gutter_candidates": gutter_candidates,
"suggestions_found": len(suggestions),
},
"duration_seconds": duration,
}
def apply_gutter_suggestions(
grid_data: Dict[str, Any],
accepted_ids: List[str],
suggestions: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Apply accepted gutter repair suggestions to the grid data.
Modifies cells in-place and returns summary of changes.
Args:
grid_data: The grid_editor_result (zones→cells).
accepted_ids: List of suggestion IDs the user accepted.
suggestions: The full suggestions list (from analyse_grid_for_gutter_repair).
Returns:
Dict with "applied_count" and "changes" list.
"""
accepted_set = set(accepted_ids)
accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set]
zones = grid_data.get("zones", [])
changes: List[Dict[str, Any]] = []
for s in accepted_suggestions:
zi = s.get("zone_index", 0)
ri = s.get("row_index", 0)
ci = s.get("col_index", 0)
stype = s.get("type", "")
if zi >= len(zones):
continue
zone_cells = zones[zi].get("cells", [])
# Find the target cell
target_cell = None
for cell in zone_cells:
if cell.get("row_index") == ri and cell.get("col_index") == ci:
target_cell = cell
break
if not target_cell:
continue
old_text = target_cell.get("text", "")
if stype == "spell_fix":
# Replace the last word in the cell text
original_word = s.get("original_text", "")
corrected = s.get("suggested_text", "")
if original_word and corrected:
# Replace from the right (last occurrence)
idx = old_text.rfind(original_word)
if idx >= 0:
new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):]
target_cell["text"] = new_text
changes.append({
"type": "spell_fix",
"zone_index": zi,
"row_index": ri,
"col_index": ci,
"cell_id": target_cell.get("cell_id", ""),
"old_text": old_text,
"new_text": new_text,
})
elif stype == "hyphen_join":
# Current cell: replace last word with the hyphenated first part
original_word = s.get("original_text", "")
joined = s.get("suggested_text", "")
display_parts = s.get("display_parts", [])
next_ri = s.get("next_row_index", -1)
if not original_word or not joined or not display_parts:
continue
# The first display part is what goes in the current row
first_part = display_parts[0] if display_parts else ""
# Replace the last word in current cell with the restored form.
# The next row is NOT modified — "künden" stays in its row
# because the original book layout has it there. We only fix
# the truncated word in the current row (e.g. "ve" → "ver-").
idx = old_text.rfind(original_word)
if idx >= 0:
new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):]
target_cell["text"] = new_text
changes.append({
"type": "hyphen_join",
"zone_index": zi,
"row_index": ri,
"col_index": ci,
"cell_id": target_cell.get("cell_id", ""),
"old_text": old_text,
"new_text": new_text,
"joined_word": joined,
})
logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions))
return {
"applied_count": len(accepted_suggestions),
"changes": changes,
}

View File

@@ -0,0 +1,231 @@
"""
Syllable Core — hyphenator init, word validation, pipe autocorrect.
Extracted from cv_syllable_detect.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# IPA/phonetic characters -- skip cells containing these
_IPA_RE = re.compile(r'[\[\]\u02c8\u02cc\u02d0\u0283\u0292\u03b8\u00f0\u014b\u0251\u0252\u00e6\u0254\u0259\u025b\u025c\u026a\u028a\u028c]')
# Common German words that should NOT be merged with adjacent tokens.
_STOP_WORDS = frozenset([
# Articles
'der', 'die', 'das', 'dem', 'den', 'des',
'ein', 'eine', 'einem', 'einen', 'einer',
# Pronouns
'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich',
'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn',
# Prepositions
'mit', 'von', 'zu', 'f\u00fcr', 'auf', 'in', 'an', 'um', 'am', 'im',
'aus', 'bei', 'nach', 'vor', 'bis', 'durch', '\u00fcber', 'unter',
'zwischen', 'ohne', 'gegen',
# Conjunctions
'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber',
# Adverbs
'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht',
# Verbs
'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf',
'sein', 'haben',
# Other
'kein', 'keine', 'keinem', 'keinen', 'keiner',
])
# Cached hyphenators
_hyph_de = None
_hyph_en = None
# Cached spellchecker (for autocorrect_pipe_artifacts)
_spell_de = None
def _get_hyphenators():
"""Lazy-load pyphen hyphenators (cached across calls)."""
global _hyph_de, _hyph_en
if _hyph_de is not None:
return _hyph_de, _hyph_en
try:
import pyphen
except ImportError:
return None, None
_hyph_de = pyphen.Pyphen(lang='de_DE')
_hyph_en = pyphen.Pyphen(lang='en_US')
return _hyph_de, _hyph_en
def _get_spellchecker():
"""Lazy-load German spellchecker (cached across calls)."""
global _spell_de
if _spell_de is not None:
return _spell_de
try:
from spellchecker import SpellChecker
except ImportError:
return None
_spell_de = SpellChecker(language='de')
return _spell_de
def _is_known_word(word: str, hyph_de, hyph_en) -> bool:
"""Check whether pyphen recognises a word (DE or EN)."""
if len(word) < 2:
return False
return ('|' in hyph_de.inserted(word, hyphen='|')
or '|' in hyph_en.inserted(word, hyphen='|'))
def _is_real_word(word: str) -> bool:
"""Check whether spellchecker knows this word (case-insensitive)."""
spell = _get_spellchecker()
if spell is None:
return False
return word.lower() in spell
def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]:
"""Try to hyphenate a word using DE then EN dictionary.
Returns word with | separators, or None if not recognized.
"""
hyph = hyph_de.inserted(word, hyphen='|')
if '|' in hyph:
return hyph
hyph = hyph_en.inserted(word, hyphen='|')
if '|' in hyph:
return hyph
return None
def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]:
"""Try to correct a word that has OCR pipe artifacts.
Printed syllable divider lines on dictionary pages confuse OCR:
the vertical stroke is often read as an extra character (commonly
``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears.
Uses ``spellchecker`` (frequency-based word list) for validation.
Strategy:
1. Strip ``|`` -- if spellchecker knows the result, done.
2. Try deleting each pipe-like character (l, I, 1, i, t).
3. Fall back to spellchecker's own ``correction()`` method.
4. Preserve the original casing of the first letter.
"""
stripped = word_with_pipes.replace('|', '')
if not stripped or len(stripped) < 3:
return stripped # too short to validate
# Step 1: if the stripped word is already a real word, done
if _is_real_word(stripped):
return stripped
# Step 2: try deleting pipe-like characters (most likely artifacts)
_PIPE_LIKE = frozenset('lI1it')
for idx in range(len(stripped)):
if stripped[idx] not in _PIPE_LIKE:
continue
candidate = stripped[:idx] + stripped[idx + 1:]
if len(candidate) >= 3 and _is_real_word(candidate):
return candidate
# Step 3: use spellchecker's built-in correction
spell = _get_spellchecker()
if spell is not None:
suggestion = spell.correction(stripped.lower())
if suggestion and suggestion != stripped.lower():
# Preserve original first-letter case
if stripped[0].isupper():
suggestion = suggestion[0].upper() + suggestion[1:]
return suggestion
return None # could not fix
def autocorrect_pipe_artifacts(
zones_data: List[Dict], session_id: str,
) -> int:
"""Strip OCR pipe artifacts and correct garbled words in-place.
Printed syllable divider lines on dictionary scans are read by OCR
as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``).
This function:
1. Strips ``|`` from every word in content cells.
2. Validates with spellchecker (real dictionary lookup).
3. If not recognised, tries deleting pipe-like characters or uses
spellchecker's correction (e.g. ``Zeplpelin`` -> ``Zeppelin``).
4. Updates both word-box texts and cell text.
Returns the number of cells modified.
"""
spell = _get_spellchecker()
if spell is None:
logger.warning("spellchecker not available -- pipe autocorrect limited")
# Fall back: still strip pipes even without spellchecker
pass
modified = 0
for z in zones_data:
for cell in z.get("cells", []):
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
cell_changed = False
# --- Fix word boxes ---
for wb in cell.get("word_boxes", []):
wb_text = wb.get("text", "")
if "|" not in wb_text:
continue
# Separate trailing punctuation
m = re.match(
r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)'
r'(.*?)'
r'([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$',
wb_text,
)
if not m:
continue
lead, core, trail = m.group(1), m.group(2), m.group(3)
if "|" not in core:
continue
corrected = _autocorrect_piped_word(core)
if corrected is not None and corrected != core:
wb["text"] = lead + corrected + trail
cell_changed = True
# --- Rebuild cell text from word boxes ---
if cell_changed:
wbs = cell.get("word_boxes", [])
if wbs:
cell["text"] = " ".join(
(wb.get("text") or "") for wb in wbs
)
modified += 1
# --- Fallback: strip residual | from cell text ---
text = cell.get("text", "")
if "|" in text:
clean = text.replace("|", "")
if clean != text:
cell["text"] = clean
if not cell_changed:
modified += 1
if modified:
logger.info(
"build-grid session %s: autocorrected pipe artifacts in %d cells",
session_id, modified,
)
return modified

View File

@@ -1,532 +1,32 @@
"""
Syllable divider insertion for dictionary pages.
Syllable divider insertion for dictionary pages — barrel re-export.
For confirmed dictionary pages (is_dictionary=True), processes all content
column cells:
1. Strips existing | dividers for clean normalization
2. Merges pipe-gap spaces (where OCR split a word at a divider position)
3. Applies pyphen syllabification to each word >= 3 alpha chars (DE then EN)
4. Only modifies words that pyphen recognizes — garbled OCR stays as-is
No CV gate needed — the dictionary detection confidence is sufficient.
pyphen uses Hunspell/TeX hyphenation dictionaries and is very reliable.
All implementation split into:
cv_syllable_core — hyphenator init, word validation, pipe autocorrect
cv_syllable_merge — word gap merging, syllabification, divider insertion
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
# IPA/phonetic characters — skip cells containing these
_IPA_RE = re.compile(r'[\[\]ˈˌːʃʒθðŋɑɒæɔəɛɜɪʊʌ]')
# Common German words that should NOT be merged with adjacent tokens.
# These are function words that appear as standalone words between
# headwords/definitions on dictionary pages.
_STOP_WORDS = frozenset([
# Articles
'der', 'die', 'das', 'dem', 'den', 'des',
'ein', 'eine', 'einem', 'einen', 'einer',
# Pronouns
'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich',
'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn',
# Prepositions
'mit', 'von', 'zu', 'für', 'auf', 'in', 'an', 'um', 'am', 'im',
'aus', 'bei', 'nach', 'vor', 'bis', 'durch', 'über', 'unter',
'zwischen', 'ohne', 'gegen',
# Conjunctions
'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber',
# Adverbs
'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht',
# Verbs
'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf',
'sein', 'haben',
# Other
'kein', 'keine', 'keinem', 'keinen', 'keiner',
])
# Cached hyphenators
_hyph_de = None
_hyph_en = None
# Cached spellchecker (for autocorrect_pipe_artifacts)
_spell_de = None
def _get_hyphenators():
"""Lazy-load pyphen hyphenators (cached across calls)."""
global _hyph_de, _hyph_en
if _hyph_de is not None:
return _hyph_de, _hyph_en
try:
import pyphen
except ImportError:
return None, None
_hyph_de = pyphen.Pyphen(lang='de_DE')
_hyph_en = pyphen.Pyphen(lang='en_US')
return _hyph_de, _hyph_en
def _get_spellchecker():
"""Lazy-load German spellchecker (cached across calls)."""
global _spell_de
if _spell_de is not None:
return _spell_de
try:
from spellchecker import SpellChecker
except ImportError:
return None
_spell_de = SpellChecker(language='de')
return _spell_de
def _is_known_word(word: str, hyph_de, hyph_en) -> bool:
"""Check whether pyphen recognises a word (DE or EN)."""
if len(word) < 2:
return False
return ('|' in hyph_de.inserted(word, hyphen='|')
or '|' in hyph_en.inserted(word, hyphen='|'))
def _is_real_word(word: str) -> bool:
"""Check whether spellchecker knows this word (case-insensitive)."""
spell = _get_spellchecker()
if spell is None:
return False
return word.lower() in spell
def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]:
"""Try to hyphenate a word using DE then EN dictionary.
Returns word with | separators, or None if not recognized.
"""
hyph = hyph_de.inserted(word, hyphen='|')
if '|' in hyph:
return hyph
hyph = hyph_en.inserted(word, hyphen='|')
if '|' in hyph:
return hyph
return None
def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]:
"""Try to correct a word that has OCR pipe artifacts.
Printed syllable divider lines on dictionary pages confuse OCR:
the vertical stroke is often read as an extra character (commonly
``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears.
Sometimes OCR reads one divider as ``|`` and another as a letter,
so the garbled character may be far from any detected pipe.
Uses ``spellchecker`` (frequency-based word list) for validation —
unlike pyphen which is a pattern-based hyphenator and accepts
nonsense strings like "Zeplpelin".
Strategy:
1. Strip ``|`` — if spellchecker knows the result, done.
2. Try deleting each pipe-like character (l, I, 1, i, t).
OCR inserts extra chars that resemble vertical strokes.
3. Fall back to spellchecker's own ``correction()`` method.
4. Preserve the original casing of the first letter.
"""
stripped = word_with_pipes.replace('|', '')
if not stripped or len(stripped) < 3:
return stripped # too short to validate
# Step 1: if the stripped word is already a real word, done
if _is_real_word(stripped):
return stripped
# Step 2: try deleting pipe-like characters (most likely artifacts)
_PIPE_LIKE = frozenset('lI1it')
for idx in range(len(stripped)):
if stripped[idx] not in _PIPE_LIKE:
continue
candidate = stripped[:idx] + stripped[idx + 1:]
if len(candidate) >= 3 and _is_real_word(candidate):
return candidate
# Step 3: use spellchecker's built-in correction
spell = _get_spellchecker()
if spell is not None:
suggestion = spell.correction(stripped.lower())
if suggestion and suggestion != stripped.lower():
# Preserve original first-letter case
if stripped[0].isupper():
suggestion = suggestion[0].upper() + suggestion[1:]
return suggestion
return None # could not fix
def autocorrect_pipe_artifacts(
zones_data: List[Dict], session_id: str,
) -> int:
"""Strip OCR pipe artifacts and correct garbled words in-place.
Printed syllable divider lines on dictionary scans are read by OCR
as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``).
This function:
1. Strips ``|`` from every word in content cells.
2. Validates with spellchecker (real dictionary lookup).
3. If not recognised, tries deleting pipe-like characters or uses
spellchecker's correction (e.g. ``Zeplpelin`` → ``Zeppelin``).
4. Updates both word-box texts and cell text.
Returns the number of cells modified.
"""
spell = _get_spellchecker()
if spell is None:
logger.warning("spellchecker not available — pipe autocorrect limited")
# Fall back: still strip pipes even without spellchecker
pass
modified = 0
for z in zones_data:
for cell in z.get("cells", []):
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
cell_changed = False
# --- Fix word boxes ---
for wb in cell.get("word_boxes", []):
wb_text = wb.get("text", "")
if "|" not in wb_text:
continue
# Separate trailing punctuation
m = re.match(
r'^([^a-zA-ZäöüÄÖÜßẞ]*)'
r'(.*?)'
r'([^a-zA-ZäöüÄÖÜßẞ]*)$',
wb_text,
)
if not m:
continue
lead, core, trail = m.group(1), m.group(2), m.group(3)
if "|" not in core:
continue
corrected = _autocorrect_piped_word(core)
if corrected is not None and corrected != core:
wb["text"] = lead + corrected + trail
cell_changed = True
# --- Rebuild cell text from word boxes ---
if cell_changed:
wbs = cell.get("word_boxes", [])
if wbs:
cell["text"] = " ".join(
(wb.get("text") or "") for wb in wbs
)
modified += 1
# --- Fallback: strip residual | from cell text ---
# (covers cases where word_boxes don't exist or weren't fixed)
text = cell.get("text", "")
if "|" in text:
clean = text.replace("|", "")
if clean != text:
cell["text"] = clean
if not cell_changed:
modified += 1
if modified:
logger.info(
"build-grid session %s: autocorrected pipe artifacts in %d cells",
session_id, modified,
)
return modified
def _try_merge_pipe_gaps(text: str, hyph_de) -> str:
"""Merge fragments separated by single spaces where OCR split at a pipe.
Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word).
Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau".
Guards against false merges:
- The FIRST token must be pure alpha (word start — no attached punctuation)
- The second token may have trailing punctuation (comma, period) which
stays attached to the merged word: "" + "fer," -> "Käfer,"
- Common German function words (der, die, das, ...) are never merged
- At least one fragment must be very short (<=3 alpha chars)
"""
parts = text.split(' ')
if len(parts) < 2:
return text
result = [parts[0]]
i = 1
while i < len(parts):
prev = result[-1]
curr = parts[i]
# Extract alpha-only core for lookup
prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev)
curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr)
# Guard 1: first token must be pure alpha (word-start fragment)
# second token may have trailing punctuation
# Guard 2: neither alpha core can be a common German function word
# Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal)
# Guard 4: combined length must be >= 4
should_try = (
prev == prev_alpha # first token: pure alpha (word start)
and prev_alpha and curr_alpha
and prev_alpha.lower() not in _STOP_WORDS
and curr_alpha.lower() not in _STOP_WORDS
and min(len(prev_alpha), len(curr_alpha)) <= 3
and len(prev_alpha) + len(curr_alpha) >= 4
)
if should_try:
merged_alpha = prev_alpha + curr_alpha
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
if '-' in hyph:
# pyphen recognizes merged word — collapse the space
result[-1] = prev + curr
i += 1
continue
result.append(curr)
i += 1
return ' '.join(result)
def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int:
"""Merge OCR word-gap fragments in cell texts using pyphen validation.
OCR often splits words at syllable boundaries into separate word_boxes,
producing text like "zerknit tert" instead of "zerknittert". This
function tries to merge adjacent fragments in every content cell.
More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3)
but still guarded by pyphen dictionary lookup and stop-word exclusion.
Returns the number of cells modified.
"""
hyph_de, _ = _get_hyphenators()
if hyph_de is None:
return 0
modified = 0
for z in zones_data:
for cell in z.get("cells", []):
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
text = cell.get("text", "")
if not text or " " not in text:
continue
# Skip IPA cells
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
if _IPA_RE.search(text_no_brackets):
continue
new_text = _try_merge_word_gaps(text, hyph_de)
if new_text != text:
cell["text"] = new_text
modified += 1
if modified:
logger.info(
"build-grid session %s: merged word gaps in %d cells",
session_id, modified,
)
return modified
def _try_merge_word_gaps(text: str, hyph_de) -> str:
"""Merge OCR word fragments with relaxed threshold (max_short=5).
Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments
(max_short=5 instead of 3). Still requires pyphen to recognize the
merged word.
"""
parts = text.split(' ')
if len(parts) < 2:
return text
result = [parts[0]]
i = 1
while i < len(parts):
prev = result[-1]
curr = parts[i]
prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev)
curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr)
should_try = (
prev == prev_alpha
and prev_alpha and curr_alpha
and prev_alpha.lower() not in _STOP_WORDS
and curr_alpha.lower() not in _STOP_WORDS
and min(len(prev_alpha), len(curr_alpha)) <= 5
and len(prev_alpha) + len(curr_alpha) >= 4
)
if should_try:
merged_alpha = prev_alpha + curr_alpha
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
if '-' in hyph:
result[-1] = prev + curr
i += 1
continue
result.append(curr)
i += 1
return ' '.join(result)
def _syllabify_text(text: str, hyph_de, hyph_en) -> str:
"""Syllabify all significant words in a text string.
1. Strip existing | dividers
2. Merge pipe-gap spaces where possible
3. Apply pyphen to each word >= 3 alphabetic chars
4. Words pyphen doesn't recognize stay as-is (no bad guesses)
"""
if not text:
return text
# Skip cells that contain IPA transcription characters outside brackets.
# Bracket content like [bɪltʃøn] is programmatically inserted and should
# not block syllabification of the surrounding text.
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
if _IPA_RE.search(text_no_brackets):
return text
# Phase 1: strip existing pipe dividers for clean normalization
clean = text.replace('|', '')
# Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting)
clean = _try_merge_pipe_gaps(clean, hyph_de)
# Phase 3: tokenize and syllabify each word
# Split on whitespace and comma/semicolon sequences, keeping separators
tokens = re.split(r'(\s+|[,;:]+\s*)', clean)
result = []
for tok in tokens:
if not tok or re.match(r'^[\s,;:]+$', tok):
result.append(tok)
continue
# Strip trailing/leading punctuation for pyphen lookup
m = re.match(r'^([^a-zA-ZäöüÄÖÜßẞ]*)(.*?)([^a-zA-ZäöüÄÖÜßẞ]*)$', tok)
if not m:
result.append(tok)
continue
lead, word, trail = m.group(1), m.group(2), m.group(3)
if len(word) < 3 or not re.search(r'[a-zA-ZäöüÄÖÜß]', word):
result.append(tok)
continue
hyph = _hyphenate_word(word, hyph_de, hyph_en)
if hyph:
result.append(lead + hyph + trail)
else:
result.append(tok)
return ''.join(result)
def insert_syllable_dividers(
zones_data: List[Dict],
img_bgr: np.ndarray,
session_id: str,
*,
force: bool = False,
col_filter: Optional[set] = None,
) -> int:
"""Insert pipe syllable dividers into dictionary cells.
For dictionary pages: process all content column cells, strip existing
pipes, merge pipe-gap spaces, and re-syllabify using pyphen.
Pre-check: at least 1% of content cells must already contain ``|`` from
OCR. This guards against pages with zero pipe characters (the primary
guard — article_col_index — is checked at the call site).
Args:
force: If True, skip the pipe-ratio pre-check and syllabify all
content words regardless of whether the original has pipe dividers.
col_filter: If set, only process cells whose col_type is in this set.
None means process all content columns.
Returns the number of cells modified.
"""
hyph_de, hyph_en = _get_hyphenators()
if hyph_de is None:
logger.warning("pyphen not installed — skipping syllable insertion")
return 0
# Pre-check: count cells that already have | from OCR.
# Real dictionary pages with printed syllable dividers will have OCR-
# detected pipes in many cells. Pages without syllable dividers will
# have zero — skip those to avoid false syllabification.
if not force:
total_col_cells = 0
cells_with_pipes = 0
for z in zones_data:
for cell in z.get("cells", []):
if cell.get("col_type", "").startswith("column_"):
total_col_cells += 1
if "|" in cell.get("text", ""):
cells_with_pipes += 1
if total_col_cells > 0:
pipe_ratio = cells_with_pipes / total_col_cells
if pipe_ratio < 0.01:
logger.info(
"build-grid session %s: skipping syllable insertion — "
"only %.1f%% of cells have existing pipes (need >=1%%)",
session_id, pipe_ratio * 100,
)
return 0
insertions = 0
for z in zones_data:
for cell in z.get("cells", []):
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
if col_filter is not None and ct not in col_filter:
continue
text = cell.get("text", "")
if not text:
continue
# In auto mode (force=False), only normalize cells that already
# have | from OCR (i.e. printed syllable dividers on the original
# scan). Don't add new syllable marks to other words.
if not force and "|" not in text:
continue
new_text = _syllabify_text(text, hyph_de, hyph_en)
if new_text != text:
cell["text"] = new_text
insertions += 1
if insertions:
logger.info(
"build-grid session %s: syllable dividers inserted/normalized "
"in %d cells (pyphen)",
session_id, insertions,
)
return insertions
# Core: init, validation, autocorrect
from cv_syllable_core import ( # noqa: F401
_IPA_RE,
_STOP_WORDS,
_get_hyphenators,
_get_spellchecker,
_is_known_word,
_is_real_word,
_hyphenate_word,
_autocorrect_piped_word,
autocorrect_pipe_artifacts,
)
# Merge: gap merging, syllabify, insert
from cv_syllable_merge import ( # noqa: F401
_try_merge_pipe_gaps,
merge_word_gaps_in_zones,
_try_merge_word_gaps,
_syllabify_text,
insert_syllable_dividers,
)

View File

@@ -0,0 +1,300 @@
"""
Syllable Merge — word gap merging, syllabification, divider insertion.
Extracted from cv_syllable_detect.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List, Optional
import numpy as np
from cv_syllable_core import (
_get_hyphenators,
_hyphenate_word,
_IPA_RE,
_STOP_WORDS,
)
logger = logging.getLogger(__name__)
def _try_merge_pipe_gaps(text: str, hyph_de) -> str:
"""Merge fragments separated by single spaces where OCR split at a pipe.
Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word).
Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau".
Guards against false merges:
- The FIRST token must be pure alpha (word start -- no attached punctuation)
- The second token may have trailing punctuation (comma, period) which
stays attached to the merged word: "Ka" + "fer," -> "Kafer,"
- Common German function words (der, die, das, ...) are never merged
- At least one fragment must be very short (<=3 alpha chars)
"""
parts = text.split(' ')
if len(parts) < 2:
return text
result = [parts[0]]
i = 1
while i < len(parts):
prev = result[-1]
curr = parts[i]
# Extract alpha-only core for lookup
prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev)
curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr)
# Guard 1: first token must be pure alpha (word-start fragment)
# second token may have trailing punctuation
# Guard 2: neither alpha core can be a common German function word
# Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal)
# Guard 4: combined length must be >= 4
should_try = (
prev == prev_alpha # first token: pure alpha (word start)
and prev_alpha and curr_alpha
and prev_alpha.lower() not in _STOP_WORDS
and curr_alpha.lower() not in _STOP_WORDS
and min(len(prev_alpha), len(curr_alpha)) <= 3
and len(prev_alpha) + len(curr_alpha) >= 4
)
if should_try:
merged_alpha = prev_alpha + curr_alpha
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
if '-' in hyph:
# pyphen recognizes merged word -- collapse the space
result[-1] = prev + curr
i += 1
continue
result.append(curr)
i += 1
return ' '.join(result)
def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int:
"""Merge OCR word-gap fragments in cell texts using pyphen validation.
OCR often splits words at syllable boundaries into separate word_boxes,
producing text like "zerknit tert" instead of "zerknittert". This
function tries to merge adjacent fragments in every content cell.
More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3)
but still guarded by pyphen dictionary lookup and stop-word exclusion.
Returns the number of cells modified.
"""
hyph_de, _ = _get_hyphenators()
if hyph_de is None:
return 0
modified = 0
for z in zones_data:
for cell in z.get("cells", []):
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
text = cell.get("text", "")
if not text or " " not in text:
continue
# Skip IPA cells
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
if _IPA_RE.search(text_no_brackets):
continue
new_text = _try_merge_word_gaps(text, hyph_de)
if new_text != text:
cell["text"] = new_text
modified += 1
if modified:
logger.info(
"build-grid session %s: merged word gaps in %d cells",
session_id, modified,
)
return modified
def _try_merge_word_gaps(text: str, hyph_de) -> str:
"""Merge OCR word fragments with relaxed threshold (max_short=5).
Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments
(max_short=5 instead of 3). Still requires pyphen to recognize the
merged word.
"""
parts = text.split(' ')
if len(parts) < 2:
return text
result = [parts[0]]
i = 1
while i < len(parts):
prev = result[-1]
curr = parts[i]
prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev)
curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr)
should_try = (
prev == prev_alpha
and prev_alpha and curr_alpha
and prev_alpha.lower() not in _STOP_WORDS
and curr_alpha.lower() not in _STOP_WORDS
and min(len(prev_alpha), len(curr_alpha)) <= 5
and len(prev_alpha) + len(curr_alpha) >= 4
)
if should_try:
merged_alpha = prev_alpha + curr_alpha
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
if '-' in hyph:
result[-1] = prev + curr
i += 1
continue
result.append(curr)
i += 1
return ' '.join(result)
def _syllabify_text(text: str, hyph_de, hyph_en) -> str:
"""Syllabify all significant words in a text string.
1. Strip existing | dividers
2. Merge pipe-gap spaces where possible
3. Apply pyphen to each word >= 3 alphabetic chars
4. Words pyphen doesn't recognize stay as-is (no bad guesses)
"""
if not text:
return text
# Skip cells that contain IPA transcription characters outside brackets.
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
if _IPA_RE.search(text_no_brackets):
return text
# Phase 1: strip existing pipe dividers for clean normalization
clean = text.replace('|', '')
# Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting)
clean = _try_merge_pipe_gaps(clean, hyph_de)
# Phase 3: tokenize and syllabify each word
# Split on whitespace and comma/semicolon sequences, keeping separators
tokens = re.split(r'(\s+|[,;:]+\s*)', clean)
result = []
for tok in tokens:
if not tok or re.match(r'^[\s,;:]+$', tok):
result.append(tok)
continue
# Strip trailing/leading punctuation for pyphen lookup
m = re.match(r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)(.*?)([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$', tok)
if not m:
result.append(tok)
continue
lead, word, trail = m.group(1), m.group(2), m.group(3)
if len(word) < 3 or not re.search(r'[a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df]', word):
result.append(tok)
continue
hyph = _hyphenate_word(word, hyph_de, hyph_en)
if hyph:
result.append(lead + hyph + trail)
else:
result.append(tok)
return ''.join(result)
def insert_syllable_dividers(
zones_data: List[Dict],
img_bgr: np.ndarray,
session_id: str,
*,
force: bool = False,
col_filter: Optional[set] = None,
) -> int:
"""Insert pipe syllable dividers into dictionary cells.
For dictionary pages: process all content column cells, strip existing
pipes, merge pipe-gap spaces, and re-syllabify using pyphen.
Pre-check: at least 1% of content cells must already contain ``|`` from
OCR. This guards against pages with zero pipe characters.
Args:
force: If True, skip the pipe-ratio pre-check and syllabify all
content words regardless of whether the original has pipe dividers.
col_filter: If set, only process cells whose col_type is in this set.
None means process all content columns.
Returns the number of cells modified.
"""
hyph_de, hyph_en = _get_hyphenators()
if hyph_de is None:
logger.warning("pyphen not installed -- skipping syllable insertion")
return 0
# Pre-check: count cells that already have | from OCR.
if not force:
total_col_cells = 0
cells_with_pipes = 0
for z in zones_data:
for cell in z.get("cells", []):
if cell.get("col_type", "").startswith("column_"):
total_col_cells += 1
if "|" in cell.get("text", ""):
cells_with_pipes += 1
if total_col_cells > 0:
pipe_ratio = cells_with_pipes / total_col_cells
if pipe_ratio < 0.01:
logger.info(
"build-grid session %s: skipping syllable insertion -- "
"only %.1f%% of cells have existing pipes (need >=1%%)",
session_id, pipe_ratio * 100,
)
return 0
insertions = 0
for z in zones_data:
for cell in z.get("cells", []):
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
if col_filter is not None and ct not in col_filter:
continue
text = cell.get("text", "")
if not text:
continue
# In auto mode (force=False), only normalize cells that already
# have | from OCR (i.e. printed syllable dividers on the original
# scan). Don't add new syllable marks to other words.
if not force and "|" not in text:
continue
new_text = _syllabify_text(text, hyph_de, hyph_en)
if new_text != text:
cell["text"] = new_text
insertions += 1
if insertions:
logger.info(
"build-grid session %s: syllable dividers inserted/normalized "
"in %d cells (pyphen)",
session_id, insertions,
)
return insertions

View File

@@ -1,52 +1,27 @@
"""
Mail Aggregator Service
Mail Aggregator Service — barrel re-export.
All implementation split into:
aggregator_imap — IMAP connection, sync, email parsing
aggregator_smtp — SMTP connection, email sending
Multi-account IMAP aggregation with async support.
"""
import os
import ssl
import email
import asyncio
import logging
import smtplib
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime, timezone
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.header import decode_header, make_header
from email.utils import parsedate_to_datetime, parseaddr
from typing import Optional, List, Dict, Any
from .credentials import get_credentials_service, MailCredentials
from .mail_db import (
get_email_accounts,
get_email_account,
update_account_status,
upsert_email,
get_unified_inbox,
)
from .models import (
AccountStatus,
AccountTestResult,
AggregatedEmail,
EmailComposeRequest,
EmailSendResult,
)
from .credentials import get_credentials_service
from .mail_db import get_email_accounts, get_unified_inbox
from .models import AccountTestResult
from .aggregator_imap import IMAPMixin, IMAPConnectionError
from .aggregator_smtp import SMTPMixin, SMTPConnectionError
logger = logging.getLogger(__name__)
class IMAPConnectionError(Exception):
"""Raised when IMAP connection fails."""
pass
class SMTPConnectionError(Exception):
"""Raised when SMTP connection fails."""
pass
class MailAggregator:
class MailAggregator(IMAPMixin, SMTPMixin):
"""
Aggregates emails from multiple IMAP accounts into a unified inbox.
@@ -86,390 +61,29 @@ class MailAggregator:
)
# Test IMAP
try:
import imaplib
if imap_ssl:
imap = imaplib.IMAP4_SSL(imap_host, imap_port)
else:
imap = imaplib.IMAP4(imap_host, imap_port)
imap.login(email_address, password)
result.imap_connected = True
# List folders
status, folders = imap.list()
if status == "OK":
result.folders_found = [
self._parse_folder_name(f) for f in folders if f
]
imap.logout()
except Exception as e:
result.error_message = f"IMAP Error: {str(e)}"
logger.warning(f"IMAP test failed for {email_address}: {e}")
imap_ok, imap_err, folders = await self.test_imap_connection(
imap_host, imap_port, imap_ssl, email_address, password
)
result.imap_connected = imap_ok
if folders:
result.folders_found = folders
if imap_err:
result.error_message = imap_err
# Test SMTP
try:
if smtp_ssl:
smtp = smtplib.SMTP_SSL(smtp_host, smtp_port)
else:
smtp = smtplib.SMTP(smtp_host, smtp_port)
smtp.starttls()
smtp.login(email_address, password)
result.smtp_connected = True
smtp.quit()
except Exception as e:
smtp_error = f"SMTP Error: {str(e)}"
smtp_ok, smtp_err = await self.test_smtp_connection(
smtp_host, smtp_port, smtp_ssl, email_address, password
)
result.smtp_connected = smtp_ok
if smtp_err:
if result.error_message:
result.error_message += f"; {smtp_error}"
result.error_message += f"; {smtp_err}"
else:
result.error_message = smtp_error
logger.warning(f"SMTP test failed for {email_address}: {e}")
result.error_message = smtp_err
result.success = result.imap_connected and result.smtp_connected
return result
def _parse_folder_name(self, folder_response: bytes) -> str:
"""Parse folder name from IMAP LIST response."""
try:
# Format: '(\\HasNoChildren) "/" "INBOX"'
decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response
parts = decoded.rsplit('" "', 1)
if len(parts) == 2:
return parts[1].rstrip('"')
return decoded
except Exception:
return str(folder_response)
async def sync_account(
self,
account_id: str,
user_id: str,
max_emails: int = 100,
folders: Optional[List[str]] = None,
) -> Tuple[int, int]:
"""
Sync emails from an IMAP account.
Args:
account_id: The account ID
user_id: The user ID
max_emails: Maximum emails to fetch
folders: Specific folders to sync (default: INBOX)
Returns:
Tuple of (new_emails, total_emails)
"""
import imaplib
account = await get_email_account(account_id, user_id)
if not account:
raise ValueError(f"Account not found: {account_id}")
# Get credentials
vault_path = account.get("vault_path", "")
creds = await self._credentials_service.get_credentials(account_id, vault_path)
if not creds:
await update_account_status(account_id, "error", "Credentials not found")
raise IMAPConnectionError("Credentials not found")
new_count = 0
total_count = 0
try:
# Connect to IMAP
if account["imap_ssl"]:
imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"])
else:
imap = imaplib.IMAP4(account["imap_host"], account["imap_port"])
imap.login(creds.email, creds.password)
# Sync specified folders or just INBOX
sync_folders = folders or ["INBOX"]
for folder in sync_folders:
try:
status, _ = imap.select(folder)
if status != "OK":
continue
# Search for recent emails
status, messages = imap.search(None, "ALL")
if status != "OK":
continue
message_ids = messages[0].split()
total_count += len(message_ids)
# Fetch most recent emails
recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids
for msg_id in recent_ids:
try:
email_data = await self._fetch_and_store_email(
imap, msg_id, account_id, user_id, account["tenant_id"], folder
)
if email_data:
new_count += 1
except Exception as e:
logger.warning(f"Failed to fetch email {msg_id}: {e}")
except Exception as e:
logger.warning(f"Failed to sync folder {folder}: {e}")
imap.logout()
# Update account status
await update_account_status(
account_id,
"active",
email_count=total_count,
unread_count=new_count, # Will be recalculated
)
return new_count, total_count
except Exception as e:
logger.error(f"Account sync failed: {e}")
await update_account_status(account_id, "error", str(e))
raise IMAPConnectionError(str(e))
async def _fetch_and_store_email(
self,
imap,
msg_id: bytes,
account_id: str,
user_id: str,
tenant_id: str,
folder: str,
) -> Optional[str]:
"""Fetch a single email and store it in the database."""
try:
status, msg_data = imap.fetch(msg_id, "(RFC822)")
if status != "OK" or not msg_data or not msg_data[0]:
return None
raw_email = msg_data[0][1]
msg = email.message_from_bytes(raw_email)
# Parse headers
message_id = msg.get("Message-ID", str(msg_id))
subject = self._decode_header(msg.get("Subject", ""))
from_header = msg.get("From", "")
sender_name, sender_email = parseaddr(from_header)
sender_name = self._decode_header(sender_name)
# Parse recipients
to_header = msg.get("To", "")
recipients = [addr[1] for addr in email.utils.getaddresses([to_header])]
cc_header = msg.get("Cc", "")
cc = [addr[1] for addr in email.utils.getaddresses([cc_header])]
# Parse dates
date_str = msg.get("Date")
try:
date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc)
except Exception:
date_sent = datetime.now(timezone.utc)
date_received = datetime.now(timezone.utc)
# Parse body
body_text, body_html, attachments = self._parse_body(msg)
# Create preview
body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text
# Get headers dict
headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]}
# Store in database
email_id = await upsert_email(
account_id=account_id,
user_id=user_id,
tenant_id=tenant_id,
message_id=message_id,
subject=subject,
sender_email=sender_email,
sender_name=sender_name,
recipients=recipients,
cc=cc,
body_preview=body_preview,
body_text=body_text,
body_html=body_html,
has_attachments=len(attachments) > 0,
attachments=attachments,
headers=headers,
folder=folder,
date_sent=date_sent,
date_received=date_received,
)
return email_id
except Exception as e:
logger.error(f"Failed to parse email: {e}")
return None
def _decode_header(self, header_value: str) -> str:
"""Decode email header value."""
if not header_value:
return ""
try:
decoded = decode_header(header_value)
return str(make_header(decoded))
except Exception:
return str(header_value)
def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]:
"""
Parse email body and attachments.
Returns:
Tuple of (body_text, body_html, attachments)
"""
body_text = None
body_html = None
attachments = []
if msg.is_multipart():
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", ""))
# Skip multipart containers
if content_type.startswith("multipart/"):
continue
# Check for attachments
if "attachment" in content_disposition:
filename = part.get_filename()
if filename:
attachments.append({
"filename": self._decode_header(filename),
"content_type": content_type,
"size": len(part.get_payload(decode=True) or b""),
})
continue
# Get body content
try:
payload = part.get_payload(decode=True)
charset = part.get_content_charset() or "utf-8"
if payload:
text = payload.decode(charset, errors="replace")
if content_type == "text/plain" and not body_text:
body_text = text
elif content_type == "text/html" and not body_html:
body_html = text
except Exception as e:
logger.debug(f"Failed to decode body part: {e}")
else:
# Single part message
content_type = msg.get_content_type()
try:
payload = msg.get_payload(decode=True)
charset = msg.get_content_charset() or "utf-8"
if payload:
text = payload.decode(charset, errors="replace")
if content_type == "text/plain":
body_text = text
elif content_type == "text/html":
body_html = text
except Exception as e:
logger.debug(f"Failed to decode body: {e}")
return body_text, body_html, attachments
async def send_email(
self,
account_id: str,
user_id: str,
request: EmailComposeRequest,
) -> EmailSendResult:
"""
Send an email via SMTP.
Args:
account_id: The account to send from
user_id: The user ID
request: The compose request with recipients and content
Returns:
EmailSendResult with success status
"""
account = await get_email_account(account_id, user_id)
if not account:
return EmailSendResult(success=False, error="Account not found")
# Verify the account_id matches
if request.account_id != account_id:
return EmailSendResult(success=False, error="Account mismatch")
# Get credentials
vault_path = account.get("vault_path", "")
creds = await self._credentials_service.get_credentials(account_id, vault_path)
if not creds:
return EmailSendResult(success=False, error="Credentials not found")
try:
# Create message
if request.is_html:
msg = MIMEMultipart("alternative")
msg.attach(MIMEText(request.body, "html"))
else:
msg = MIMEText(request.body, "plain")
msg["Subject"] = request.subject
msg["From"] = account["email"]
msg["To"] = ", ".join(request.to)
if request.cc:
msg["Cc"] = ", ".join(request.cc)
if request.reply_to_message_id:
msg["In-Reply-To"] = request.reply_to_message_id
msg["References"] = request.reply_to_message_id
# Send via SMTP
if account["smtp_ssl"]:
smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"])
else:
smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"])
smtp.starttls()
smtp.login(creds.email, creds.password)
# All recipients
all_recipients = list(request.to)
if request.cc:
all_recipients.extend(request.cc)
if request.bcc:
all_recipients.extend(request.bcc)
smtp.sendmail(account["email"], all_recipients, msg.as_string())
smtp.quit()
return EmailSendResult(
success=True,
message_id=msg.get("Message-ID"),
)
except Exception as e:
logger.error(f"Failed to send email: {e}")
return EmailSendResult(success=False, error=str(e))
async def sync_all_accounts(self, user_id: str, tenant_id: Optional[str] = None) -> Dict[str, Any]:
"""
Sync all accounts for a user.

View File

@@ -0,0 +1,322 @@
"""
Mail Aggregator IMAP — IMAP connection, sync, email parsing.
Extracted from aggregator.py for modularity.
"""
import email
import logging
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime, timezone
from email.header import decode_header, make_header
from email.utils import parsedate_to_datetime, parseaddr
from .mail_db import upsert_email, update_account_status, get_email_account
logger = logging.getLogger(__name__)
class IMAPConnectionError(Exception):
"""Raised when IMAP connection fails."""
pass
class IMAPMixin:
"""IMAP-related methods for MailAggregator.
Provides connection testing, syncing, and email parsing.
Must be mixed into a class that has ``_credentials_service``.
"""
def _parse_folder_name(self, folder_response: bytes) -> str:
"""Parse folder name from IMAP LIST response."""
try:
# Format: '(\\HasNoChildren) "/" "INBOX"'
decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response
parts = decoded.rsplit('" "', 1)
if len(parts) == 2:
return parts[1].rstrip('"')
return decoded
except Exception:
return str(folder_response)
async def test_imap_connection(
self,
imap_host: str,
imap_port: int,
imap_ssl: bool,
email_address: str,
password: str,
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
"""Test IMAP connection. Returns (success, error, folders)."""
try:
import imaplib
if imap_ssl:
imap = imaplib.IMAP4_SSL(imap_host, imap_port)
else:
imap = imaplib.IMAP4(imap_host, imap_port)
imap.login(email_address, password)
# List folders
folders_found = None
status, folders = imap.list()
if status == "OK":
folders_found = [
self._parse_folder_name(f) for f in folders if f
]
imap.logout()
return True, None, folders_found
except Exception as e:
logger.warning(f"IMAP test failed for {email_address}: {e}")
return False, f"IMAP Error: {str(e)}", None
async def sync_account(
self,
account_id: str,
user_id: str,
max_emails: int = 100,
folders: Optional[List[str]] = None,
) -> Tuple[int, int]:
"""
Sync emails from an IMAP account.
Args:
account_id: The account ID
user_id: The user ID
max_emails: Maximum emails to fetch
folders: Specific folders to sync (default: INBOX)
Returns:
Tuple of (new_emails, total_emails)
"""
import imaplib
account = await get_email_account(account_id, user_id)
if not account:
raise ValueError(f"Account not found: {account_id}")
# Get credentials
vault_path = account.get("vault_path", "")
creds = await self._credentials_service.get_credentials(account_id, vault_path)
if not creds:
await update_account_status(account_id, "error", "Credentials not found")
raise IMAPConnectionError("Credentials not found")
new_count = 0
total_count = 0
try:
# Connect to IMAP
if account["imap_ssl"]:
imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"])
else:
imap = imaplib.IMAP4(account["imap_host"], account["imap_port"])
imap.login(creds.email, creds.password)
# Sync specified folders or just INBOX
sync_folders = folders or ["INBOX"]
for folder in sync_folders:
try:
status, _ = imap.select(folder)
if status != "OK":
continue
# Search for recent emails
status, messages = imap.search(None, "ALL")
if status != "OK":
continue
message_ids = messages[0].split()
total_count += len(message_ids)
# Fetch most recent emails
recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids
for msg_id in recent_ids:
try:
email_data = await self._fetch_and_store_email(
imap, msg_id, account_id, user_id, account["tenant_id"], folder
)
if email_data:
new_count += 1
except Exception as e:
logger.warning(f"Failed to fetch email {msg_id}: {e}")
except Exception as e:
logger.warning(f"Failed to sync folder {folder}: {e}")
imap.logout()
# Update account status
await update_account_status(
account_id,
"active",
email_count=total_count,
unread_count=new_count, # Will be recalculated
)
return new_count, total_count
except Exception as e:
logger.error(f"Account sync failed: {e}")
await update_account_status(account_id, "error", str(e))
raise IMAPConnectionError(str(e))
async def _fetch_and_store_email(
self,
imap,
msg_id: bytes,
account_id: str,
user_id: str,
tenant_id: str,
folder: str,
) -> Optional[str]:
"""Fetch a single email and store it in the database."""
try:
status, msg_data = imap.fetch(msg_id, "(RFC822)")
if status != "OK" or not msg_data or not msg_data[0]:
return None
raw_email = msg_data[0][1]
msg = email.message_from_bytes(raw_email)
# Parse headers
message_id = msg.get("Message-ID", str(msg_id))
subject = self._decode_header(msg.get("Subject", ""))
from_header = msg.get("From", "")
sender_name, sender_email = parseaddr(from_header)
sender_name = self._decode_header(sender_name)
# Parse recipients
to_header = msg.get("To", "")
recipients = [addr[1] for addr in email.utils.getaddresses([to_header])]
cc_header = msg.get("Cc", "")
cc = [addr[1] for addr in email.utils.getaddresses([cc_header])]
# Parse dates
date_str = msg.get("Date")
try:
date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc)
except Exception:
date_sent = datetime.now(timezone.utc)
date_received = datetime.now(timezone.utc)
# Parse body
body_text, body_html, attachments = self._parse_body(msg)
# Create preview
body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text
# Get headers dict
headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]}
# Store in database
email_id = await upsert_email(
account_id=account_id,
user_id=user_id,
tenant_id=tenant_id,
message_id=message_id,
subject=subject,
sender_email=sender_email,
sender_name=sender_name,
recipients=recipients,
cc=cc,
body_preview=body_preview,
body_text=body_text,
body_html=body_html,
has_attachments=len(attachments) > 0,
attachments=attachments,
headers=headers,
folder=folder,
date_sent=date_sent,
date_received=date_received,
)
return email_id
except Exception as e:
logger.error(f"Failed to parse email: {e}")
return None
def _decode_header(self, header_value: str) -> str:
"""Decode email header value."""
if not header_value:
return ""
try:
decoded = decode_header(header_value)
return str(make_header(decoded))
except Exception:
return str(header_value)
def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]:
"""
Parse email body and attachments.
Returns:
Tuple of (body_text, body_html, attachments)
"""
body_text = None
body_html = None
attachments = []
if msg.is_multipart():
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", ""))
# Skip multipart containers
if content_type.startswith("multipart/"):
continue
# Check for attachments
if "attachment" in content_disposition:
filename = part.get_filename()
if filename:
attachments.append({
"filename": self._decode_header(filename),
"content_type": content_type,
"size": len(part.get_payload(decode=True) or b""),
})
continue
# Get body content
try:
payload = part.get_payload(decode=True)
charset = part.get_content_charset() or "utf-8"
if payload:
text = payload.decode(charset, errors="replace")
if content_type == "text/plain" and not body_text:
body_text = text
elif content_type == "text/html" and not body_html:
body_html = text
except Exception as e:
logger.debug(f"Failed to decode body part: {e}")
else:
# Single part message
content_type = msg.get_content_type()
try:
payload = msg.get_payload(decode=True)
charset = msg.get_content_charset() or "utf-8"
if payload:
text = payload.decode(charset, errors="replace")
if content_type == "text/plain":
body_text = text
elif content_type == "text/html":
body_html = text
except Exception as e:
logger.debug(f"Failed to decode body: {e}")
return body_text, body_html, attachments

View File

@@ -0,0 +1,131 @@
"""
Mail Aggregator SMTP — email sending via SMTP.
Extracted from aggregator.py for modularity.
"""
import logging
import smtplib
from typing import Optional, List, Dict, Any
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from .mail_db import get_email_account
from .models import EmailComposeRequest, EmailSendResult
logger = logging.getLogger(__name__)
class SMTPConnectionError(Exception):
"""Raised when SMTP connection fails."""
pass
class SMTPMixin:
"""SMTP-related methods for MailAggregator.
Provides SMTP connection testing and email sending.
Must be mixed into a class that has ``_credentials_service``.
"""
async def test_smtp_connection(
self,
smtp_host: str,
smtp_port: int,
smtp_ssl: bool,
email_address: str,
password: str,
) -> tuple:
"""Test SMTP connection. Returns (success, error)."""
try:
if smtp_ssl:
smtp = smtplib.SMTP_SSL(smtp_host, smtp_port)
else:
smtp = smtplib.SMTP(smtp_host, smtp_port)
smtp.starttls()
smtp.login(email_address, password)
smtp.quit()
return True, None
except Exception as e:
logger.warning(f"SMTP test failed for {email_address}: {e}")
return False, f"SMTP Error: {str(e)}"
async def send_email(
self,
account_id: str,
user_id: str,
request: EmailComposeRequest,
) -> EmailSendResult:
"""
Send an email via SMTP.
Args:
account_id: The account to send from
user_id: The user ID
request: The compose request with recipients and content
Returns:
EmailSendResult with success status
"""
account = await get_email_account(account_id, user_id)
if not account:
return EmailSendResult(success=False, error="Account not found")
# Verify the account_id matches
if request.account_id != account_id:
return EmailSendResult(success=False, error="Account mismatch")
# Get credentials
vault_path = account.get("vault_path", "")
creds = await self._credentials_service.get_credentials(account_id, vault_path)
if not creds:
return EmailSendResult(success=False, error="Credentials not found")
try:
# Create message
if request.is_html:
msg = MIMEMultipart("alternative")
msg.attach(MIMEText(request.body, "html"))
else:
msg = MIMEText(request.body, "plain")
msg["Subject"] = request.subject
msg["From"] = account["email"]
msg["To"] = ", ".join(request.to)
if request.cc:
msg["Cc"] = ", ".join(request.cc)
if request.reply_to_message_id:
msg["In-Reply-To"] = request.reply_to_message_id
msg["References"] = request.reply_to_message_id
# Send via SMTP
if account["smtp_ssl"]:
smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"])
else:
smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"])
smtp.starttls()
smtp.login(creds.email, creds.password)
# All recipients
all_recipients = list(request.to)
if request.cc:
all_recipients.extend(request.cc)
if request.bcc:
all_recipients.extend(request.bcc)
smtp.sendmail(account["email"], all_recipients, msg.as_string())
smtp.quit()
return EmailSendResult(
success=True,
message_id=msg.get("Message-ID"),
)
except Exception as e:
logger.error(f"Failed to send email: {e}")
return EmailSendResult(success=False, error=str(e))

View File

@@ -10,12 +10,11 @@ Unterstützt:
"""
import os
import re
import zipfile
import hashlib
import json
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import asyncio
@@ -23,6 +22,7 @@ import asyncio
# Local imports
from eh_pipeline import chunk_text, generate_embeddings, extract_text_from_pdf, get_vector_size, EMBEDDING_BACKEND
from qdrant_service import QdrantService
from nibis_parsers import parse_filename_old_format, parse_filename_new_format
# Configuration
DOCS_BASE_PATH = Path("/Users/benjaminadmin/projekte/breakpilot-pwa/docs")
@@ -87,15 +87,6 @@ SUBJECT_MAPPING = {
"gespfl": "Gesundheit-Pflege",
}
# Niveau-Mapping
NIVEAU_MAPPING = {
"ea": "eA", # erhöhtes Anforderungsniveau
"ga": "gA", # grundlegendes Anforderungsniveau
"neuga": "gA (neu einsetzend)",
"neuea": "eA (neu einsetzend)",
}
def compute_file_hash(file_path: Path) -> str:
"""Berechnet SHA-256 Hash einer Datei."""
sha256 = hashlib.sha256()
@@ -135,103 +126,6 @@ def extract_zip_files(base_path: Path) -> List[Path]:
return extracted
def parse_filename_old_format(filename: str, file_path: Path) -> Optional[Dict]:
"""
Parst alte Namenskonvention (2016, 2017):
- {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf
- Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf
"""
# Pattern für Lehrer-Dateien
pattern = r"(\d{4})([A-Za-zäöüÄÖÜ]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$"
match = re.search(pattern, filename, re.IGNORECASE)
if not match:
return None
year = int(match.group(1))
subject_raw = match.group(2).lower()
niveau = match.group(3).upper()
task_num = match.group(4) or match.group(5)
# Prüfe ob es ein Lehrer-Dokument ist (EWH)
is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf")
# Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.)
variant = None
variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"]
for v in variant_patterns:
if v.lower() in str(file_path).lower():
variant = v
break
return {
"year": year,
"subject": subject_raw,
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
"task_number": int(task_num) if task_num else None,
"doc_type": "EWH" if is_ewh else "Aufgabe",
"variant": variant,
}
def parse_filename_new_format(filename: str, file_path: Path) -> Optional[Dict]:
"""
Parst neue Namenskonvention (2024, 2025):
- {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf
- Beispiel: 2025_Deutsch_eA_I_EWH.pdf
"""
# Pattern für neue Dateien
pattern = r"(\d{4})_([A-Za-zäöüÄÖÜ]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$"
match = re.search(pattern, filename, re.IGNORECASE)
if not match:
return None
year = int(match.group(1))
subject_raw = match.group(2).lower()
niveau = match.group(3)
task_id = match.group(4)
suffix = match.group(5) or ""
# Task-Nummer aus römischen Zahlen
task_num = None
if task_id:
roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None)
# Dokumenttyp
is_ewh = "EWH" in filename or "ewh" in filename.lower()
# Spezielle Dokumenttypen
doc_type = "EWH" if is_ewh else "Aufgabe"
if "Material" in suffix:
doc_type = "Material"
elif "GBU" in suffix:
doc_type = "GBU"
elif "Ergebnis" in suffix:
doc_type = "Ergebnis"
elif "Bewertungsbogen" in suffix:
doc_type = "Bewertungsbogen"
elif "HV" in suffix:
doc_type = "Hörverstehen"
elif "ME" in suffix:
doc_type = "Mediation"
# BG Variante
variant = "BG" if "BG" in filename else None
if "mitExp" in str(file_path):
variant = "mitExp"
return {
"year": year,
"subject": subject_raw,
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
"task_number": task_num,
"doc_type": doc_type,
"variant": variant,
}
def discover_documents(base_path: Path, ewh_only: bool = True) -> List[NiBiSDocument]:
"""
Findet alle relevanten Dokumente in den za-download Verzeichnissen.

View File

@@ -0,0 +1,113 @@
"""
NiBiS Filename Parsers
Parses old and new naming conventions for NiBiS Abitur documents.
"""
import re
from typing import Dict, Optional
# Niveau-Mapping
NIVEAU_MAPPING = {
"ea": "eA", # erhoehtes Anforderungsniveau
"ga": "gA", # grundlegendes Anforderungsniveau
"neuga": "gA (neu einsetzend)",
"neuea": "eA (neu einsetzend)",
}
def parse_filename_old_format(filename: str, file_path) -> Optional[Dict]:
"""
Parst alte Namenskonvention (2016, 2017):
- {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf
- Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf
"""
# Pattern fuer Lehrer-Dateien
pattern = r"(\d{4})([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$"
match = re.search(pattern, filename, re.IGNORECASE)
if not match:
return None
year = int(match.group(1))
subject_raw = match.group(2).lower()
niveau = match.group(3).upper()
task_num = match.group(4) or match.group(5)
# Pruefe ob es ein Lehrer-Dokument ist (EWH)
is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf")
# Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.)
variant = None
variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"]
for v in variant_patterns:
if v.lower() in str(file_path).lower():
variant = v
break
return {
"year": year,
"subject": subject_raw,
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
"task_number": int(task_num) if task_num else None,
"doc_type": "EWH" if is_ewh else "Aufgabe",
"variant": variant,
}
def parse_filename_new_format(filename: str, file_path) -> Optional[Dict]:
"""
Parst neue Namenskonvention (2024, 2025):
- {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf
- Beispiel: 2025_Deutsch_eA_I_EWH.pdf
"""
# Pattern fuer neue Dateien
pattern = r"(\d{4})_([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$"
match = re.search(pattern, filename, re.IGNORECASE)
if not match:
return None
year = int(match.group(1))
subject_raw = match.group(2).lower()
niveau = match.group(3)
task_id = match.group(4)
suffix = match.group(5) or ""
# Task-Nummer aus roemischen Zahlen
task_num = None
if task_id:
roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None)
# Dokumenttyp
is_ewh = "EWH" in filename or "ewh" in filename.lower()
# Spezielle Dokumenttypen
doc_type = "EWH" if is_ewh else "Aufgabe"
if "Material" in suffix:
doc_type = "Material"
elif "GBU" in suffix:
doc_type = "GBU"
elif "Ergebnis" in suffix:
doc_type = "Ergebnis"
elif "Bewertungsbogen" in suffix:
doc_type = "Bewertungsbogen"
elif "HV" in suffix:
doc_type = "Hoerverstehen"
elif "ME" in suffix:
doc_type = "Mediation"
# BG Variante
variant = "BG" if "BG" in filename else None
if "mitExp" in str(file_path):
variant = "mitExp"
return {
"year": year,
"subject": subject_raw,
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
"task_number": task_num,
"doc_type": doc_type,
"variant": variant,
}

View File

@@ -1,557 +1,26 @@
"""
NRU Worksheet Generator - Generate vocabulary worksheets in NRU format.
NRU Worksheet Generator — barrel re-export.
Format:
- Page 1 (Vokabeln): 3-column table
- Column 1: English vocabulary
- Column 2: Empty (child writes German translation)
- Column 3: Empty (child writes corrected English after parent review)
- Page 2 (Lernsätze): Full-width table
- Row 1: German sentence (pre-filled)
- Row 2-3: Empty lines (child writes English translation)
All implementation split into:
nru_worksheet_models — data classes, entry separation
nru_worksheet_html — HTML generation
nru_worksheet_pdf — PDF generation
Per scanned page, we generate 2 worksheet pages.
"""
import io
import logging
from typing import List, Dict, Tuple
from dataclasses import dataclass
# Models
from nru_worksheet_models import ( # noqa: F401
VocabEntry,
SentenceEntry,
separate_vocab_and_sentences,
)
logger = logging.getLogger(__name__)
# HTML generation
from nru_worksheet_html import ( # noqa: F401
generate_nru_html,
generate_nru_worksheet_html,
)
@dataclass
class VocabEntry:
english: str
german: str
source_page: int = 1
@dataclass
class SentenceEntry:
german: str
english: str # For solution sheet
source_page: int = 1
def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
"""
Separate vocabulary entries into single words/phrases and full sentences.
Sentences are identified by:
- Ending with punctuation (. ! ?)
- Being longer than 40 characters
- Containing multiple words with capital letters mid-sentence
"""
vocab_list = []
sentence_list = []
for entry in entries:
english = entry.get("english", "").strip()
german = entry.get("german", "").strip()
source_page = entry.get("source_page", 1)
if not english or not german:
continue
# Detect if this is a sentence
is_sentence = (
english.endswith('.') or
english.endswith('!') or
english.endswith('?') or
len(english) > 50 or
(len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
)
if is_sentence:
sentence_list.append(SentenceEntry(
german=german,
english=english,
source_page=source_page
))
else:
vocab_list.append(VocabEntry(
english=english,
german=german,
source_page=source_page
))
return vocab_list, sentence_list
def generate_nru_html(
vocab_list: List[VocabEntry],
sentence_list: List[SentenceEntry],
page_number: int,
title: str = "Vokabeltest",
show_solutions: bool = False,
line_height_px: int = 28
) -> str:
"""
Generate HTML for NRU-format worksheet.
Returns HTML for 2 pages:
- Page 1: Vocabulary table (3 columns)
- Page 2: Sentence practice (full width)
"""
# Filter by page
page_vocab = [v for v in vocab_list if v.source_page == page_number]
page_sentences = [s for s in sentence_list if s.source_page == page_number]
html = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {{
size: A4;
margin: 1.5cm 2cm;
}}
* {{
box-sizing: border-box;
}}
body {{
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}}
.page {{
page-break-after: always;
min-height: 100%;
}}
.page:last-child {{
page-break-after: avoid;
}}
h1 {{
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}}
.header {{
margin-bottom: 15px;
}}
.name-line {{
font-size: 11pt;
margin-bottom: 10px;
}}
/* Vocabulary Table - 3 columns */
.vocab-table {{
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}}
.vocab-table th {{
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}}
.vocab-table td {{
border: 1px solid #333;
padding: 4px 8px;
height: {line_height_px}px;
vertical-align: middle;
}}
.vocab-table .col-english {{ width: 35%; }}
.vocab-table .col-german {{ width: 35%; }}
.vocab-table .col-correction {{ width: 30%; }}
.vocab-answer {{
color: #0066cc;
font-style: italic;
}}
/* Sentence Table - full width */
.sentence-table {{
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}}
.sentence-table td {{
border: 1px solid #333;
padding: 6px 10px;
}}
.sentence-header {{
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}}
.sentence-line {{
height: {line_height_px + 4}px;
}}
.sentence-answer {{
color: #0066cc;
font-style: italic;
font-size: 11pt;
}}
.page-info {{
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}}
</style>
</head>
<body>
"""
# ========== PAGE 1: VOCABULARY TABLE ==========
if page_vocab:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
html += """
</tbody>
</table>
<div class="page-info">Vokabeln aus Unit</div>
</div>
"""
# ========== PAGE 2: SENTENCE PRACTICE ==========
if page_sentences:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
html += """
</table>
"""
html += """
<div class="page-info">Lernsaetze aus Unit</div>
</div>
"""
html += """
</body>
</html>
"""
return html
def generate_nru_worksheet_html(
entries: List[Dict],
title: str = "Vokabeltest",
show_solutions: bool = False,
specific_pages: List[int] = None
) -> str:
"""
Generate complete NRU worksheet HTML for all pages.
Args:
entries: List of vocabulary entries with source_page
title: Worksheet title
show_solutions: Whether to show answers
specific_pages: List of specific page numbers to include (1-indexed)
Returns:
Complete HTML document
"""
# Separate into vocab and sentences
vocab_list, sentence_list = separate_vocab_and_sentences(entries)
# Get unique page numbers
all_pages = set()
for v in vocab_list:
all_pages.add(v.source_page)
for s in sentence_list:
all_pages.add(s.source_page)
# Filter to specific pages if requested
if specific_pages:
all_pages = all_pages.intersection(set(specific_pages))
pages_sorted = sorted(all_pages)
logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
# Generate HTML for each page
combined_html = """<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {
size: A4;
margin: 1.5cm 2cm;
}
* {
box-sizing: border-box;
}
body {
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}
.page {
page-break-after: always;
min-height: 100%;
}
.page:last-child {
page-break-after: avoid;
}
h1 {
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}
.header {
margin-bottom: 15px;
}
.name-line {
font-size: 11pt;
margin-bottom: 10px;
}
/* Vocabulary Table - 3 columns */
.vocab-table {
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}
.vocab-table th {
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}
.vocab-table td {
border: 1px solid #333;
padding: 4px 8px;
height: 28px;
vertical-align: middle;
}
.vocab-table .col-english { width: 35%; }
.vocab-table .col-german { width: 35%; }
.vocab-table .col-correction { width: 30%; }
.vocab-answer {
color: #0066cc;
font-style: italic;
}
/* Sentence Table - full width */
.sentence-table {
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}
.sentence-table td {
border: 1px solid #333;
padding: 6px 10px;
}
.sentence-header {
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}
.sentence-line {
height: 32px;
}
.sentence-answer {
color: #0066cc;
font-style: italic;
font-size: 11pt;
}
.page-info {
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}
</style>
</head>
<body>
"""
for page_num in pages_sorted:
page_vocab = [v for v in vocab_list if v.source_page == page_num]
page_sentences = [s for s in sentence_list if s.source_page == page_num]
# PAGE 1: VOCABULARY TABLE
if page_vocab:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
combined_html += f"""
</tbody>
</table>
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
# PAGE 2: SENTENCE PRACTICE
if page_sentences:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
combined_html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
combined_html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
combined_html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
combined_html += """
</table>
"""
combined_html += f"""
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
combined_html += """
</body>
</html>
"""
return combined_html
async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
"""
Generate NRU worksheet PDFs.
Returns:
Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
"""
from weasyprint import HTML
# Generate worksheet HTML
worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
worksheet_pdf = HTML(string=worksheet_html).write_pdf()
# Generate solution HTML
solution_pdf = None
if include_solutions:
solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
solution_pdf = HTML(string=solution_html).write_pdf()
return worksheet_pdf, solution_pdf
# PDF generation
from nru_worksheet_pdf import generate_nru_pdf # noqa: F401

View File

@@ -0,0 +1,466 @@
"""
NRU Worksheet HTML — HTML generation for vocabulary worksheets.
Extracted from nru_worksheet_generator.py for modularity.
"""
import logging
from typing import List, Dict
from nru_worksheet_models import VocabEntry, SentenceEntry, separate_vocab_and_sentences
logger = logging.getLogger(__name__)
def generate_nru_html(
vocab_list: List[VocabEntry],
sentence_list: List[SentenceEntry],
page_number: int,
title: str = "Vokabeltest",
show_solutions: bool = False,
line_height_px: int = 28
) -> str:
"""
Generate HTML for NRU-format worksheet.
Returns HTML for 2 pages:
- Page 1: Vocabulary table (3 columns)
- Page 2: Sentence practice (full width)
"""
# Filter by page
page_vocab = [v for v in vocab_list if v.source_page == page_number]
page_sentences = [s for s in sentence_list if s.source_page == page_number]
html = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {{
size: A4;
margin: 1.5cm 2cm;
}}
* {{
box-sizing: border-box;
}}
body {{
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}}
.page {{
page-break-after: always;
min-height: 100%;
}}
.page:last-child {{
page-break-after: avoid;
}}
h1 {{
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}}
.header {{
margin-bottom: 15px;
}}
.name-line {{
font-size: 11pt;
margin-bottom: 10px;
}}
/* Vocabulary Table - 3 columns */
.vocab-table {{
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}}
.vocab-table th {{
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}}
.vocab-table td {{
border: 1px solid #333;
padding: 4px 8px;
height: {line_height_px}px;
vertical-align: middle;
}}
.vocab-table .col-english {{ width: 35%; }}
.vocab-table .col-german {{ width: 35%; }}
.vocab-table .col-correction {{ width: 30%; }}
.vocab-answer {{
color: #0066cc;
font-style: italic;
}}
/* Sentence Table - full width */
.sentence-table {{
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}}
.sentence-table td {{
border: 1px solid #333;
padding: 6px 10px;
}}
.sentence-header {{
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}}
.sentence-line {{
height: {line_height_px + 4}px;
}}
.sentence-answer {{
color: #0066cc;
font-style: italic;
font-size: 11pt;
}}
.page-info {{
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}}
</style>
</head>
<body>
"""
# ========== PAGE 1: VOCABULARY TABLE ==========
if page_vocab:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
html += """
</tbody>
</table>
<div class="page-info">Vokabeln aus Unit</div>
</div>
"""
# ========== PAGE 2: SENTENCE PRACTICE ==========
if page_sentences:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
html += """
</table>
"""
html += """
<div class="page-info">Lernsaetze aus Unit</div>
</div>
"""
html += """
</body>
</html>
"""
return html
def generate_nru_worksheet_html(
entries: List[Dict],
title: str = "Vokabeltest",
show_solutions: bool = False,
specific_pages: List[int] = None
) -> str:
"""
Generate complete NRU worksheet HTML for all pages.
Args:
entries: List of vocabulary entries with source_page
title: Worksheet title
show_solutions: Whether to show answers
specific_pages: List of specific page numbers to include (1-indexed)
Returns:
Complete HTML document
"""
# Separate into vocab and sentences
vocab_list, sentence_list = separate_vocab_and_sentences(entries)
# Get unique page numbers
all_pages = set()
for v in vocab_list:
all_pages.add(v.source_page)
for s in sentence_list:
all_pages.add(s.source_page)
# Filter to specific pages if requested
if specific_pages:
all_pages = all_pages.intersection(set(specific_pages))
pages_sorted = sorted(all_pages)
logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
# Generate HTML for each page
combined_html = """<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {
size: A4;
margin: 1.5cm 2cm;
}
* {
box-sizing: border-box;
}
body {
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}
.page {
page-break-after: always;
min-height: 100%;
}
.page:last-child {
page-break-after: avoid;
}
h1 {
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}
.header {
margin-bottom: 15px;
}
.name-line {
font-size: 11pt;
margin-bottom: 10px;
}
/* Vocabulary Table - 3 columns */
.vocab-table {
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}
.vocab-table th {
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}
.vocab-table td {
border: 1px solid #333;
padding: 4px 8px;
height: 28px;
vertical-align: middle;
}
.vocab-table .col-english { width: 35%; }
.vocab-table .col-german { width: 35%; }
.vocab-table .col-correction { width: 30%; }
.vocab-answer {
color: #0066cc;
font-style: italic;
}
/* Sentence Table - full width */
.sentence-table {
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}
.sentence-table td {
border: 1px solid #333;
padding: 6px 10px;
}
.sentence-header {
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}
.sentence-line {
height: 32px;
}
.sentence-answer {
color: #0066cc;
font-style: italic;
font-size: 11pt;
}
.page-info {
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}
</style>
</head>
<body>
"""
for page_num in pages_sorted:
page_vocab = [v for v in vocab_list if v.source_page == page_num]
page_sentences = [s for s in sentence_list if s.source_page == page_num]
# PAGE 1: VOCABULARY TABLE
if page_vocab:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
combined_html += f"""
</tbody>
</table>
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
# PAGE 2: SENTENCE PRACTICE
if page_sentences:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
combined_html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
combined_html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
combined_html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
combined_html += """
</table>
"""
combined_html += f"""
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
combined_html += """
</body>
</html>
"""
return combined_html

View File

@@ -0,0 +1,70 @@
"""
NRU Worksheet Models — data classes and entry separation logic.
Extracted from nru_worksheet_generator.py for modularity.
"""
import logging
from typing import List, Dict, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class VocabEntry:
english: str
german: str
source_page: int = 1
@dataclass
class SentenceEntry:
german: str
english: str # For solution sheet
source_page: int = 1
def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
"""
Separate vocabulary entries into single words/phrases and full sentences.
Sentences are identified by:
- Ending with punctuation (. ! ?)
- Being longer than 40 characters
- Containing multiple words with capital letters mid-sentence
"""
vocab_list = []
sentence_list = []
for entry in entries:
english = entry.get("english", "").strip()
german = entry.get("german", "").strip()
source_page = entry.get("source_page", 1)
if not english or not german:
continue
# Detect if this is a sentence
is_sentence = (
english.endswith('.') or
english.endswith('!') or
english.endswith('?') or
len(english) > 50 or
(len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
)
if is_sentence:
sentence_list.append(SentenceEntry(
german=german,
english=english,
source_page=source_page
))
else:
vocab_list.append(VocabEntry(
english=english,
german=german,
source_page=source_page
))
return vocab_list, sentence_list

View File

@@ -0,0 +1,31 @@
"""
NRU Worksheet PDF — PDF generation using weasyprint.
Extracted from nru_worksheet_generator.py for modularity.
"""
from typing import List, Dict, Tuple
from nru_worksheet_html import generate_nru_worksheet_html
async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
"""
Generate NRU worksheet PDFs.
Returns:
Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
"""
from weasyprint import HTML
# Generate worksheet HTML
worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
worksheet_pdf = HTML(string=worksheet_html).write_pdf()
# Generate solution HTML
solution_pdf = None
if include_solutions:
solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
solution_pdf = HTML(string=solution_html).write_pdf()
return worksheet_pdf, solution_pdf

View File

@@ -0,0 +1,333 @@
"""
Overlay rendering for columns, rows, and words (grid-based overlays).
Extracted from ocr_pipeline_overlays.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from ocr_pipeline_common import _get_base_image_png
from ocr_pipeline_session_store import get_session_db
from ocr_pipeline_rows import _draw_box_exclusion_overlay
logger = logging.getLogger(__name__)
async def _get_columns_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with column borders drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=404, detail="No column data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for region types (BGR)
colors = {
"column_en": (255, 180, 0), # Blue
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
"column_text": (200, 200, 0), # Cyan/Turquoise
"page_ref": (200, 0, 200), # Purple
"column_marker": (0, 0, 220), # Red
"column_ignore": (180, 180, 180), # Light Gray
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for col in column_result["columns"]:
x, y = col["x"], col["y"]
w, h = col["width"], col["height"]
color = colors.get(col.get("type", ""), (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
# Label with confidence
label = col.get("type", "unknown").replace("column_", "").upper()
conf = col.get("classification_confidence")
if conf is not None and conf < 1.0:
label = f"{label} {int(conf * 100)}%"
cv2.putText(img, label, (x + 10, y + 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# Blend overlay at 20% opacity
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
# Draw detected box boundaries as dashed rectangles
zones = column_result.get("zones") or []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
box_color = (0, 200, 255) # Yellow (BGR)
# Draw dashed rectangle by drawing short line segments
dash_len = 15
for edge_x in range(bx, bx + bw, dash_len * 2):
end_x = min(edge_x + dash_len, bx + bw)
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
for edge_y in range(by, by + bh, dash_len * 2):
end_y = min(edge_y + dash_len, by + bh)
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_rows_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with row bands drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
row_result = session.get("row_result")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=404, detail="No row data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for row types (BGR)
row_colors = {
"content": (255, 180, 0), # Blue
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for row in row_result["rows"]:
x, y = row["x"], row["y"]
w, h = row["width"], row["height"]
row_type = row.get("row_type", "content")
color = row_colors.get(row_type, (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
# Label
idx = row.get("index", 0)
label = f"R{idx} {row_type.upper()}"
wc = row.get("word_count", 0)
if wc:
label = f"{label} ({wc}w)"
cv2.putText(img, label, (x + 5, y + 18),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# Draw zone separator lines if zones exist
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
if zones:
img_w_px = img.shape[1]
zone_color = (0, 200, 255) # Yellow (BGR)
dash_len = 20
for zone in zones:
if zone.get("zone_type") == "box":
zy = zone["y"]
zh = zone["height"]
for line_y in [zy, zy + zh]:
for sx in range(0, img_w_px, dash_len * 2):
ex = min(sx + dash_len, img_w_px)
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_words_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=404, detail="No word data available")
# Support both new cell-based and legacy entry-based formats
cells = word_result.get("cells")
if not cells and not word_result.get("entries"):
raise HTTPException(status_code=404, detail="No word data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
img_h, img_w = img.shape[:2]
overlay = img.copy()
if cells:
# New cell-based overlay: color by column index
col_palette = [
(255, 180, 0), # Blue (BGR)
(0, 200, 0), # Green
(0, 140, 255), # Orange
(200, 100, 200), # Purple
(200, 200, 0), # Cyan
(100, 200, 200), # Yellow-ish
]
for cell in cells:
bbox = cell.get("bbox_px", {})
cx = bbox.get("x", 0)
cy = bbox.get("y", 0)
cw = bbox.get("w", 0)
ch = bbox.get("h", 0)
if cw <= 0 or ch <= 0:
continue
col_idx = cell.get("col_index", 0)
color = col_palette[col_idx % len(col_palette)]
# Cell rectangle border
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
# Semi-transparent fill
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
# Cell-ID label (top-left corner)
cell_id = cell.get("cell_id", "")
cv2.putText(img, cell_id, (cx + 2, cy + 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
# Text label (bottom of cell)
text = cell.get("text", "")
if text:
conf = cell.get("confidence", 0)
if conf >= 70:
text_color = (0, 180, 0)
elif conf >= 50:
text_color = (0, 180, 220)
else:
text_color = (0, 0, 220)
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, cy + ch - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
else:
# Legacy fallback: entry-based overlay (for old sessions)
column_result = session.get("column_result")
row_result = session.get("row_result")
col_colors = {
"column_en": (255, 180, 0),
"column_de": (0, 200, 0),
"column_example": (0, 140, 255),
}
columns = []
if column_result and column_result.get("columns"):
columns = [c for c in column_result["columns"]
if c.get("type", "").startswith("column_")]
content_rows_data = []
if row_result and row_result.get("rows"):
content_rows_data = [r for r in row_result["rows"]
if r.get("row_type") == "content"]
for col in columns:
col_type = col.get("type", "")
color = col_colors.get(col_type, (200, 200, 200))
cx, cw = col["x"], col["width"]
for row in content_rows_data:
ry, rh = row["y"], row["height"]
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
entries = word_result["entries"]
entry_by_row: Dict[int, Dict] = {}
for entry in entries:
entry_by_row[entry.get("row_index", -1)] = entry
for row_idx, row in enumerate(content_rows_data):
entry = entry_by_row.get(row_idx)
if not entry:
continue
conf = entry.get("confidence", 0)
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
ry, rh = row["y"], row["height"]
for col in columns:
col_type = col.get("type", "")
cx, cw = col["x"], col["width"]
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
text = entry.get(field, "") if field else ""
if text:
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, ry + rh - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
# Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
# Red semi-transparent overlay for box zones
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")

View File

@@ -0,0 +1,205 @@
"""
Overlay rendering for structure detection (boxes, zones, colors, graphics).
Extracted from ocr_pipeline_overlays.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from ocr_pipeline_common import _get_base_image_png
from ocr_pipeline_session_store import get_session_db
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
from cv_box_detect import detect_boxes, split_page_into_zones
logger = logging.getLogger(__name__)
async def _get_structure_overlay(session_id: str) -> Response:
"""Generate overlay image showing detected boxes, zones, and color regions."""
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
# Get structure result (run detection if not cached)
session = await get_session_db(session_id)
structure = (session or {}).get("structure_result")
if not structure:
# Run detection on-the-fly
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
structure = {
"boxes": [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence, "border_thickness": b.border_thickness}
for b in boxes
],
"zones": [
{"index": z.index, "zone_type": z.zone_type,
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
for z in zones
],
}
overlay = img.copy()
# --- Draw zone boundaries ---
zone_colors = {
"content": (200, 200, 200), # light gray
"box": (255, 180, 0), # blue-ish (BGR)
}
for zone in structure.get("zones", []):
zx = zone["x"]
zy = zone["y"]
zw = zone["w"]
zh = zone["h"]
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
# Draw zone boundary as dashed line
dash_len = 12
for edge_x in range(zx, zx + zw, dash_len * 2):
end_x = min(edge_x + dash_len, zx + zw)
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
# Zone label
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
cv2.putText(img, zone_label, (zx + 5, zy + 15),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
# --- Draw detected boxes ---
# Color map for box backgrounds (BGR)
bg_hex_to_bgr = {
"#dc2626": (38, 38, 220), # red
"#2563eb": (235, 99, 37), # blue
"#16a34a": (74, 163, 22), # green
"#ea580c": (12, 88, 234), # orange
"#9333ea": (234, 51, 147), # purple
"#ca8a04": (4, 138, 202), # yellow
"#6b7280": (128, 114, 107), # gray
}
for box_data in structure.get("boxes", []):
bx = box_data["x"]
by = box_data["y"]
bw = box_data["w"]
bh = box_data["h"]
conf = box_data.get("confidence", 0)
thickness = box_data.get("border_thickness", 0)
bg_hex = box_data.get("bg_color_hex", "#6b7280")
bg_name = box_data.get("bg_color_name", "")
# Box fill color
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
# Semi-transparent fill
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
# Solid border
border_color = fill_bgr
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
# Label
label = f"BOX"
if bg_name and bg_name not in ("unknown", "white"):
label += f" ({bg_name})"
if thickness > 0:
label += f" border={thickness}px"
label += f" {int(conf * 100)}%"
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# --- Draw color regions (HSV masks) ---
hsv = cv2.cvtColor(
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
cv2.COLOR_BGR2HSV,
)
color_bgr_map = {
"red": (0, 0, 255),
"orange": (0, 140, 255),
"yellow": (0, 200, 255),
"green": (0, 200, 0),
"blue": (255, 150, 0),
"purple": (200, 0, 200),
}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
# Only draw if there are significant colored pixels
if np.sum(mask > 0) < 100:
continue
# Draw colored contours
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
for cnt in contours:
area = cv2.contourArea(cnt)
if area < 20:
continue
cv2.drawContours(img, [cnt], -1, draw_color, 2)
# --- Draw graphic elements ---
graphics_data = structure.get("graphics", [])
shape_icons = {
"image": "IMAGE",
"illustration": "ILLUST",
}
for gfx in graphics_data:
gx, gy = gfx["x"], gfx["y"]
gw, gh = gfx["w"], gfx["h"]
shape = gfx.get("shape", "icon")
color_hex = gfx.get("color_hex", "#6b7280")
conf = gfx.get("confidence", 0)
# Pick draw color based on element color (BGR)
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
# Draw bounding box (dashed style via short segments)
dash = 6
for seg_x in range(gx, gx + gw, dash * 2):
end_x = min(seg_x + dash, gx + gw)
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
for seg_y in range(gy, gy + gh, dash * 2):
end_y = min(seg_y + dash, gy + gh)
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
# Label
icon = shape_icons.get(shape, shape.upper()[:5])
label = f"{icon} {int(conf * 100)}%"
# White background for readability
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
lx = gx + 2
ly = max(gy - 4, th + 4)
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
# Encode result
_, png_buf = cv2.imencode(".png", img)
return Response(content=png_buf.tobytes(), media_type="image/png")

View File

@@ -1,34 +1,23 @@
"""
Overlay image rendering for OCR pipeline.
Overlay image rendering for OCR pipeline — barrel re-export.
Generates visual overlays for structure, columns, rows, and words
detection results.
All implementation split into:
ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
ocr_pipeline_overlay_grid — columns, rows, words overlays
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from ocr_pipeline_common import (
_cache,
_get_base_image_png,
_load_session_to_cache,
_get_cached,
from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401
from ocr_pipeline_overlay_grid import ( # noqa: F401
_get_columns_overlay,
_get_rows_overlay,
_get_words_overlay,
)
from ocr_pipeline_session_store import get_session_db, get_session_image
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
from cv_box_detect import detect_boxes, split_page_into_zones
from ocr_pipeline_rows import _draw_box_exclusion_overlay
logger = logging.getLogger(__name__)
async def render_overlay(overlay_type: str, session_id: str) -> Response:
@@ -43,505 +32,3 @@ async def render_overlay(overlay_type: str, session_id: str) -> Response:
return await _get_words_overlay(session_id)
else:
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")
async def _get_structure_overlay(session_id: str) -> Response:
"""Generate overlay image showing detected boxes, zones, and color regions."""
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
# Get structure result (run detection if not cached)
session = await get_session_db(session_id)
structure = (session or {}).get("structure_result")
if not structure:
# Run detection on-the-fly
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
structure = {
"boxes": [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence, "border_thickness": b.border_thickness}
for b in boxes
],
"zones": [
{"index": z.index, "zone_type": z.zone_type,
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
for z in zones
],
}
overlay = img.copy()
# --- Draw zone boundaries ---
zone_colors = {
"content": (200, 200, 200), # light gray
"box": (255, 180, 0), # blue-ish (BGR)
}
for zone in structure.get("zones", []):
zx = zone["x"]
zy = zone["y"]
zw = zone["w"]
zh = zone["h"]
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
# Draw zone boundary as dashed line
dash_len = 12
for edge_x in range(zx, zx + zw, dash_len * 2):
end_x = min(edge_x + dash_len, zx + zw)
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
# Zone label
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
cv2.putText(img, zone_label, (zx + 5, zy + 15),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
# --- Draw detected boxes ---
# Color map for box backgrounds (BGR)
bg_hex_to_bgr = {
"#dc2626": (38, 38, 220), # red
"#2563eb": (235, 99, 37), # blue
"#16a34a": (74, 163, 22), # green
"#ea580c": (12, 88, 234), # orange
"#9333ea": (234, 51, 147), # purple
"#ca8a04": (4, 138, 202), # yellow
"#6b7280": (128, 114, 107), # gray
}
for box_data in structure.get("boxes", []):
bx = box_data["x"]
by = box_data["y"]
bw = box_data["w"]
bh = box_data["h"]
conf = box_data.get("confidence", 0)
thickness = box_data.get("border_thickness", 0)
bg_hex = box_data.get("bg_color_hex", "#6b7280")
bg_name = box_data.get("bg_color_name", "")
# Box fill color
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
# Semi-transparent fill
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
# Solid border
border_color = fill_bgr
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
# Label
label = f"BOX"
if bg_name and bg_name not in ("unknown", "white"):
label += f" ({bg_name})"
if thickness > 0:
label += f" border={thickness}px"
label += f" {int(conf * 100)}%"
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# --- Draw color regions (HSV masks) ---
hsv = cv2.cvtColor(
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
cv2.COLOR_BGR2HSV,
)
color_bgr_map = {
"red": (0, 0, 255),
"orange": (0, 140, 255),
"yellow": (0, 200, 255),
"green": (0, 200, 0),
"blue": (255, 150, 0),
"purple": (200, 0, 200),
}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
# Only draw if there are significant colored pixels
if np.sum(mask > 0) < 100:
continue
# Draw colored contours
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
for cnt in contours:
area = cv2.contourArea(cnt)
if area < 20:
continue
cv2.drawContours(img, [cnt], -1, draw_color, 2)
# --- Draw graphic elements ---
graphics_data = structure.get("graphics", [])
shape_icons = {
"image": "IMAGE",
"illustration": "ILLUST",
}
for gfx in graphics_data:
gx, gy = gfx["x"], gfx["y"]
gw, gh = gfx["w"], gfx["h"]
shape = gfx.get("shape", "icon")
color_hex = gfx.get("color_hex", "#6b7280")
conf = gfx.get("confidence", 0)
# Pick draw color based on element color (BGR)
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
# Draw bounding box (dashed style via short segments)
dash = 6
for seg_x in range(gx, gx + gw, dash * 2):
end_x = min(seg_x + dash, gx + gw)
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
for seg_y in range(gy, gy + gh, dash * 2):
end_y = min(seg_y + dash, gy + gh)
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
# Label
icon = shape_icons.get(shape, shape.upper()[:5])
label = f"{icon} {int(conf * 100)}%"
# White background for readability
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
lx = gx + 2
ly = max(gy - 4, th + 4)
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
# Encode result
_, png_buf = cv2.imencode(".png", img)
return Response(content=png_buf.tobytes(), media_type="image/png")
async def _get_columns_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with column borders drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=404, detail="No column data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for region types (BGR)
colors = {
"column_en": (255, 180, 0), # Blue
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
"column_text": (200, 200, 0), # Cyan/Turquoise
"page_ref": (200, 0, 200), # Purple
"column_marker": (0, 0, 220), # Red
"column_ignore": (180, 180, 180), # Light Gray
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for col in column_result["columns"]:
x, y = col["x"], col["y"]
w, h = col["width"], col["height"]
color = colors.get(col.get("type", ""), (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
# Label with confidence
label = col.get("type", "unknown").replace("column_", "").upper()
conf = col.get("classification_confidence")
if conf is not None and conf < 1.0:
label = f"{label} {int(conf * 100)}%"
cv2.putText(img, label, (x + 10, y + 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# Blend overlay at 20% opacity
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
# Draw detected box boundaries as dashed rectangles
zones = column_result.get("zones") or []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
box_color = (0, 200, 255) # Yellow (BGR)
# Draw dashed rectangle by drawing short line segments
dash_len = 15
for edge_x in range(bx, bx + bw, dash_len * 2):
end_x = min(edge_x + dash_len, bx + bw)
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
for edge_y in range(by, by + bh, dash_len * 2):
end_y = min(edge_y + dash_len, by + bh)
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
# ---------------------------------------------------------------------------
# Row Detection Endpoints
# ---------------------------------------------------------------------------
async def _get_rows_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with row bands drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
row_result = session.get("row_result")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=404, detail="No row data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for row types (BGR)
row_colors = {
"content": (255, 180, 0), # Blue
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for row in row_result["rows"]:
x, y = row["x"], row["y"]
w, h = row["width"], row["height"]
row_type = row.get("row_type", "content")
color = row_colors.get(row_type, (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
# Label
idx = row.get("index", 0)
label = f"R{idx} {row_type.upper()}"
wc = row.get("word_count", 0)
if wc:
label = f"{label} ({wc}w)"
cv2.putText(img, label, (x + 5, y + 18),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# Draw zone separator lines if zones exist
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
if zones:
img_w_px = img.shape[1]
zone_color = (0, 200, 255) # Yellow (BGR)
dash_len = 20
for zone in zones:
if zone.get("zone_type") == "box":
zy = zone["y"]
zh = zone["height"]
for line_y in [zy, zy + zh]:
for sx in range(0, img_w_px, dash_len * 2):
ex = min(sx + dash_len, img_w_px)
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_words_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=404, detail="No word data available")
# Support both new cell-based and legacy entry-based formats
cells = word_result.get("cells")
if not cells and not word_result.get("entries"):
raise HTTPException(status_code=404, detail="No word data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
img_h, img_w = img.shape[:2]
overlay = img.copy()
if cells:
# New cell-based overlay: color by column index
col_palette = [
(255, 180, 0), # Blue (BGR)
(0, 200, 0), # Green
(0, 140, 255), # Orange
(200, 100, 200), # Purple
(200, 200, 0), # Cyan
(100, 200, 200), # Yellow-ish
]
for cell in cells:
bbox = cell.get("bbox_px", {})
cx = bbox.get("x", 0)
cy = bbox.get("y", 0)
cw = bbox.get("w", 0)
ch = bbox.get("h", 0)
if cw <= 0 or ch <= 0:
continue
col_idx = cell.get("col_index", 0)
color = col_palette[col_idx % len(col_palette)]
# Cell rectangle border
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
# Semi-transparent fill
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
# Cell-ID label (top-left corner)
cell_id = cell.get("cell_id", "")
cv2.putText(img, cell_id, (cx + 2, cy + 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
# Text label (bottom of cell)
text = cell.get("text", "")
if text:
conf = cell.get("confidence", 0)
if conf >= 70:
text_color = (0, 180, 0)
elif conf >= 50:
text_color = (0, 180, 220)
else:
text_color = (0, 0, 220)
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, cy + ch - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
else:
# Legacy fallback: entry-based overlay (for old sessions)
column_result = session.get("column_result")
row_result = session.get("row_result")
col_colors = {
"column_en": (255, 180, 0),
"column_de": (0, 200, 0),
"column_example": (0, 140, 255),
}
columns = []
if column_result and column_result.get("columns"):
columns = [c for c in column_result["columns"]
if c.get("type", "").startswith("column_")]
content_rows_data = []
if row_result and row_result.get("rows"):
content_rows_data = [r for r in row_result["rows"]
if r.get("row_type") == "content"]
for col in columns:
col_type = col.get("type", "")
color = col_colors.get(col_type, (200, 200, 200))
cx, cw = col["x"], col["width"]
for row in content_rows_data:
ry, rh = row["y"], row["height"]
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
entries = word_result["entries"]
entry_by_row: Dict[int, Dict] = {}
for entry in entries:
entry_by_row[entry.get("row_index", -1)] = entry
for row_idx, row in enumerate(content_rows_data):
entry = entry_by_row.get(row_idx)
if not entry:
continue
conf = entry.get("confidence", 0)
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
ry, rh = row["y"], row["height"]
for col in columns:
col_type = col.get("type", "")
cx, cw = col["x"], col["width"]
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
text = entry.get(field, "") if field else ""
if text:
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, ry + rh - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
# Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
# Red semi-transparent overlay for box zones
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")

View File

@@ -1,607 +1,22 @@
"""
OCR Pipeline Regression Tests — Ground Truth comparison system.
OCR Pipeline Regression Tests — barrel re-export.
Allows marking sessions as "ground truth" and re-running build_grid()
to detect regressions after code changes.
All implementation split into:
ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
ocr_pipeline_regression_endpoints — FastAPI routes
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_pool,
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
# Helpers (used by grid_editor_api_grid.py)
from ocr_pipeline_regression_helpers import ( # noqa: F401
_init_regression_table,
_persist_regression_run,
_extract_cells_for_comparison,
_build_reference_snapshot,
compare_grids,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# DB persistence for regression runs
# ---------------------------------------------------------------------------
async def _init_regression_table():
"""Ensure regression_runs table exists (idempotent)."""
pool = await get_pool()
async with pool.acquire() as conn:
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/008_regression_runs.sql",
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
async def _persist_regression_run(
status: str,
summary: dict,
results: list,
duration_ms: int,
triggered_by: str = "manual",
) -> str:
"""Save a regression run to the database. Returns the run ID."""
try:
await _init_regression_table()
pool = await get_pool()
run_id = str(uuid.uuid4())
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO regression_runs
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
run_id,
status,
summary.get("total", 0),
summary.get("passed", 0),
summary.get("failed", 0),
summary.get("errors", 0),
duration_ms,
json.dumps(results),
triggered_by,
)
logger.info("Regression run %s persisted: %s", run_id, status)
return run_id
except Exception as e:
logger.warning("Failed to persist regression run: %s", e)
return ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Endpoints (router used by ocr_pipeline_api.py)
from ocr_pipeline_regression_endpoints import router # noqa: F401

View File

@@ -0,0 +1,421 @@
"""
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
from ocr_pipeline_regression_helpers import (
_build_reference_snapshot,
_init_regression_table,
_persist_regression_run,
compare_grids,
get_pool,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,207 @@
"""
OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from ocr_pipeline_session_store import get_pool
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DB persistence for regression runs
# ---------------------------------------------------------------------------
async def _init_regression_table():
"""Ensure regression_runs table exists (idempotent)."""
pool = await get_pool()
async with pool.acquire() as conn:
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/008_regression_runs.sql",
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
async def _persist_regression_run(
status: str,
summary: dict,
results: list,
duration_ms: int,
triggered_by: str = "manual",
) -> str:
"""Save a regression run to the database. Returns the run ID."""
try:
await _init_regression_table()
pool = await get_pool()
run_id = str(uuid.uuid4())
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO regression_runs
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
run_id,
status,
summary.get("total", 0),
summary.get("passed", 0),
summary.get("failed", 0),
summary.get("errors", 0),
duration_ms,
json.dumps(results),
triggered_by,
)
logger.info("Regression run %s persisted: %s", run_id, status)
return run_id
except Exception as e:
logger.warning("Failed to persist regression run: %s", e)
return ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}

View File

@@ -1,597 +1,20 @@
"""
OCR Pipeline Sessions API - Session management and image serving endpoints.
OCR Pipeline Sessions API — barrel re-export.
Extracted from ocr_pipeline_api.py for modularity.
Handles: CRUD for sessions, thumbnails, pipeline logs, categories,
image serving (with overlay dispatch), and document type detection.
All implementation split into:
ocr_pipeline_sessions_crud — session CRUD, box sessions
ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
import uuid
from typing import Any, Dict, Optional
from fastapi import APIRouter
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import Response
from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401
from ocr_pipeline_sessions_images import router as _images_router # noqa: F401
from cv_vocab_pipeline import (
create_ocr_image,
detect_document_type,
render_image_high_res,
render_pdf_high_res,
)
from ocr_pipeline_common import (
VALID_DOCUMENT_CATEGORIES,
UpdateSessionRequest,
_append_pipeline_log,
_cache,
_get_base_image_png,
_get_cached,
_load_session_to_cache,
)
from ocr_pipeline_overlays import render_overlay
from ocr_pipeline_session_store import (
create_session_db,
delete_all_sessions_db,
delete_session_db,
get_session_db,
get_session_image,
get_sub_sessions,
list_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Session Management Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions")
async def list_sessions(include_sub_sessions: bool = False):
"""List OCR pipeline sessions.
By default, sub-sessions (box regions) are hidden.
Pass ?include_sub_sessions=true to show them.
"""
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
return {"sessions": sessions}
@router.post("/sessions")
async def create_session(
file: UploadFile = File(...),
name: Optional[str] = Form(None),
):
"""Upload a PDF or image file and create a pipeline session.
For multi-page PDFs (> 1 page), each page becomes its own session
grouped under a ``document_group_id``. The response includes a
``pages`` array with one entry per page/session.
"""
file_data = await file.read()
filename = file.filename or "upload"
content_type = file.content_type or ""
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
session_name = name or filename
# --- Multi-page PDF handling ---
if is_pdf:
try:
import fitz # PyMuPDF
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
page_count = pdf_doc.page_count
pdf_doc.close()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
if page_count > 1:
return await _create_multi_page_sessions(
file_data, filename, session_name, page_count,
)
# --- Single page (image or 1-page PDF) ---
session_id = str(uuid.uuid4())
try:
if is_pdf:
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
else:
img_bgr = render_image_high_res(file_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
# Encode original as PNG bytes
success, png_buf = cv2.imencode(".png", img_bgr)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image")
original_png = png_buf.tobytes()
# Persist to DB
await create_session_db(
session_id=session_id,
name=session_name,
filename=filename,
original_png=original_png,
)
# Cache BGR array for immediate processing
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": session_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
return {
"session_id": session_id,
"filename": filename,
"name": session_name,
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
}
async def _create_multi_page_sessions(
pdf_data: bytes,
filename: str,
base_name: str,
page_count: int,
) -> dict:
"""Create one session per PDF page, grouped by document_group_id."""
document_group_id = str(uuid.uuid4())
pages = []
for page_idx in range(page_count):
session_id = str(uuid.uuid4())
page_name = f"{base_name} — Seite {page_idx + 1}"
try:
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
except Exception as e:
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
continue
ok, png_buf = cv2.imencode(".png", img_bgr)
if not ok:
continue
page_png = png_buf.tobytes()
await create_session_db(
session_id=session_id,
name=page_name,
filename=filename,
original_png=page_png,
document_group_id=document_group_id,
page_number=page_idx + 1,
)
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": page_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
h, w = img_bgr.shape[:2]
pages.append({
"session_id": session_id,
"name": page_name,
"page_number": page_idx + 1,
"image_width": w,
"image_height": h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
})
logger.info(
f"OCR Pipeline: created page session {session_id} "
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
)
# Include session_id pointing to first page for backwards compatibility
# (frontends that expect a single session_id will navigate to page 1)
first_session_id = pages[0]["session_id"] if pages else None
return {
"session_id": first_session_id,
"document_group_id": document_group_id,
"filename": filename,
"name": base_name,
"page_count": page_count,
"pages": pages,
}
@router.get("/sessions/{session_id}")
async def get_session_info(session_id: str):
"""Get session info including deskew/dewarp/column results for step navigation."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get image dimensions from original PNG
original_png = await get_session_image(session_id, "original")
if original_png:
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
else:
img_w, img_h = 0, 0
result = {
"session_id": session["id"],
"filename": session.get("filename", ""),
"name": session.get("name", ""),
"image_width": img_w,
"image_height": img_h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
"current_step": session.get("current_step", 1),
"document_category": session.get("document_category"),
"doc_type": session.get("doc_type"),
}
if session.get("orientation_result"):
result["orientation_result"] = session["orientation_result"]
if session.get("crop_result"):
result["crop_result"] = session["crop_result"]
if session.get("deskew_result"):
result["deskew_result"] = session["deskew_result"]
if session.get("dewarp_result"):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
if session.get("row_result"):
result["row_result"] = session["row_result"]
if session.get("word_result"):
result["word_result"] = session["word_result"]
if session.get("doc_type_result"):
result["doc_type_result"] = session["doc_type_result"]
if session.get("structure_result"):
result["structure_result"] = session["structure_result"]
if session.get("grid_editor_result"):
# Include summary only to keep response small
gr = session["grid_editor_result"]
result["grid_editor_result"] = {
"summary": gr.get("summary", {}),
"zones_count": len(gr.get("zones", [])),
"edited": gr.get("edited", False),
}
if session.get("ground_truth"):
result["ground_truth"] = session["ground_truth"]
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
if session.get("parent_session_id"):
result["parent_session_id"] = session["parent_session_id"]
result["box_index"] = session.get("box_index")
else:
# Check for box sub-sessions (column detection creates these)
subs = await get_sub_sessions(session_id)
if subs:
result["sub_sessions"] = [
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
for s in subs
]
return result
@router.put("/sessions/{session_id}")
async def update_session(session_id: str, req: UpdateSessionRequest):
"""Update session name and/or document category."""
kwargs: Dict[str, Any] = {}
if req.name is not None:
kwargs["name"] = req.name
if req.document_category is not None:
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
raise HTTPException(
status_code=400,
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
)
kwargs["document_category"] = req.document_category
if not kwargs:
raise HTTPException(status_code=400, detail="Nothing to update")
updated = await update_session_db(session_id, **kwargs)
if not updated:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, **kwargs}
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
_cache.pop(session_id, None)
deleted = await delete_session_db(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "deleted": True}
@router.delete("/sessions")
async def delete_all_sessions():
"""Delete ALL sessions (cleanup)."""
_cache.clear()
count = await delete_all_sessions_db()
return {"deleted_count": count}
@router.post("/sessions/{session_id}/create-box-sessions")
async def create_box_sessions(session_id: str):
"""Create sub-sessions for each detected box region.
Crops box regions from the cropped/dewarped image and creates
independent sub-sessions that can be processed through the pipeline.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result:
raise HTTPException(status_code=400, detail="Column detection must be completed first")
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
if not box_zones:
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
# Check for existing sub-sessions
existing = await get_sub_sessions(session_id)
if existing:
return {
"session_id": session_id,
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
"message": f"{len(existing)} sub-session(s) already exist",
}
# Load base image
base_png = await get_session_image(session_id, "cropped")
if not base_png:
base_png = await get_session_image(session_id, "dewarped")
if not base_png:
raise HTTPException(status_code=400, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
parent_name = session.get("name", "Session")
created = []
for i, zone in enumerate(box_zones):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Crop box region with small padding
pad = 5
y1 = max(0, by - pad)
y2 = min(img.shape[0], by + bh + pad)
x1 = max(0, bx - pad)
x2 = min(img.shape[1], bx + bw + pad)
crop = img[y1:y2, x1:x2]
# Encode as PNG
success, png_buf = cv2.imencode(".png", crop)
if not success:
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
continue
sub_id = str(uuid.uuid4())
sub_name = f"{parent_name} — Box {i + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=session.get("filename", "box-crop.png"),
original_png=png_buf.tobytes(),
parent_session_id=session_id,
box_index=i,
)
# Cache the BGR for immediate processing
# Promote original to cropped so column/row/word detection finds it
box_bgr = crop.copy()
_cache[sub_id] = {
"id": sub_id,
"filename": session.get("filename", "box-crop.png"),
"name": sub_name,
"parent_session_id": session_id,
"original_bgr": box_bgr,
"oriented_bgr": None,
"cropped_bgr": box_bgr,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
created.append({
"id": sub_id,
"name": sub_name,
"box_index": i,
"box": box,
"image_width": crop.shape[1],
"image_height": crop.shape[0],
})
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
return {
"session_id": session_id,
"sub_sessions": created,
"total": len(created),
}
@router.get("/sessions/{session_id}/thumbnail")
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
"""Return a small thumbnail of the original image."""
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
scale = size / max(h, w)
new_w, new_h = int(w * scale), int(h * scale)
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
_, png_bytes = cv2.imencode(".png", thumb)
return Response(content=png_bytes.tobytes(), media_type="image/png",
headers={"Cache-Control": "public, max-age=3600"})
@router.get("/sessions/{session_id}/pipeline-log")
async def get_pipeline_log(session_id: str):
"""Get the pipeline execution log for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
@router.get("/categories")
async def list_categories():
"""List valid document categories."""
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
# ---------------------------------------------------------------------------
# Image Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if image_type == "structure-overlay":
return await render_overlay("structure", session_id)
if image_type == "columns-overlay":
return await render_overlay("columns", session_id)
if image_type == "rows-overlay":
return await render_overlay("rows", session_id)
if image_type == "words-overlay":
return await render_overlay("words", session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
png_key = f"{image_type}_png" if image_type != "original" else None
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
# For binarized, check if we have it cached as PNG
if image_type == "binarized" and cached.get("binarized_png"):
return Response(content=cached["binarized_png"], media_type="image/png")
# Load from DB — for cropped/dewarped, fall back through the chain
if image_type in ("cropped", "dewarped"):
data = await _get_base_image_png(session_id)
else:
data = await get_session_image(session_id, image_type)
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
# ---------------------------------------------------------------------------
# Document Type Detection (between Dewarp and Columns)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-type")
async def detect_type(session_id: str):
"""Detect document type (vocab_table, full_text, generic_table).
Should be called after crop (clean image available).
Falls back to dewarped if crop was skipped.
Stores result in session for frontend to decide pipeline flow.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
ocr_img = create_ocr_image(img_bgr)
result = detect_document_type(ocr_img, img_bgr)
duration = time.time() - t0
result_dict = {
"doc_type": result.doc_type,
"confidence": result.confidence,
"pipeline": result.pipeline,
"skip_steps": result.skip_steps,
"features": result.features,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
doc_type=result.doc_type,
doc_type_result=result_dict,
)
cached["doc_type_result"] = result_dict
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
await _append_pipeline_log(session_id, "detect_type", {
"doc_type": result.doc_type,
"pipeline": result.pipeline,
"confidence": result.confidence,
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **result_dict}
# Composite router (used by ocr_pipeline_api.py)
router = APIRouter()
router.include_router(_crud_router)
router.include_router(_images_router)

View File

@@ -0,0 +1,449 @@
"""
OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
Extracted from ocr_pipeline_sessions.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import uuid
from typing import Any, Dict, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
from ocr_pipeline_common import (
VALID_DOCUMENT_CATEGORIES,
UpdateSessionRequest,
_cache,
)
from ocr_pipeline_session_store import (
create_session_db,
delete_all_sessions_db,
delete_session_db,
get_session_db,
get_session_image,
get_sub_sessions,
list_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Session Management Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions")
async def list_sessions(include_sub_sessions: bool = False):
"""List OCR pipeline sessions.
By default, sub-sessions (box regions) are hidden.
Pass ?include_sub_sessions=true to show them.
"""
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
return {"sessions": sessions}
@router.post("/sessions")
async def create_session(
file: UploadFile = File(...),
name: Optional[str] = Form(None),
):
"""Upload a PDF or image file and create a pipeline session.
For multi-page PDFs (> 1 page), each page becomes its own session
grouped under a ``document_group_id``. The response includes a
``pages`` array with one entry per page/session.
"""
file_data = await file.read()
filename = file.filename or "upload"
content_type = file.content_type or ""
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
session_name = name or filename
# --- Multi-page PDF handling ---
if is_pdf:
try:
import fitz # PyMuPDF
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
page_count = pdf_doc.page_count
pdf_doc.close()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
if page_count > 1:
return await _create_multi_page_sessions(
file_data, filename, session_name, page_count,
)
# --- Single page (image or 1-page PDF) ---
session_id = str(uuid.uuid4())
try:
if is_pdf:
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
else:
img_bgr = render_image_high_res(file_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
# Encode original as PNG bytes
success, png_buf = cv2.imencode(".png", img_bgr)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image")
original_png = png_buf.tobytes()
# Persist to DB
await create_session_db(
session_id=session_id,
name=session_name,
filename=filename,
original_png=original_png,
)
# Cache BGR array for immediate processing
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": session_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
return {
"session_id": session_id,
"filename": filename,
"name": session_name,
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
}
async def _create_multi_page_sessions(
pdf_data: bytes,
filename: str,
base_name: str,
page_count: int,
) -> dict:
"""Create one session per PDF page, grouped by document_group_id."""
document_group_id = str(uuid.uuid4())
pages = []
for page_idx in range(page_count):
session_id = str(uuid.uuid4())
page_name = f"{base_name} — Seite {page_idx + 1}"
try:
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
except Exception as e:
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
continue
ok, png_buf = cv2.imencode(".png", img_bgr)
if not ok:
continue
page_png = png_buf.tobytes()
await create_session_db(
session_id=session_id,
name=page_name,
filename=filename,
original_png=page_png,
document_group_id=document_group_id,
page_number=page_idx + 1,
)
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": page_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
h, w = img_bgr.shape[:2]
pages.append({
"session_id": session_id,
"name": page_name,
"page_number": page_idx + 1,
"image_width": w,
"image_height": h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
})
logger.info(
f"OCR Pipeline: created page session {session_id} "
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
)
# Include session_id pointing to first page for backwards compatibility
# (frontends that expect a single session_id will navigate to page 1)
first_session_id = pages[0]["session_id"] if pages else None
return {
"session_id": first_session_id,
"document_group_id": document_group_id,
"filename": filename,
"name": base_name,
"page_count": page_count,
"pages": pages,
}
@router.get("/sessions/{session_id}")
async def get_session_info(session_id: str):
"""Get session info including deskew/dewarp/column results for step navigation."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get image dimensions from original PNG
original_png = await get_session_image(session_id, "original")
if original_png:
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
else:
img_w, img_h = 0, 0
result = {
"session_id": session["id"],
"filename": session.get("filename", ""),
"name": session.get("name", ""),
"image_width": img_w,
"image_height": img_h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
"current_step": session.get("current_step", 1),
"document_category": session.get("document_category"),
"doc_type": session.get("doc_type"),
}
if session.get("orientation_result"):
result["orientation_result"] = session["orientation_result"]
if session.get("crop_result"):
result["crop_result"] = session["crop_result"]
if session.get("deskew_result"):
result["deskew_result"] = session["deskew_result"]
if session.get("dewarp_result"):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
if session.get("row_result"):
result["row_result"] = session["row_result"]
if session.get("word_result"):
result["word_result"] = session["word_result"]
if session.get("doc_type_result"):
result["doc_type_result"] = session["doc_type_result"]
if session.get("structure_result"):
result["structure_result"] = session["structure_result"]
if session.get("grid_editor_result"):
# Include summary only to keep response small
gr = session["grid_editor_result"]
result["grid_editor_result"] = {
"summary": gr.get("summary", {}),
"zones_count": len(gr.get("zones", [])),
"edited": gr.get("edited", False),
}
if session.get("ground_truth"):
result["ground_truth"] = session["ground_truth"]
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
if session.get("parent_session_id"):
result["parent_session_id"] = session["parent_session_id"]
result["box_index"] = session.get("box_index")
else:
# Check for box sub-sessions (column detection creates these)
subs = await get_sub_sessions(session_id)
if subs:
result["sub_sessions"] = [
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
for s in subs
]
return result
@router.put("/sessions/{session_id}")
async def update_session(session_id: str, req: UpdateSessionRequest):
"""Update session name and/or document category."""
kwargs: Dict[str, Any] = {}
if req.name is not None:
kwargs["name"] = req.name
if req.document_category is not None:
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
raise HTTPException(
status_code=400,
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
)
kwargs["document_category"] = req.document_category
if not kwargs:
raise HTTPException(status_code=400, detail="Nothing to update")
updated = await update_session_db(session_id, **kwargs)
if not updated:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, **kwargs}
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
_cache.pop(session_id, None)
deleted = await delete_session_db(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "deleted": True}
@router.delete("/sessions")
async def delete_all_sessions():
"""Delete ALL sessions (cleanup)."""
_cache.clear()
count = await delete_all_sessions_db()
return {"deleted_count": count}
@router.post("/sessions/{session_id}/create-box-sessions")
async def create_box_sessions(session_id: str):
"""Create sub-sessions for each detected box region.
Crops box regions from the cropped/dewarped image and creates
independent sub-sessions that can be processed through the pipeline.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result:
raise HTTPException(status_code=400, detail="Column detection must be completed first")
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
if not box_zones:
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
# Check for existing sub-sessions
existing = await get_sub_sessions(session_id)
if existing:
return {
"session_id": session_id,
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
"message": f"{len(existing)} sub-session(s) already exist",
}
# Load base image
base_png = await get_session_image(session_id, "cropped")
if not base_png:
base_png = await get_session_image(session_id, "dewarped")
if not base_png:
raise HTTPException(status_code=400, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
parent_name = session.get("name", "Session")
created = []
for i, zone in enumerate(box_zones):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Crop box region with small padding
pad = 5
y1 = max(0, by - pad)
y2 = min(img.shape[0], by + bh + pad)
x1 = max(0, bx - pad)
x2 = min(img.shape[1], bx + bw + pad)
crop = img[y1:y2, x1:x2]
# Encode as PNG
success, png_buf = cv2.imencode(".png", crop)
if not success:
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
continue
sub_id = str(uuid.uuid4())
sub_name = f"{parent_name} — Box {i + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=session.get("filename", "box-crop.png"),
original_png=png_buf.tobytes(),
parent_session_id=session_id,
box_index=i,
)
# Cache the BGR for immediate processing
# Promote original to cropped so column/row/word detection finds it
box_bgr = crop.copy()
_cache[sub_id] = {
"id": sub_id,
"filename": session.get("filename", "box-crop.png"),
"name": sub_name,
"parent_session_id": session_id,
"original_bgr": box_bgr,
"oriented_bgr": None,
"cropped_bgr": box_bgr,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
created.append({
"id": sub_id,
"name": sub_name,
"box_index": i,
"box": box,
"image_width": crop.shape[1],
"image_height": crop.shape[0],
})
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
return {
"session_id": session_id,
"sub_sessions": created,
"total": len(created),
}

View File

@@ -0,0 +1,176 @@
"""
OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
categories, and document type detection.
Extracted from ocr_pipeline_sessions.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from typing import Any, Dict
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import Response
from cv_vocab_pipeline import create_ocr_image, detect_document_type
from ocr_pipeline_common import (
VALID_DOCUMENT_CATEGORIES,
_append_pipeline_log,
_cache,
_get_base_image_png,
_get_cached,
_load_session_to_cache,
)
from ocr_pipeline_overlays import render_overlay
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Thumbnail & Log Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/thumbnail")
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
"""Return a small thumbnail of the original image."""
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
scale = size / max(h, w)
new_w, new_h = int(w * scale), int(h * scale)
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
_, png_bytes = cv2.imencode(".png", thumb)
return Response(content=png_bytes.tobytes(), media_type="image/png",
headers={"Cache-Control": "public, max-age=3600"})
@router.get("/sessions/{session_id}/pipeline-log")
async def get_pipeline_log(session_id: str):
"""Get the pipeline execution log for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
@router.get("/categories")
async def list_categories():
"""List valid document categories."""
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
# ---------------------------------------------------------------------------
# Image Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if image_type == "structure-overlay":
return await render_overlay("structure", session_id)
if image_type == "columns-overlay":
return await render_overlay("columns", session_id)
if image_type == "rows-overlay":
return await render_overlay("rows", session_id)
if image_type == "words-overlay":
return await render_overlay("words", session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
png_key = f"{image_type}_png" if image_type != "original" else None
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
# For binarized, check if we have it cached as PNG
if image_type == "binarized" and cached.get("binarized_png"):
return Response(content=cached["binarized_png"], media_type="image/png")
# Load from DB — for cropped/dewarped, fall back through the chain
if image_type in ("cropped", "dewarped"):
data = await _get_base_image_png(session_id)
else:
data = await get_session_image(session_id, image_type)
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
# ---------------------------------------------------------------------------
# Document Type Detection (between Dewarp and Columns)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-type")
async def detect_type(session_id: str):
"""Detect document type (vocab_table, full_text, generic_table).
Should be called after crop (clean image available).
Falls back to dewarped if crop was skipped.
Stores result in session for frontend to decide pipeline flow.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
ocr_img = create_ocr_image(img_bgr)
result = detect_document_type(ocr_img, img_bgr)
duration = time.time() - t0
result_dict = {
"doc_type": result.doc_type,
"confidence": result.confidence,
"pipeline": result.pipeline,
"skip_steps": result.skip_steps,
"features": result.features,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
doc_type=result.doc_type,
doc_type_result=result_dict,
)
cached["doc_type_result"] = result_dict
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
await _append_pipeline_log(session_id, "detect_type", {
"doc_type": result.doc_type,
"pipeline": result.pipeline,
"confidence": result.confidence,
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **result_dict}

View File

@@ -1,529 +1,38 @@
"""
Self-RAG / Corrective RAG Module
Self-RAG / Corrective RAG Module — barrel re-export.
Implements self-reflective RAG that can:
1. Grade retrieved documents for relevance
2. Decide if more retrieval is needed
3. Reformulate queries if initial retrieval fails
4. Filter irrelevant passages before generation
5. Grade answers for groundedness and hallucination
All implementation split into:
self_rag_grading — document relevance grading, filtering, decisions
self_rag_retrieval — query reformulation, retrieval loop, info
IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
When enabled, search queries and retrieved documents are sent to OpenAI API.
Based on research:
- Self-RAG (Asai et al., 2023): Learning to retrieve, generate, and critique
- Corrective RAG (Yan et al., 2024): Self-correcting retrieval augmented generation
This is especially useful for German educational documents where:
- Queries may use informal language
- Documents use formal/technical terminology
- Context must be precisely matched to scoring criteria
- Self-RAG (Asai et al., 2023)
- Corrective RAG (Yan et al., 2024)
"""
import os
from typing import List, Dict, Optional, Tuple
from enum import Enum
import httpx
# Configuration
# IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
# When enabled, search queries and retrieved documents are sent to OpenAI API
# for relevance grading and query reformulation. This exposes user data to third parties.
# Only enable if you have explicit user consent for data processing.
SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
# Thresholds for self-reflection
RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
class RetrievalDecision(Enum):
"""Decision after grading retrieval."""
SUFFICIENT = "sufficient" # Context is good, proceed to generation
NEEDS_MORE = "needs_more" # Need to retrieve more documents
REFORMULATE = "reformulate" # Query needs reformulation
FALLBACK = "fallback" # Use fallback (no good context found)
class SelfRAGError(Exception):
"""Error during Self-RAG processing."""
pass
async def grade_document_relevance(
query: str,
document: str,
) -> Tuple[float, str]:
"""
Grade whether a document is relevant to the query.
Returns a score between 0 (irrelevant) and 1 (highly relevant)
along with an explanation.
"""
if not OPENAI_API_KEY:
# Fallback: simple keyword overlap
query_words = set(query.lower().split())
doc_words = set(document.lower().split())
overlap = len(query_words & doc_words) / max(len(query_words), 1)
return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
prompt = f"""Bewerte, ob das folgende Dokument relevant für die Suchanfrage ist.
SUCHANFRAGE: {query}
DOKUMENT:
{document[:2000]}
Ist dieses Dokument relevant, um die Anfrage zu beantworten?
Berücksichtige:
- Thematische Übereinstimmung
- Enthält das Dokument spezifische Informationen zur Anfrage?
- Würde dieses Dokument bei der Beantwortung helfen?
Antworte im Format:
SCORE: [0.0-1.0]
BEGRÜNDUNG: [Kurze Erklärung]"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": SELF_RAG_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 150,
"temperature": 0.0,
},
timeout=30.0
)
if response.status_code != 200:
return 0.5, f"API error: {response.status_code}"
result = response.json()["choices"][0]["message"]["content"]
import re
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
score = float(score_match.group(1)) if score_match else 0.5
reason_match = re.search(r'BEGRÜNDUNG:\s*(.+)', result, re.DOTALL)
reason = reason_match.group(1).strip() if reason_match else result
return min(max(score, 0.0), 1.0), reason
except Exception as e:
return 0.5, f"Grading error: {str(e)}"
async def grade_documents_batch(
query: str,
documents: List[str],
) -> List[Tuple[float, str]]:
"""
Grade multiple documents for relevance.
Returns list of (score, reason) tuples.
"""
results = []
for doc in documents:
score, reason = await grade_document_relevance(query, doc)
results.append((score, reason))
return results
async def filter_relevant_documents(
query: str,
documents: List[Dict],
threshold: float = RELEVANCE_THRESHOLD,
) -> Tuple[List[Dict], List[Dict]]:
"""
Filter documents by relevance, separating relevant from irrelevant.
Args:
query: The search query
documents: List of document dicts with 'text' field
threshold: Minimum relevance score to keep
Returns:
Tuple of (relevant_docs, filtered_out_docs)
"""
relevant = []
filtered = []
for doc in documents:
text = doc.get("text", "")
score, reason = await grade_document_relevance(query, text)
doc_with_grade = doc.copy()
doc_with_grade["relevance_score"] = score
doc_with_grade["relevance_reason"] = reason
if score >= threshold:
relevant.append(doc_with_grade)
else:
filtered.append(doc_with_grade)
# Sort relevant by score
relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
return relevant, filtered
async def decide_retrieval_strategy(
query: str,
documents: List[Dict],
attempt: int = 1,
) -> Tuple[RetrievalDecision, Dict]:
"""
Decide what to do based on current retrieval results.
Args:
query: The search query
documents: Retrieved documents with relevance scores
attempt: Current retrieval attempt number
Returns:
Tuple of (decision, metadata)
"""
if not documents:
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
# Check average relevance
scores = [doc.get("relevance_score", 0.5) for doc in documents]
avg_score = sum(scores) / len(scores)
max_score = max(scores)
if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
return RetrievalDecision.SUFFICIENT, {
"avg_relevance": avg_score,
"max_relevance": max_score,
"doc_count": len(documents),
}
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
if max_score >= RELEVANCE_THRESHOLD * 0.5:
# At least some relevant context, proceed with caution
return RetrievalDecision.SUFFICIENT, {
"avg_relevance": avg_score,
"warning": "Low relevance after max attempts",
}
return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
if avg_score < 0.3:
return RetrievalDecision.REFORMULATE, {
"reason": "Very low relevance, query reformulation needed",
"avg_relevance": avg_score,
}
return RetrievalDecision.NEEDS_MORE, {
"reason": "Moderate relevance, retrieving more documents",
"avg_relevance": avg_score,
}
async def reformulate_query(
original_query: str,
context: Optional[str] = None,
previous_results_summary: Optional[str] = None,
) -> str:
"""
Reformulate a query to improve retrieval.
Uses LLM to generate a better query based on:
- Original query
- Optional context (subject, niveau, etc.)
- Summary of why previous retrieval failed
"""
if not OPENAI_API_KEY:
# Simple reformulation: expand abbreviations, add synonyms
reformulated = original_query
expansions = {
"EA": "erhöhtes Anforderungsniveau",
"eA": "erhöhtes Anforderungsniveau",
"GA": "grundlegendes Anforderungsniveau",
"gA": "grundlegendes Anforderungsniveau",
"AFB": "Anforderungsbereich",
"Abi": "Abitur",
}
for abbr, expansion in expansions.items():
if abbr in original_query:
reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
return reformulated
prompt = f"""Du bist ein Experte für deutsche Bildungsstandards und Prüfungsanforderungen.
Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
ORIGINAL: {original_query}
{f"KONTEXT: {context}" if context else ""}
{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
Formuliere die Anfrage so um, dass sie:
1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
2. Relevante Synonyme oder verwandte Begriffe einschließt
3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
Antworte NUR mit der umformulierten Suchanfrage, ohne Erklärung."""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": SELF_RAG_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 100,
"temperature": 0.3,
},
timeout=30.0
)
if response.status_code != 200:
return original_query
return response.json()["choices"][0]["message"]["content"].strip()
except Exception:
return original_query
async def grade_answer_groundedness(
answer: str,
contexts: List[str],
) -> Tuple[float, List[str]]:
"""
Grade whether an answer is grounded in the provided contexts.
Returns:
Tuple of (grounding_score, list of unsupported claims)
"""
if not OPENAI_API_KEY:
return 0.5, ["LLM not configured for grounding check"]
context_text = "\n---\n".join(contexts[:5])
prompt = f"""Analysiere, ob die folgende Antwort vollständig durch die Kontexte gestützt wird.
KONTEXTE:
{context_text}
ANTWORT:
{answer}
Identifiziere:
1. Welche Aussagen sind durch die Kontexte belegt?
2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
Antworte im Format:
SCORE: [0.0-1.0] (1.0 = vollständig belegt)
NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": SELF_RAG_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 300,
"temperature": 0.0,
},
timeout=30.0
)
if response.status_code != 200:
return 0.5, [f"API error: {response.status_code}"]
result = response.json()["choices"][0]["message"]["content"]
import re
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
score = float(score_match.group(1)) if score_match else 0.5
unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
if unsupported_text.lower() == "keine":
unsupported = []
else:
unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
return min(max(score, 0.0), 1.0), unsupported
except Exception as e:
return 0.5, [f"Grounding check error: {str(e)}"]
async def self_rag_retrieve(
query: str,
search_func,
subject: Optional[str] = None,
niveau: Optional[str] = None,
initial_top_k: int = 10,
final_top_k: int = 5,
**search_kwargs
) -> Dict:
"""
Perform Self-RAG enhanced retrieval with reflection and correction.
This implements a retrieval loop that:
1. Retrieves initial documents
2. Grades them for relevance
3. Decides if more retrieval is needed
4. Reformulates query if necessary
5. Returns filtered, high-quality context
Args:
query: The search query
search_func: Async function to perform the actual search
subject: Optional subject context
niveau: Optional niveau context
initial_top_k: Number of documents for initial retrieval
final_top_k: Maximum documents to return
**search_kwargs: Additional args for search_func
Returns:
Dict with results, metadata, and reflection trace
"""
if not SELF_RAG_ENABLED:
# Fall back to simple search
results = await search_func(query=query, limit=final_top_k, **search_kwargs)
return {
"results": results,
"self_rag_enabled": False,
"query_used": query,
}
trace = []
current_query = query
attempt = 1
while attempt <= MAX_RETRIEVAL_ATTEMPTS:
# Step 1: Retrieve documents
results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
trace.append({
"attempt": attempt,
"query": current_query,
"retrieved_count": len(results) if results else 0,
})
if not results:
attempt += 1
if attempt <= MAX_RETRIEVAL_ATTEMPTS:
current_query = await reformulate_query(
query,
context=f"Fach: {subject}" if subject else None,
previous_results_summary="Keine Dokumente gefunden"
)
trace[-1]["action"] = "reformulate"
trace[-1]["new_query"] = current_query
continue
# Step 2: Grade documents for relevance
relevant, filtered = await filter_relevant_documents(current_query, results)
trace[-1]["relevant_count"] = len(relevant)
trace[-1]["filtered_count"] = len(filtered)
# Step 3: Decide what to do
decision, decision_meta = await decide_retrieval_strategy(
current_query, relevant, attempt
)
trace[-1]["decision"] = decision.value
trace[-1]["decision_meta"] = decision_meta
if decision == RetrievalDecision.SUFFICIENT:
# We have good context, return it
return {
"results": relevant[:final_top_k],
"self_rag_enabled": True,
"query_used": current_query,
"original_query": query if current_query != query else None,
"attempts": attempt,
"decision": decision.value,
"trace": trace,
"filtered_out_count": len(filtered),
}
elif decision == RetrievalDecision.REFORMULATE:
# Reformulate and try again
avg_score = decision_meta.get("avg_relevance", 0)
current_query = await reformulate_query(
query,
context=f"Fach: {subject}" if subject else None,
previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
)
trace[-1]["action"] = "reformulate"
trace[-1]["new_query"] = current_query
elif decision == RetrievalDecision.NEEDS_MORE:
# Retrieve more with expanded query
current_query = f"{current_query} Bewertungskriterien Anforderungen"
trace[-1]["action"] = "expand_query"
trace[-1]["new_query"] = current_query
elif decision == RetrievalDecision.FALLBACK:
# Return what we have, even if not ideal
return {
"results": (relevant or results)[:final_top_k],
"self_rag_enabled": True,
"query_used": current_query,
"original_query": query if current_query != query else None,
"attempts": attempt,
"decision": decision.value,
"warning": "Fallback mode - low relevance context",
"trace": trace,
}
attempt += 1
# Max attempts reached
return {
"results": results[:final_top_k] if results else [],
"self_rag_enabled": True,
"query_used": current_query,
"original_query": query if current_query != query else None,
"attempts": attempt - 1,
"decision": "max_attempts",
"warning": "Max retrieval attempts reached",
"trace": trace,
}
def get_self_rag_info() -> dict:
"""Get information about Self-RAG configuration."""
return {
"enabled": SELF_RAG_ENABLED,
"llm_configured": bool(OPENAI_API_KEY),
"model": SELF_RAG_MODEL,
"relevance_threshold": RELEVANCE_THRESHOLD,
"grounding_threshold": GROUNDING_THRESHOLD,
"max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
"features": [
"document_grading",
"relevance_filtering",
"query_reformulation",
"answer_grounding_check",
"retrieval_decision",
],
"sends_data_externally": True, # ALWAYS true when enabled - documents sent to OpenAI
"privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
"default_enabled": False, # Disabled by default for privacy
}
# Grading: relevance, filtering, decisions, groundedness
from self_rag_grading import ( # noqa: F401
SELF_RAG_ENABLED,
OPENAI_API_KEY,
SELF_RAG_MODEL,
RELEVANCE_THRESHOLD,
GROUNDING_THRESHOLD,
MAX_RETRIEVAL_ATTEMPTS,
RetrievalDecision,
SelfRAGError,
grade_document_relevance,
grade_documents_batch,
filter_relevant_documents,
decide_retrieval_strategy,
grade_answer_groundedness,
)
# Retrieval: reformulation, loop, info
from self_rag_retrieval import ( # noqa: F401
reformulate_query,
self_rag_retrieve,
get_self_rag_info,
)

View File

@@ -0,0 +1,285 @@
"""
Self-RAG Grading — document relevance grading, filtering, retrieval decisions.
Extracted from self_rag.py for modularity.
Based on research:
- Self-RAG (Asai et al., 2023)
- Corrective RAG (Yan et al., 2024)
"""
import os
from typing import List, Dict, Optional, Tuple
from enum import Enum
import httpx
# Configuration
SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
# Thresholds for self-reflection
RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
class RetrievalDecision(Enum):
"""Decision after grading retrieval."""
SUFFICIENT = "sufficient" # Context is good, proceed to generation
NEEDS_MORE = "needs_more" # Need to retrieve more documents
REFORMULATE = "reformulate" # Query needs reformulation
FALLBACK = "fallback" # Use fallback (no good context found)
class SelfRAGError(Exception):
"""Error during Self-RAG processing."""
pass
async def grade_document_relevance(
query: str,
document: str,
) -> Tuple[float, str]:
"""
Grade whether a document is relevant to the query.
Returns a score between 0 (irrelevant) and 1 (highly relevant)
along with an explanation.
"""
if not OPENAI_API_KEY:
# Fallback: simple keyword overlap
query_words = set(query.lower().split())
doc_words = set(document.lower().split())
overlap = len(query_words & doc_words) / max(len(query_words), 1)
return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
prompt = f"""Bewerte, ob das folgende Dokument relevant fuer die Suchanfrage ist.
SUCHANFRAGE: {query}
DOKUMENT:
{document[:2000]}
Ist dieses Dokument relevant, um die Anfrage zu beantworten?
Beruecksichtige:
- Thematische Uebereinstimmung
- Enthaelt das Dokument spezifische Informationen zur Anfrage?
- Wuerde dieses Dokument bei der Beantwortung helfen?
Antworte im Format:
SCORE: [0.0-1.0]
BEGRUENDUNG: [Kurze Erklaerung]"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": SELF_RAG_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 150,
"temperature": 0.0,
},
timeout=30.0
)
if response.status_code != 200:
return 0.5, f"API error: {response.status_code}"
result = response.json()["choices"][0]["message"]["content"]
import re
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
score = float(score_match.group(1)) if score_match else 0.5
reason_match = re.search(r'BEGRUENDUNG:\s*(.+)', result, re.DOTALL)
reason = reason_match.group(1).strip() if reason_match else result
return min(max(score, 0.0), 1.0), reason
except Exception as e:
return 0.5, f"Grading error: {str(e)}"
async def grade_documents_batch(
query: str,
documents: List[str],
) -> List[Tuple[float, str]]:
"""
Grade multiple documents for relevance.
Returns list of (score, reason) tuples.
"""
results = []
for doc in documents:
score, reason = await grade_document_relevance(query, doc)
results.append((score, reason))
return results
async def filter_relevant_documents(
query: str,
documents: List[Dict],
threshold: float = RELEVANCE_THRESHOLD,
) -> Tuple[List[Dict], List[Dict]]:
"""
Filter documents by relevance, separating relevant from irrelevant.
Args:
query: The search query
documents: List of document dicts with 'text' field
threshold: Minimum relevance score to keep
Returns:
Tuple of (relevant_docs, filtered_out_docs)
"""
relevant = []
filtered = []
for doc in documents:
text = doc.get("text", "")
score, reason = await grade_document_relevance(query, text)
doc_with_grade = doc.copy()
doc_with_grade["relevance_score"] = score
doc_with_grade["relevance_reason"] = reason
if score >= threshold:
relevant.append(doc_with_grade)
else:
filtered.append(doc_with_grade)
# Sort relevant by score
relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
return relevant, filtered
async def decide_retrieval_strategy(
query: str,
documents: List[Dict],
attempt: int = 1,
) -> Tuple[RetrievalDecision, Dict]:
"""
Decide what to do based on current retrieval results.
Args:
query: The search query
documents: Retrieved documents with relevance scores
attempt: Current retrieval attempt number
Returns:
Tuple of (decision, metadata)
"""
if not documents:
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
# Check average relevance
scores = [doc.get("relevance_score", 0.5) for doc in documents]
avg_score = sum(scores) / len(scores)
max_score = max(scores)
if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
return RetrievalDecision.SUFFICIENT, {
"avg_relevance": avg_score,
"max_relevance": max_score,
"doc_count": len(documents),
}
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
if max_score >= RELEVANCE_THRESHOLD * 0.5:
# At least some relevant context, proceed with caution
return RetrievalDecision.SUFFICIENT, {
"avg_relevance": avg_score,
"warning": "Low relevance after max attempts",
}
return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
if avg_score < 0.3:
return RetrievalDecision.REFORMULATE, {
"reason": "Very low relevance, query reformulation needed",
"avg_relevance": avg_score,
}
return RetrievalDecision.NEEDS_MORE, {
"reason": "Moderate relevance, retrieving more documents",
"avg_relevance": avg_score,
}
async def grade_answer_groundedness(
answer: str,
contexts: List[str],
) -> Tuple[float, List[str]]:
"""
Grade whether an answer is grounded in the provided contexts.
Returns:
Tuple of (grounding_score, list of unsupported claims)
"""
if not OPENAI_API_KEY:
return 0.5, ["LLM not configured for grounding check"]
context_text = "\n---\n".join(contexts[:5])
prompt = f"""Analysiere, ob die folgende Antwort vollstaendig durch die Kontexte gestuetzt wird.
KONTEXTE:
{context_text}
ANTWORT:
{answer}
Identifiziere:
1. Welche Aussagen sind durch die Kontexte belegt?
2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
Antworte im Format:
SCORE: [0.0-1.0] (1.0 = vollstaendig belegt)
NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": SELF_RAG_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 300,
"temperature": 0.0,
},
timeout=30.0
)
if response.status_code != 200:
return 0.5, [f"API error: {response.status_code}"]
result = response.json()["choices"][0]["message"]["content"]
import re
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
score = float(score_match.group(1)) if score_match else 0.5
unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
if unsupported_text.lower() == "keine":
unsupported = []
else:
unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
return min(max(score, 0.0), 1.0), unsupported
except Exception as e:
return 0.5, [f"Grounding check error: {str(e)}"]

View File

@@ -0,0 +1,255 @@
"""
Self-RAG Retrieval — query reformulation, retrieval loop, info.
Extracted from self_rag.py for modularity.
IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
When enabled, search queries and retrieved documents are sent to OpenAI API
for relevance grading and query reformulation.
"""
import os
from typing import List, Dict, Optional
import httpx
from self_rag_grading import (
SELF_RAG_ENABLED,
OPENAI_API_KEY,
SELF_RAG_MODEL,
RELEVANCE_THRESHOLD,
GROUNDING_THRESHOLD,
MAX_RETRIEVAL_ATTEMPTS,
RetrievalDecision,
filter_relevant_documents,
decide_retrieval_strategy,
)
async def reformulate_query(
original_query: str,
context: Optional[str] = None,
previous_results_summary: Optional[str] = None,
) -> str:
"""
Reformulate a query to improve retrieval.
Uses LLM to generate a better query based on:
- Original query
- Optional context (subject, niveau, etc.)
- Summary of why previous retrieval failed
"""
if not OPENAI_API_KEY:
# Simple reformulation: expand abbreviations, add synonyms
reformulated = original_query
expansions = {
"EA": "erhoehtes Anforderungsniveau",
"eA": "erhoehtes Anforderungsniveau",
"GA": "grundlegendes Anforderungsniveau",
"gA": "grundlegendes Anforderungsniveau",
"AFB": "Anforderungsbereich",
"Abi": "Abitur",
}
for abbr, expansion in expansions.items():
if abbr in original_query:
reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
return reformulated
prompt = f"""Du bist ein Experte fuer deutsche Bildungsstandards und Pruefungsanforderungen.
Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
ORIGINAL: {original_query}
{f"KONTEXT: {context}" if context else ""}
{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
Formuliere die Anfrage so um, dass sie:
1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
2. Relevante Synonyme oder verwandte Begriffe einschliesst
3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
Antworte NUR mit der umformulierten Suchanfrage, ohne Erklaerung."""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": SELF_RAG_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 100,
"temperature": 0.3,
},
timeout=30.0
)
if response.status_code != 200:
return original_query
return response.json()["choices"][0]["message"]["content"].strip()
except Exception:
return original_query
async def self_rag_retrieve(
query: str,
search_func,
subject: Optional[str] = None,
niveau: Optional[str] = None,
initial_top_k: int = 10,
final_top_k: int = 5,
**search_kwargs
) -> Dict:
"""
Perform Self-RAG enhanced retrieval with reflection and correction.
This implements a retrieval loop that:
1. Retrieves initial documents
2. Grades them for relevance
3. Decides if more retrieval is needed
4. Reformulates query if necessary
5. Returns filtered, high-quality context
Args:
query: The search query
search_func: Async function to perform the actual search
subject: Optional subject context
niveau: Optional niveau context
initial_top_k: Number of documents for initial retrieval
final_top_k: Maximum documents to return
**search_kwargs: Additional args for search_func
Returns:
Dict with results, metadata, and reflection trace
"""
if not SELF_RAG_ENABLED:
# Fall back to simple search
results = await search_func(query=query, limit=final_top_k, **search_kwargs)
return {
"results": results,
"self_rag_enabled": False,
"query_used": query,
}
trace = []
current_query = query
attempt = 1
while attempt <= MAX_RETRIEVAL_ATTEMPTS:
# Step 1: Retrieve documents
results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
trace.append({
"attempt": attempt,
"query": current_query,
"retrieved_count": len(results) if results else 0,
})
if not results:
attempt += 1
if attempt <= MAX_RETRIEVAL_ATTEMPTS:
current_query = await reformulate_query(
query,
context=f"Fach: {subject}" if subject else None,
previous_results_summary="Keine Dokumente gefunden"
)
trace[-1]["action"] = "reformulate"
trace[-1]["new_query"] = current_query
continue
# Step 2: Grade documents for relevance
relevant, filtered = await filter_relevant_documents(current_query, results)
trace[-1]["relevant_count"] = len(relevant)
trace[-1]["filtered_count"] = len(filtered)
# Step 3: Decide what to do
decision, decision_meta = await decide_retrieval_strategy(
current_query, relevant, attempt
)
trace[-1]["decision"] = decision.value
trace[-1]["decision_meta"] = decision_meta
if decision == RetrievalDecision.SUFFICIENT:
# We have good context, return it
return {
"results": relevant[:final_top_k],
"self_rag_enabled": True,
"query_used": current_query,
"original_query": query if current_query != query else None,
"attempts": attempt,
"decision": decision.value,
"trace": trace,
"filtered_out_count": len(filtered),
}
elif decision == RetrievalDecision.REFORMULATE:
# Reformulate and try again
avg_score = decision_meta.get("avg_relevance", 0)
current_query = await reformulate_query(
query,
context=f"Fach: {subject}" if subject else None,
previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
)
trace[-1]["action"] = "reformulate"
trace[-1]["new_query"] = current_query
elif decision == RetrievalDecision.NEEDS_MORE:
# Retrieve more with expanded query
current_query = f"{current_query} Bewertungskriterien Anforderungen"
trace[-1]["action"] = "expand_query"
trace[-1]["new_query"] = current_query
elif decision == RetrievalDecision.FALLBACK:
# Return what we have, even if not ideal
return {
"results": (relevant or results)[:final_top_k],
"self_rag_enabled": True,
"query_used": current_query,
"original_query": query if current_query != query else None,
"attempts": attempt,
"decision": decision.value,
"warning": "Fallback mode - low relevance context",
"trace": trace,
}
attempt += 1
# Max attempts reached
return {
"results": results[:final_top_k] if results else [],
"self_rag_enabled": True,
"query_used": current_query,
"original_query": query if current_query != query else None,
"attempts": attempt - 1,
"decision": "max_attempts",
"warning": "Max retrieval attempts reached",
"trace": trace,
}
def get_self_rag_info() -> dict:
"""Get information about Self-RAG configuration."""
return {
"enabled": SELF_RAG_ENABLED,
"llm_configured": bool(OPENAI_API_KEY),
"model": SELF_RAG_MODEL,
"relevance_threshold": RELEVANCE_THRESHOLD,
"grounding_threshold": GROUNDING_THRESHOLD,
"max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
"features": [
"document_grading",
"relevance_filtering",
"query_reformulation",
"answer_grounding_check",
"retrieval_decision",
],
"sends_data_externally": True, # ALWAYS true when enabled
"privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
"default_enabled": False, # Disabled by default for privacy
}

View File

@@ -0,0 +1,164 @@
"""
Grid Detection Models v4
Data classes for OCR grid detection results.
Coordinates use percentage (0-100) and mm (A4 format).
"""
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Dict, Any
# A4 dimensions
A4_WIDTH_MM = 210.0
A4_HEIGHT_MM = 297.0
# Column margin (1mm)
COLUMN_MARGIN_MM = 1.0
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
class CellStatus(str, Enum):
EMPTY = "empty"
RECOGNIZED = "recognized"
PROBLEMATIC = "problematic"
MANUAL = "manual"
class ColumnType(str, Enum):
ENGLISH = "english"
GERMAN = "german"
EXAMPLE = "example"
UNKNOWN = "unknown"
@dataclass
class OCRRegion:
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
text: str
confidence: float
x: float # X position as percentage of page width
y: float # Y position as percentage of page height
width: float # Width as percentage of page width
height: float # Height as percentage of page height
@property
def x_mm(self) -> float:
return round(self.x / 100 * A4_WIDTH_MM, 1)
@property
def y_mm(self) -> float:
return round(self.y / 100 * A4_HEIGHT_MM, 1)
@property
def width_mm(self) -> float:
return round(self.width / 100 * A4_WIDTH_MM, 1)
@property
def height_mm(self) -> float:
return round(self.height / 100 * A4_HEIGHT_MM, 2)
@property
def center_x(self) -> float:
return self.x + self.width / 2
@property
def center_y(self) -> float:
return self.y + self.height / 2
@property
def right(self) -> float:
return self.x + self.width
@property
def bottom(self) -> float:
return self.y + self.height
@dataclass
class GridCell:
"""A cell in the detected grid with coordinates in percentage (0-100)."""
row: int
col: int
x: float
y: float
width: float
height: float
text: str = ""
confidence: float = 0.0
status: CellStatus = CellStatus.EMPTY
column_type: ColumnType = ColumnType.UNKNOWN
logical_row: int = 0
logical_col: int = 0
is_continuation: bool = False
@property
def x_mm(self) -> float:
return round(self.x / 100 * A4_WIDTH_MM, 1)
@property
def y_mm(self) -> float:
return round(self.y / 100 * A4_HEIGHT_MM, 1)
@property
def width_mm(self) -> float:
return round(self.width / 100 * A4_WIDTH_MM, 1)
@property
def height_mm(self) -> float:
return round(self.height / 100 * A4_HEIGHT_MM, 2)
def to_dict(self) -> dict:
return {
"row": self.row,
"col": self.col,
"x": round(self.x, 2),
"y": round(self.y, 2),
"width": round(self.width, 2),
"height": round(self.height, 2),
"x_mm": self.x_mm,
"y_mm": self.y_mm,
"width_mm": self.width_mm,
"height_mm": self.height_mm,
"text": self.text,
"confidence": self.confidence,
"status": self.status.value,
"column_type": self.column_type.value,
"logical_row": self.logical_row,
"logical_col": self.logical_col,
"is_continuation": self.is_continuation,
}
@dataclass
class GridResult:
"""Result of grid detection."""
rows: int = 0
columns: int = 0
cells: List[List[GridCell]] = field(default_factory=list)
column_types: List[str] = field(default_factory=list)
column_boundaries: List[float] = field(default_factory=list)
row_boundaries: List[float] = field(default_factory=list)
deskew_angle: float = 0.0
stats: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict:
cells_dicts = []
for row_cells in self.cells:
cells_dicts.append([c.to_dict() for c in row_cells])
return {
"rows": self.rows,
"columns": self.columns,
"cells": cells_dicts,
"column_types": self.column_types,
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
"deskew_angle": round(self.deskew_angle, 2),
"stats": self.stats,
"page_dimensions": {
"width_mm": A4_WIDTH_MM,
"height_mm": A4_HEIGHT_MM,
"format": "A4",
},
}

View File

@@ -10,166 +10,21 @@ Lizenz: Apache 2.0 (kommerziell nutzbar)
import math
import logging
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple
from typing import List
from .grid_detection_models import (
A4_WIDTH_MM,
A4_HEIGHT_MM,
COLUMN_MARGIN_MM,
CellStatus,
ColumnType,
OCRRegion,
GridCell,
GridResult,
)
logger = logging.getLogger(__name__)
# A4 dimensions
A4_WIDTH_MM = 210.0
A4_HEIGHT_MM = 297.0
# Column margin (1mm)
COLUMN_MARGIN_MM = 1.0
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
class CellStatus(str, Enum):
EMPTY = "empty"
RECOGNIZED = "recognized"
PROBLEMATIC = "problematic"
MANUAL = "manual"
class ColumnType(str, Enum):
ENGLISH = "english"
GERMAN = "german"
EXAMPLE = "example"
UNKNOWN = "unknown"
@dataclass
class OCRRegion:
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
text: str
confidence: float
x: float # X position as percentage of page width
y: float # Y position as percentage of page height
width: float # Width as percentage of page width
height: float # Height as percentage of page height
@property
def x_mm(self) -> float:
return round(self.x / 100 * A4_WIDTH_MM, 1)
@property
def y_mm(self) -> float:
return round(self.y / 100 * A4_HEIGHT_MM, 1)
@property
def width_mm(self) -> float:
return round(self.width / 100 * A4_WIDTH_MM, 1)
@property
def height_mm(self) -> float:
return round(self.height / 100 * A4_HEIGHT_MM, 2)
@property
def center_x(self) -> float:
return self.x + self.width / 2
@property
def center_y(self) -> float:
return self.y + self.height / 2
@property
def right(self) -> float:
return self.x + self.width
@property
def bottom(self) -> float:
return self.y + self.height
@dataclass
class GridCell:
"""A cell in the detected grid with coordinates in percentage (0-100)."""
row: int
col: int
x: float
y: float
width: float
height: float
text: str = ""
confidence: float = 0.0
status: CellStatus = CellStatus.EMPTY
column_type: ColumnType = ColumnType.UNKNOWN
logical_row: int = 0
logical_col: int = 0
is_continuation: bool = False
@property
def x_mm(self) -> float:
return round(self.x / 100 * A4_WIDTH_MM, 1)
@property
def y_mm(self) -> float:
return round(self.y / 100 * A4_HEIGHT_MM, 1)
@property
def width_mm(self) -> float:
return round(self.width / 100 * A4_WIDTH_MM, 1)
@property
def height_mm(self) -> float:
return round(self.height / 100 * A4_HEIGHT_MM, 2)
def to_dict(self) -> dict:
return {
"row": self.row,
"col": self.col,
"x": round(self.x, 2),
"y": round(self.y, 2),
"width": round(self.width, 2),
"height": round(self.height, 2),
"x_mm": self.x_mm,
"y_mm": self.y_mm,
"width_mm": self.width_mm,
"height_mm": self.height_mm,
"text": self.text,
"confidence": self.confidence,
"status": self.status.value,
"column_type": self.column_type.value,
"logical_row": self.logical_row,
"logical_col": self.logical_col,
"is_continuation": self.is_continuation,
}
@dataclass
class GridResult:
"""Result of grid detection."""
rows: int = 0
columns: int = 0
cells: List[List[GridCell]] = field(default_factory=list)
column_types: List[str] = field(default_factory=list)
column_boundaries: List[float] = field(default_factory=list)
row_boundaries: List[float] = field(default_factory=list)
deskew_angle: float = 0.0
stats: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict:
cells_dicts = []
for row_cells in self.cells:
cells_dicts.append([c.to_dict() for c in row_cells])
return {
"rows": self.rows,
"columns": self.columns,
"cells": cells_dicts,
"column_types": self.column_types,
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
"deskew_angle": round(self.deskew_angle, 2),
"stats": self.stats,
"page_dimensions": {
"width_mm": A4_WIDTH_MM,
"height_mm": A4_HEIGHT_MM,
"format": "A4",
},
}
class GridDetectionService:
"""Detect grid/table structure from OCR bounding-box regions."""
@@ -184,7 +39,7 @@ class GridDetectionService:
"""Calculate page skew angle from OCR region positions.
Uses left-edge alignment of regions to detect consistent tilt.
Returns angle in degrees, clamped to ±5°.
Returns angle in degrees, clamped to +/-5 degrees.
"""
if len(regions) < 3:
return 0.0
@@ -229,12 +84,12 @@ class GridDetectionService:
slope = (n * sum_xy - sum_y * sum_x) / denom
# Convert slope to angle (slope is dx/dy in percent space)
# Adjust for aspect ratio: A4 is 210/297 0.707
# Adjust for aspect ratio: A4 is 210/297 ~ 0.707
aspect = A4_WIDTH_MM / A4_HEIGHT_MM
angle_rad = math.atan(slope * aspect)
angle_deg = math.degrees(angle_rad)
# Clamp to ±5°
# Clamp to +/-5 degrees
return max(-5.0, min(5.0, round(angle_deg, 2)))
def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]:

View File

@@ -1,594 +1,25 @@
"""
SmartSpellChecker — Language-aware OCR post-correction without LLMs.
SmartSpellChecker — barrel re-export.
Uses pyspellchecker (MIT) with dual EN+DE dictionaries for:
- Automatic language detection per word (dual-dictionary heuristic)
- OCR error correction (digit↔letter, umlauts, transpositions)
- Context-based disambiguation (a/I, l/I) via bigram lookup
- Mixed-language support for example sentences
All implementation split into:
smart_spell_core — init, data types, language detection, word correction
smart_spell_text — full text correction, boundary repair, context split
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Init
# ---------------------------------------------------------------------------
try:
from spellchecker import SpellChecker as _SpellChecker
_en_spell = _SpellChecker(language='en', distance=1)
_de_spell = _SpellChecker(language='de', distance=1)
_AVAILABLE = True
except ImportError:
_AVAILABLE = False
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
Lang = Literal["en", "de", "both", "unknown"]
# ---------------------------------------------------------------------------
# Bigram context for a/I disambiguation
# ---------------------------------------------------------------------------
# Words that commonly follow "I" (subject pronoun → verb/modal)
_I_FOLLOWERS: frozenset = frozenset({
"am", "was", "have", "had", "do", "did", "will", "would", "can",
"could", "should", "shall", "may", "might", "must",
"think", "know", "see", "want", "need", "like", "love", "hate",
"go", "went", "come", "came", "say", "said", "get", "got",
"make", "made", "take", "took", "give", "gave", "tell", "told",
"feel", "felt", "find", "found", "believe", "hope", "wish",
"remember", "forget", "understand", "mean", "meant",
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
"really", "just", "also", "always", "never", "often", "sometimes",
})
# Words that commonly follow "a" (article → noun/adjective)
_A_FOLLOWERS: frozenset = frozenset({
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
"long", "short", "big", "small", "large", "huge", "tiny",
"nice", "beautiful", "wonderful", "terrible", "horrible",
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
"book", "car", "house", "room", "school", "teacher", "student",
"day", "week", "month", "year", "time", "place", "way",
"friend", "family", "person", "problem", "question", "story",
"very", "really", "quite", "rather", "pretty", "single",
})
# Digit→letter substitutions (OCR confusion)
_DIGIT_SUBS: Dict[str, List[str]] = {
'0': ['o', 'O'],
'1': ['l', 'I'],
'5': ['s', 'S'],
'6': ['g', 'G'],
'8': ['b', 'B'],
'|': ['I', 'l'],
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" → "pl")
}
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
# Umlaut confusion: OCR drops dots (ü→u, ä→a, ö→o)
_UMLAUT_MAP = {
'a': 'ä', 'o': 'ö', 'u': 'ü', 'i': 'ü',
'A': 'Ä', 'O': 'Ö', 'U': 'Ü', 'I': 'Ü',
}
# Tokenizer — includes | and / so OCR artifacts like "p/" are treated as words
_TOKEN_RE = re.compile(r"([A-Za-zÄÖÜäöüß'|/]+)([^A-Za-zÄÖÜäöüß'|/]*)")
# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------
@dataclass
class CorrectionResult:
original: str
corrected: str
lang_detected: Lang
changed: bool
changes: List[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Core class
# ---------------------------------------------------------------------------
class SmartSpellChecker:
"""Language-aware OCR spell checker using pyspellchecker (no LLM)."""
def __init__(self):
if not _AVAILABLE:
raise RuntimeError("pyspellchecker not installed")
self.en = _en_spell
self.de = _de_spell
# --- Language detection ---
def detect_word_lang(self, word: str) -> Lang:
"""Detect language of a single word using dual-dict heuristic."""
w = word.lower().strip(".,;:!?\"'()")
if not w:
return "unknown"
in_en = bool(self.en.known([w]))
in_de = bool(self.de.known([w]))
if in_en and in_de:
return "both"
if in_en:
return "en"
if in_de:
return "de"
return "unknown"
def detect_text_lang(self, text: str) -> Lang:
"""Detect dominant language of a text string (sentence/phrase)."""
words = re.findall(r"[A-Za-zÄÖÜäöüß]+", text)
if not words:
return "unknown"
en_count = 0
de_count = 0
for w in words:
lang = self.detect_word_lang(w)
if lang == "en":
en_count += 1
elif lang == "de":
de_count += 1
# "both" doesn't count for either
if en_count > de_count:
return "en"
if de_count > en_count:
return "de"
if en_count == de_count and en_count > 0:
return "both"
return "unknown"
# --- Single-word correction ---
def _known(self, word: str) -> bool:
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
w = word.lower()
if bool(self.en.known([w])) or bool(self.de.known([w])):
return True
# Also accept known abbreviations (sth, sb, adj, etc.)
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
if w in _KNOWN_ABBREVIATIONS:
return True
except ImportError:
pass
return False
def _word_freq(self, word: str) -> float:
"""Get word frequency (max of EN and DE)."""
w = word.lower()
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
def _known_in(self, word: str, lang: str) -> bool:
"""True if word is known in a specific language dictionary."""
w = word.lower()
spell = self.en if lang == "en" else self.de
return bool(spell.known([w]))
def correct_word(self, word: str, lang: str = "en",
prev_word: str = "", next_word: str = "") -> Optional[str]:
"""Correct a single word for the given language.
Returns None if no correction needed, or the corrected string.
Args:
word: The word to check/correct
lang: Expected language ("en" or "de")
prev_word: Previous word (for context)
next_word: Next word (for context)
"""
if not word or not word.strip():
return None
# Skip numbers, abbreviations with dots, very short tokens
if word.isdigit() or '.' in word:
return None
# Skip IPA/phonetic content in brackets
if '[' in word or ']' in word:
return None
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
# 1. Already known → no fix
if self._known(word):
# But check a/I disambiguation for single-char words
if word.lower() in ('l', '|') and next_word:
return self._disambiguate_a_I(word, next_word)
return None
# 2. Digit/pipe substitution
if has_suspicious:
if word == '|':
return 'I'
# Try single-char substitutions
for i, ch in enumerate(word):
if ch not in _DIGIT_SUBS:
continue
for replacement in _DIGIT_SUBS[ch]:
candidate = word[:i] + replacement + word[i + 1:]
if self._known(candidate):
return candidate
# Try multi-char substitution (e.g., "sch00l" → "school")
multi = self._try_multi_digit_sub(word)
if multi:
return multi
# 3. Umlaut correction (German)
if lang == "de" and len(word) >= 3 and word.isalpha():
umlaut_fix = self._try_umlaut_fix(word)
if umlaut_fix:
return umlaut_fix
# 4. General spell correction
if not has_suspicious and len(word) >= 3 and word.isalpha():
# Safety: don't correct if the word is valid in the OTHER language
# (either directly or via umlaut fix)
other_lang = "de" if lang == "en" else "en"
if self._known_in(word, other_lang):
return None
if other_lang == "de" and self._try_umlaut_fix(word):
return None # has a valid DE umlaut variant → don't touch
spell = self.en if lang == "en" else self.de
correction = spell.correction(word.lower())
if correction and correction != word.lower():
if word[0].isupper():
correction = correction[0].upper() + correction[1:]
if self._known(correction):
return correction
return None
# --- Multi-digit substitution ---
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
"""Try replacing multiple digits simultaneously."""
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
if len(positions) < 1 or len(positions) > 4:
return None
# Try all combinations (max 2^4 = 16 for 4 positions)
chars = list(word)
best = None
self._multi_sub_recurse(chars, positions, 0, best_result=[None])
return self._multi_sub_recurse_result
_multi_sub_recurse_result: Optional[str] = None
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
"""Try replacing multiple digits simultaneously using BFS."""
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
if not positions or len(positions) > 4:
return None
# BFS over substitution combinations
queue = [list(word)]
for pos, ch in positions:
next_queue = []
for current in queue:
# Keep original
next_queue.append(current[:])
# Try each substitution
for repl in _DIGIT_SUBS[ch]:
variant = current[:]
variant[pos] = repl
next_queue.append(variant)
queue = next_queue
# Check which combinations produce known words
for combo in queue:
candidate = "".join(combo)
if candidate != word and self._known(candidate):
return candidate
return None
# --- Umlaut fix ---
def _try_umlaut_fix(self, word: str) -> Optional[str]:
"""Try single-char umlaut substitutions for German words."""
for i, ch in enumerate(word):
if ch in _UMLAUT_MAP:
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
if self._known(candidate):
return candidate
return None
# --- Boundary repair (shifted word boundaries) ---
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
"""Fix shifted word boundaries between adjacent tokens.
OCR sometimes shifts the boundary: "at sth.""ats th."
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
Returns (fixed_word1, fixed_word2) or None.
"""
# Import known abbreviations for vocabulary context
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
except ImportError:
_KNOWN_ABBREVIATIONS = set()
# Strip trailing punctuation for checking, preserve for result
w2_stripped = word2.rstrip(".,;:!?")
w2_punct = word2[len(w2_stripped):]
# Try shifting 1-2 chars from word1 → word2
for shift in (1, 2):
if len(word1) <= shift:
continue
new_w1 = word1[:-shift]
new_w2_base = word1[-shift:] + w2_stripped
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
# Try shifting 1-2 chars from word2 → word1
for shift in (1, 2):
if len(w2_stripped) <= shift:
continue
new_w1 = word1 + w2_stripped[:shift]
new_w2_base = w2_stripped[shift:]
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
return None
# --- Context-based word split for ambiguous merges ---
# Patterns where a valid word is actually "a" + adjective/noun
_ARTICLE_SPLIT_CANDIDATES = {
# word → (article, remainder) — only when followed by a compatible word
"anew": ("a", "new"),
"areal": ("a", "real"),
"alive": None, # genuinely one word, never split
"alone": None,
"aware": None,
"alike": None,
"apart": None,
"aside": None,
"above": None,
"about": None,
"among": None,
"along": None,
}
def _try_context_split(self, word: str, next_word: str,
prev_word: str) -> Optional[str]:
"""Split words like 'anew''a new' when context indicates a merge.
Only splits when:
- The word is in the split candidates list
- The following word makes sense as a noun (for "a + adj + noun" pattern)
- OR the word is unknown and can be split into article + known word
"""
w_lower = word.lower()
# Check explicit candidates
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
if split is None:
return None # explicitly marked as "don't split"
article, remainder = split
# Only split if followed by a word (noun pattern)
if next_word and next_word[0].islower():
return f"{article} {remainder}"
# Also split if remainder + next_word makes a common phrase
if next_word and self._known(next_word):
return f"{article} {remainder}"
# Generic: if word starts with 'a' and rest is a known adjective/word
if (len(word) >= 4 and word[0].lower() == 'a'
and not self._known(word) # only for UNKNOWN words
and self._known(word[1:])):
return f"a {word[1:]}"
return None
# --- a/I disambiguation ---
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
nw = next_word.lower().strip(".,;:!?")
if nw in _I_FOLLOWERS:
return "I"
if nw in _A_FOLLOWERS:
return "a"
# Fallback: check if next word is more commonly a verb (→I) or noun/adj (→a)
# Simple heuristic: if next word starts with uppercase (and isn't first in sentence)
# it's likely a German noun following "I"... but in English context, uppercase
# after "I" is unusual.
return None # uncertain, don't change
# --- Full text correction ---
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
"""Correct a full text string (field value).
Three passes:
1. Boundary repair — fix shifted word boundaries between adjacent tokens
2. Context split — split ambiguous merges (anew → a new)
3. Per-word correction — spell check individual words
Args:
text: The text to correct
lang: Expected language ("en" or "de")
"""
if not text or not text.strip():
return CorrectionResult(text, text, "unknown", False)
detected = self.detect_text_lang(text) if lang == "auto" else lang
effective_lang = detected if detected in ("en", "de") else "en"
changes: List[str] = []
tokens = list(_TOKEN_RE.finditer(text))
# Extract token list: [(word, separator), ...]
token_list: List[List[str]] = [] # [[word, sep], ...]
for m in tokens:
token_list.append([m.group(1), m.group(2)])
# --- Pass 1: Boundary repair between adjacent unknown words ---
# Import abbreviations for the heuristic below
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
except ImportError:
_ABBREVS = set()
for i in range(len(token_list) - 1):
w1 = token_list[i][0]
w2_raw = token_list[i + 1][0]
# Skip boundary repair for IPA/bracket content
# Brackets may be in the token OR in the adjacent separators
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
sep_after_w1 = token_list[i][1]
sep_after_w2 = token_list[i + 1][1]
has_bracket = (
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
or ']' in sep_after_w1 # w1 text was inside [brackets]
or '[' in sep_after_w1 # w2 starts a bracket
or ']' in sep_after_w2 # w2 text was inside [brackets]
or '[' in sep_before_w1 # w1 starts a bracket
)
if has_bracket:
continue
# Include trailing punct from separator in w2 for abbreviation matching
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
# Try boundary repair — always, even if both words are valid.
# Use word-frequency scoring to decide if repair is better.
repair = self._try_boundary_repair(w1, w2_with_punct)
if not repair and w2_with_punct != w2_raw:
repair = self._try_boundary_repair(w1, w2_raw)
if repair:
new_w1, new_w2_full = repair
new_w2_base = new_w2_full.rstrip(".,;:!?")
# Frequency-based scoring: product of word frequencies
# Higher product = more common word pair = better
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
# Abbreviation bonus: if repair produces a known abbreviation
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
if has_abbrev:
# Accept abbreviation repair ONLY if at least one of the
# original words is rare/unknown (prevents "Can I" → "Ca nI"
# where both original words are common and correct).
# "Rare" = frequency < 1e-6 (covers "ats", "th" but not "Can", "I")
RARE_THRESHOLD = 1e-6
orig_both_common = (
self._word_freq(w1) > RARE_THRESHOLD
and self._word_freq(w2_raw) > RARE_THRESHOLD
)
if not orig_both_common:
new_freq = max(new_freq, old_freq * 10)
else:
has_abbrev = False # both originals common → don't trust
# Accept if repair produces a more frequent word pair
# (threshold: at least 5x more frequent to avoid false positives)
if new_freq > old_freq * 5:
new_w2_punct = new_w2_full[len(new_w2_base):]
changes.append(f"{w1} {w2_raw}{new_w1} {new_w2_base}")
token_list[i][0] = new_w1
token_list[i + 1][0] = new_w2_base
if new_w2_punct:
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
# --- Pass 2: Context split (anew → a new) ---
expanded: List[List[str]] = []
for i, (word, sep) in enumerate(token_list):
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
split = self._try_context_split(word, next_word, prev_word)
if split and split != word:
changes.append(f"{word}{split}")
expanded.append([split, sep])
else:
expanded.append([word, sep])
token_list = expanded
# --- Pass 3: Per-word correction ---
parts: List[str] = []
# Preserve any leading text before the first token match
# (e.g., "(= " before "I won and he lost.")
first_start = tokens[0].start() if tokens else 0
if first_start > 0:
parts.append(text[:first_start])
for i, (word, sep) in enumerate(token_list):
# Skip words inside IPA brackets (brackets land in separators)
prev_sep = token_list[i - 1][1] if i > 0 else ""
if '[' in prev_sep or ']' in sep:
parts.append(word)
parts.append(sep)
continue
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
correction = self.correct_word(
word, lang=effective_lang,
prev_word=prev_word, next_word=next_word,
)
if correction and correction != word:
changes.append(f"{word}{correction}")
parts.append(correction)
else:
parts.append(word)
parts.append(sep)
# Append any trailing text
last_end = tokens[-1].end() if tokens else 0
if last_end < len(text):
parts.append(text[last_end:])
corrected = "".join(parts)
return CorrectionResult(
original=text,
corrected=corrected,
lang_detected=detected,
changed=corrected != text,
changes=changes,
)
# --- Vocabulary entry correction ---
def correct_vocab_entry(self, english: str, german: str,
example: str = "") -> Dict[str, CorrectionResult]:
"""Correct a full vocabulary entry (EN + DE + example).
Uses column position to determine language — the most reliable signal.
"""
results = {}
results["english"] = self.correct_text(english, lang="en")
results["german"] = self.correct_text(german, lang="de")
if example:
# For examples, auto-detect language
results["example"] = self.correct_text(example, lang="auto")
return results
# Core: data types, lang detection (re-exported for tests)
from smart_spell_core import ( # noqa: F401
_AVAILABLE,
_DIGIT_SUBS,
_SUSPICIOUS_CHARS,
_UMLAUT_MAP,
_TOKEN_RE,
_I_FOLLOWERS,
_A_FOLLOWERS,
CorrectionResult,
Lang,
)
# Text: SmartSpellChecker class (the main public API)
from smart_spell_text import SmartSpellChecker # noqa: F401

View File

@@ -0,0 +1,298 @@
"""
SmartSpellChecker Core — init, data types, language detection, word correction.
Extracted from smart_spell.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Init
# ---------------------------------------------------------------------------
try:
from spellchecker import SpellChecker as _SpellChecker
_en_spell = _SpellChecker(language='en', distance=1)
_de_spell = _SpellChecker(language='de', distance=1)
_AVAILABLE = True
except ImportError:
_AVAILABLE = False
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
Lang = Literal["en", "de", "both", "unknown"]
# ---------------------------------------------------------------------------
# Bigram context for a/I disambiguation
# ---------------------------------------------------------------------------
# Words that commonly follow "I" (subject pronoun -> verb/modal)
_I_FOLLOWERS: frozenset = frozenset({
"am", "was", "have", "had", "do", "did", "will", "would", "can",
"could", "should", "shall", "may", "might", "must",
"think", "know", "see", "want", "need", "like", "love", "hate",
"go", "went", "come", "came", "say", "said", "get", "got",
"make", "made", "take", "took", "give", "gave", "tell", "told",
"feel", "felt", "find", "found", "believe", "hope", "wish",
"remember", "forget", "understand", "mean", "meant",
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
"really", "just", "also", "always", "never", "often", "sometimes",
})
# Words that commonly follow "a" (article -> noun/adjective)
_A_FOLLOWERS: frozenset = frozenset({
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
"long", "short", "big", "small", "large", "huge", "tiny",
"nice", "beautiful", "wonderful", "terrible", "horrible",
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
"book", "car", "house", "room", "school", "teacher", "student",
"day", "week", "month", "year", "time", "place", "way",
"friend", "family", "person", "problem", "question", "story",
"very", "really", "quite", "rather", "pretty", "single",
})
# Digit->letter substitutions (OCR confusion)
_DIGIT_SUBS: Dict[str, List[str]] = {
'0': ['o', 'O'],
'1': ['l', 'I'],
'5': ['s', 'S'],
'6': ['g', 'G'],
'8': ['b', 'B'],
'|': ['I', 'l'],
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
}
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
_UMLAUT_MAP = {
'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
}
# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------
@dataclass
class CorrectionResult:
original: str
corrected: str
lang_detected: Lang
changed: bool
changes: List[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Core class — language detection and word-level correction
# ---------------------------------------------------------------------------
class _SmartSpellCoreBase:
"""Base class with language detection and single-word correction.
Not intended for direct use — SmartSpellChecker inherits from this.
"""
def __init__(self):
if not _AVAILABLE:
raise RuntimeError("pyspellchecker not installed")
self.en = _en_spell
self.de = _de_spell
# --- Language detection ---
def detect_word_lang(self, word: str) -> Lang:
"""Detect language of a single word using dual-dict heuristic."""
w = word.lower().strip(".,;:!?\"'()")
if not w:
return "unknown"
in_en = bool(self.en.known([w]))
in_de = bool(self.de.known([w]))
if in_en and in_de:
return "both"
if in_en:
return "en"
if in_de:
return "de"
return "unknown"
def detect_text_lang(self, text: str) -> Lang:
"""Detect dominant language of a text string (sentence/phrase)."""
words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
if not words:
return "unknown"
en_count = 0
de_count = 0
for w in words:
lang = self.detect_word_lang(w)
if lang == "en":
en_count += 1
elif lang == "de":
de_count += 1
# "both" doesn't count for either
if en_count > de_count:
return "en"
if de_count > en_count:
return "de"
if en_count == de_count and en_count > 0:
return "both"
return "unknown"
# --- Single-word correction ---
def _known(self, word: str) -> bool:
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
w = word.lower()
if bool(self.en.known([w])) or bool(self.de.known([w])):
return True
# Also accept known abbreviations (sth, sb, adj, etc.)
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
if w in _KNOWN_ABBREVIATIONS:
return True
except ImportError:
pass
return False
def _word_freq(self, word: str) -> float:
"""Get word frequency (max of EN and DE)."""
w = word.lower()
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
def _known_in(self, word: str, lang: str) -> bool:
"""True if word is known in a specific language dictionary."""
w = word.lower()
spell = self.en if lang == "en" else self.de
return bool(spell.known([w]))
def correct_word(self, word: str, lang: str = "en",
prev_word: str = "", next_word: str = "") -> Optional[str]:
"""Correct a single word for the given language.
Returns None if no correction needed, or the corrected string.
"""
if not word or not word.strip():
return None
# Skip numbers, abbreviations with dots, very short tokens
if word.isdigit() or '.' in word:
return None
# Skip IPA/phonetic content in brackets
if '[' in word or ']' in word:
return None
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
# 1. Already known -> no fix
if self._known(word):
# But check a/I disambiguation for single-char words
if word.lower() in ('l', '|') and next_word:
return self._disambiguate_a_I(word, next_word)
return None
# 2. Digit/pipe substitution
if has_suspicious:
if word == '|':
return 'I'
# Try single-char substitutions
for i, ch in enumerate(word):
if ch not in _DIGIT_SUBS:
continue
for replacement in _DIGIT_SUBS[ch]:
candidate = word[:i] + replacement + word[i + 1:]
if self._known(candidate):
return candidate
# Try multi-char substitution (e.g., "sch00l" -> "school")
multi = self._try_multi_digit_sub(word)
if multi:
return multi
# 3. Umlaut correction (German)
if lang == "de" and len(word) >= 3 and word.isalpha():
umlaut_fix = self._try_umlaut_fix(word)
if umlaut_fix:
return umlaut_fix
# 4. General spell correction
if not has_suspicious and len(word) >= 3 and word.isalpha():
# Safety: don't correct if the word is valid in the OTHER language
other_lang = "de" if lang == "en" else "en"
if self._known_in(word, other_lang):
return None
if other_lang == "de" and self._try_umlaut_fix(word):
return None # has a valid DE umlaut variant -> don't touch
spell = self.en if lang == "en" else self.de
correction = spell.correction(word.lower())
if correction and correction != word.lower():
if word[0].isupper():
correction = correction[0].upper() + correction[1:]
if self._known(correction):
return correction
return None
# --- Multi-digit substitution ---
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
"""Try replacing multiple digits simultaneously using BFS."""
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
if not positions or len(positions) > 4:
return None
# BFS over substitution combinations
queue = [list(word)]
for pos, ch in positions:
next_queue = []
for current in queue:
# Keep original
next_queue.append(current[:])
# Try each substitution
for repl in _DIGIT_SUBS[ch]:
variant = current[:]
variant[pos] = repl
next_queue.append(variant)
queue = next_queue
# Check which combinations produce known words
for combo in queue:
candidate = "".join(combo)
if candidate != word and self._known(candidate):
return candidate
return None
# --- Umlaut fix ---
def _try_umlaut_fix(self, word: str) -> Optional[str]:
"""Try single-char umlaut substitutions for German words."""
for i, ch in enumerate(word):
if ch in _UMLAUT_MAP:
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
if self._known(candidate):
return candidate
return None
# --- a/I disambiguation ---
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
nw = next_word.lower().strip(".,;:!?")
if nw in _I_FOLLOWERS:
return "I"
if nw in _A_FOLLOWERS:
return "a"
return None # uncertain, don't change

View File

@@ -0,0 +1,289 @@
"""
SmartSpellChecker Text — full text correction, boundary repair, context split.
Extracted from smart_spell.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import re
from typing import Dict, List, Optional, Tuple
from smart_spell_core import (
_SmartSpellCoreBase,
_TOKEN_RE,
CorrectionResult,
Lang,
)
class SmartSpellChecker(_SmartSpellCoreBase):
"""Language-aware OCR spell checker using pyspellchecker (no LLM).
Inherits single-word correction from _SmartSpellCoreBase.
Adds text-level passes: boundary repair, context split, full correction.
"""
# --- Boundary repair (shifted word boundaries) ---
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
"""Fix shifted word boundaries between adjacent tokens.
OCR sometimes shifts the boundary: "at sth." -> "ats th."
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
Returns (fixed_word1, fixed_word2) or None.
"""
# Import known abbreviations for vocabulary context
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
except ImportError:
_KNOWN_ABBREVIATIONS = set()
# Strip trailing punctuation for checking, preserve for result
w2_stripped = word2.rstrip(".,;:!?")
w2_punct = word2[len(w2_stripped):]
# Try shifting 1-2 chars from word1 -> word2
for shift in (1, 2):
if len(word1) <= shift:
continue
new_w1 = word1[:-shift]
new_w2_base = word1[-shift:] + w2_stripped
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
# Try shifting 1-2 chars from word2 -> word1
for shift in (1, 2):
if len(w2_stripped) <= shift:
continue
new_w1 = word1 + w2_stripped[:shift]
new_w2_base = w2_stripped[shift:]
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
return None
# --- Context-based word split for ambiguous merges ---
# Patterns where a valid word is actually "a" + adjective/noun
_ARTICLE_SPLIT_CANDIDATES = {
# word -> (article, remainder) -- only when followed by a compatible word
"anew": ("a", "new"),
"areal": ("a", "real"),
"alive": None, # genuinely one word, never split
"alone": None,
"aware": None,
"alike": None,
"apart": None,
"aside": None,
"above": None,
"about": None,
"among": None,
"along": None,
}
def _try_context_split(self, word: str, next_word: str,
prev_word: str) -> Optional[str]:
"""Split words like 'anew' -> 'a new' when context indicates a merge.
Only splits when:
- The word is in the split candidates list
- The following word makes sense as a noun (for "a + adj + noun" pattern)
- OR the word is unknown and can be split into article + known word
"""
w_lower = word.lower()
# Check explicit candidates
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
if split is None:
return None # explicitly marked as "don't split"
article, remainder = split
# Only split if followed by a word (noun pattern)
if next_word and next_word[0].islower():
return f"{article} {remainder}"
# Also split if remainder + next_word makes a common phrase
if next_word and self._known(next_word):
return f"{article} {remainder}"
# Generic: if word starts with 'a' and rest is a known adjective/word
if (len(word) >= 4 and word[0].lower() == 'a'
and not self._known(word) # only for UNKNOWN words
and self._known(word[1:])):
return f"a {word[1:]}"
return None
# --- Full text correction ---
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
"""Correct a full text string (field value).
Three passes:
1. Boundary repair -- fix shifted word boundaries between adjacent tokens
2. Context split -- split ambiguous merges (anew -> a new)
3. Per-word correction -- spell check individual words
"""
if not text or not text.strip():
return CorrectionResult(text, text, "unknown", False)
detected = self.detect_text_lang(text) if lang == "auto" else lang
effective_lang = detected if detected in ("en", "de") else "en"
changes: List[str] = []
tokens = list(_TOKEN_RE.finditer(text))
# Extract token list: [(word, separator), ...]
token_list: List[List[str]] = [] # [[word, sep], ...]
for m in tokens:
token_list.append([m.group(1), m.group(2)])
# --- Pass 1: Boundary repair between adjacent unknown words ---
# Import abbreviations for the heuristic below
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
except ImportError:
_ABBREVS = set()
for i in range(len(token_list) - 1):
w1 = token_list[i][0]
w2_raw = token_list[i + 1][0]
# Skip boundary repair for IPA/bracket content
# Brackets may be in the token OR in the adjacent separators
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
sep_after_w1 = token_list[i][1]
sep_after_w2 = token_list[i + 1][1]
has_bracket = (
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
or ']' in sep_after_w1 # w1 text was inside [brackets]
or '[' in sep_after_w1 # w2 starts a bracket
or ']' in sep_after_w2 # w2 text was inside [brackets]
or '[' in sep_before_w1 # w1 starts a bracket
)
if has_bracket:
continue
# Include trailing punct from separator in w2 for abbreviation matching
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
# Try boundary repair -- always, even if both words are valid.
# Use word-frequency scoring to decide if repair is better.
repair = self._try_boundary_repair(w1, w2_with_punct)
if not repair and w2_with_punct != w2_raw:
repair = self._try_boundary_repair(w1, w2_raw)
if repair:
new_w1, new_w2_full = repair
new_w2_base = new_w2_full.rstrip(".,;:!?")
# Frequency-based scoring: product of word frequencies
# Higher product = more common word pair = better
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
# Abbreviation bonus: if repair produces a known abbreviation
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
if has_abbrev:
# Accept abbreviation repair ONLY if at least one of the
# original words is rare/unknown (prevents "Can I" -> "Ca nI"
# where both original words are common and correct).
RARE_THRESHOLD = 1e-6
orig_both_common = (
self._word_freq(w1) > RARE_THRESHOLD
and self._word_freq(w2_raw) > RARE_THRESHOLD
)
if not orig_both_common:
new_freq = max(new_freq, old_freq * 10)
else:
has_abbrev = False # both originals common -> don't trust
# Accept if repair produces a more frequent word pair
# (threshold: at least 5x more frequent to avoid false positives)
if new_freq > old_freq * 5:
new_w2_punct = new_w2_full[len(new_w2_base):]
changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
token_list[i][0] = new_w1
token_list[i + 1][0] = new_w2_base
if new_w2_punct:
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
# --- Pass 2: Context split (anew -> a new) ---
expanded: List[List[str]] = []
for i, (word, sep) in enumerate(token_list):
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
split = self._try_context_split(word, next_word, prev_word)
if split and split != word:
changes.append(f"{word}\u2192{split}")
expanded.append([split, sep])
else:
expanded.append([word, sep])
token_list = expanded
# --- Pass 3: Per-word correction ---
parts: List[str] = []
# Preserve any leading text before the first token match
first_start = tokens[0].start() if tokens else 0
if first_start > 0:
parts.append(text[:first_start])
for i, (word, sep) in enumerate(token_list):
# Skip words inside IPA brackets (brackets land in separators)
prev_sep = token_list[i - 1][1] if i > 0 else ""
if '[' in prev_sep or ']' in sep:
parts.append(word)
parts.append(sep)
continue
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
correction = self.correct_word(
word, lang=effective_lang,
prev_word=prev_word, next_word=next_word,
)
if correction and correction != word:
changes.append(f"{word}\u2192{correction}")
parts.append(correction)
else:
parts.append(word)
parts.append(sep)
# Append any trailing text
last_end = tokens[-1].end() if tokens else 0
if last_end < len(text):
parts.append(text[last_end:])
corrected = "".join(parts)
return CorrectionResult(
original=text,
corrected=corrected,
lang_detected=detected,
changed=corrected != text,
changes=changes,
)
# --- Vocabulary entry correction ---
def correct_vocab_entry(self, english: str, german: str,
example: str = "") -> Dict[str, CorrectionResult]:
"""Correct a full vocabulary entry (EN + DE + example).
Uses column position to determine language -- the most reliable signal.
"""
results = {}
results["english"] = self.correct_text(english, lang="en")
results["german"] = self.correct_text(german, lang="de")
if example:
# For examples, auto-detect language
results["example"] = self.correct_text(example, lang="auto")
return results

View File

@@ -1,602 +1,29 @@
"""
Mobile Upload API for Klausur-Service
Mobile Upload API — barrel re-export.
All implementation split into:
upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
upload_api_mobile — mobile HTML upload page
Provides chunked upload endpoints for large PDF files (100MB+) from mobile devices.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
import os
import uuid
import shutil
import hashlib
from pathlib import Path
from datetime import datetime, timezone
from typing import Dict, Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
# Configuration
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
# Ensure directories exist
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# In-memory storage for upload sessions (for simplicity)
# In production, use Redis or database
_upload_sessions: Dict[str, dict] = {}
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
class InitUploadRequest(BaseModel):
filename: str
filesize: int
chunks: int
destination: str = "klausur" # "klausur" or "rag"
class InitUploadResponse(BaseModel):
upload_id: str
chunk_size: int
total_chunks: int
message: str
class ChunkUploadResponse(BaseModel):
upload_id: str
chunk_index: int
received: bool
chunks_received: int
total_chunks: int
class FinalizeResponse(BaseModel):
upload_id: str
filename: str
filepath: str
filesize: int
checksum: str
message: str
@router.post("/init", response_model=InitUploadResponse)
async def init_upload(request: InitUploadRequest):
"""
Initialize a chunked upload session.
Returns an upload_id that must be used for subsequent chunk uploads.
"""
upload_id = str(uuid.uuid4())
# Create session directory
session_dir = CHUNK_DIR / upload_id
session_dir.mkdir(parents=True, exist_ok=True)
# Store session info
_upload_sessions[upload_id] = {
"filename": request.filename,
"filesize": request.filesize,
"total_chunks": request.chunks,
"received_chunks": set(),
"destination": request.destination,
"session_dir": str(session_dir),
"created_at": datetime.now(timezone.utc).isoformat(),
}
return InitUploadResponse(
upload_id=upload_id,
chunk_size=5 * 1024 * 1024, # 5 MB
total_chunks=request.chunks,
message="Upload-Session erstellt"
)
@router.post("/chunk", response_model=ChunkUploadResponse)
async def upload_chunk(
chunk: UploadFile = File(...),
upload_id: str = Form(...),
chunk_index: int = Form(...)
):
"""
Upload a single chunk of a file.
Chunks are stored temporarily until finalize is called.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
raise HTTPException(
status_code=400,
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
)
# Save chunk
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
with open(chunk_path, "wb") as f:
content = await chunk.read()
f.write(content)
# Track received chunks
session["received_chunks"].add(chunk_index)
return ChunkUploadResponse(
upload_id=upload_id,
chunk_index=chunk_index,
received=True,
chunks_received=len(session["received_chunks"]),
total_chunks=session["total_chunks"]
)
@router.post("/finalize", response_model=FinalizeResponse)
async def finalize_upload(upload_id: str = Form(...)):
"""
Finalize the upload by combining all chunks into a single file.
Validates that all chunks were received and calculates checksum.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Check if all chunks received
if len(session["received_chunks"]) != session["total_chunks"]:
missing = session["total_chunks"] - len(session["received_chunks"])
raise HTTPException(
status_code=400,
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
)
# Determine destination directory
if session["destination"] == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = session["filename"].replace(" ", "_")
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Combine chunks
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as outfile:
for i in range(session["total_chunks"]):
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
if not chunk_path.exists():
raise HTTPException(
status_code=500,
detail=f"Chunk {i} nicht gefunden"
)
with open(chunk_path, "rb") as infile:
data = infile.read()
outfile.write(data)
hasher.update(data)
total_size += len(data)
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
checksum = hasher.hexdigest()
return FinalizeResponse(
upload_id=upload_id,
filename=final_filename,
filepath=str(final_path),
filesize=total_size,
checksum=checksum,
message="Upload erfolgreich abgeschlossen"
)
@router.post("/simple")
async def simple_upload(
file: UploadFile = File(...),
destination: str = Form("klausur")
):
"""
Simple single-request upload for smaller files (<10MB).
For larger files, use the chunked upload endpoints.
"""
# Determine destination directory
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Calculate checksum while writing
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as f:
while True:
chunk = await file.read(1024 * 1024) # Read 1MB at a time
if not chunk:
break
f.write(chunk)
hasher.update(chunk)
total_size += len(chunk)
return {
"filename": final_filename,
"filepath": str(final_path),
"filesize": total_size,
"checksum": hasher.hexdigest(),
"message": "Upload erfolgreich"
}
@router.get("/status/{upload_id}")
async def get_upload_status(upload_id: str):
"""
Get the status of an ongoing upload.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
return {
"upload_id": upload_id,
"filename": session["filename"],
"total_chunks": session["total_chunks"],
"received_chunks": len(session["received_chunks"]),
"progress_percent": round(
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
),
"destination": session["destination"],
"created_at": session["created_at"]
}
@router.delete("/cancel/{upload_id}")
async def cancel_upload(upload_id: str):
"""
Cancel an ongoing upload and clean up temporary files.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
return {"message": "Upload abgebrochen", "upload_id": upload_id}
@router.get("/list")
async def list_uploads(destination: str = "klausur"):
"""
List all uploaded files in the specified destination.
"""
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
files = []
for f in dest_dir.iterdir():
if f.is_file() and f.suffix.lower() == ".pdf":
stat = f.stat()
files.append({
"filename": f.name,
"size": stat.st_size,
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
})
files.sort(key=lambda x: x["modified"], reverse=True)
return {
"destination": destination,
"count": len(files),
"files": files[:50] # Limit to 50 most recent
}
@router.get("/mobile", response_class=HTMLResponse)
async def mobile_upload_page():
"""
Serve the mobile upload page directly from the klausur-service.
This allows mobile devices to upload without needing the Next.js website.
"""
from fastapi.responses import HTMLResponse
html_content = '''<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<meta name="apple-mobile-web-app-capable" content="yes">
<title>BreakPilot Upload</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
color: white;
min-height: 100vh;
padding: 16px;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 16px;
border-bottom: 1px solid #334155;
margin-bottom: 24px;
}
.header h1 { font-size: 20px; color: #60a5fa; }
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
.destination-selector {
display: flex;
gap: 8px;
margin-bottom: 24px;
}
.dest-btn {
flex: 1;
padding: 14px;
border: none;
border-radius: 10px;
font-size: 14px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
}
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
.upload-zone {
border: 2px dashed #475569;
border-radius: 16px;
padding: 40px 20px;
text-align: center;
margin-bottom: 24px;
transition: all 0.2s;
position: relative;
}
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
.upload-zone input[type="file"] {
position: absolute;
inset: 0;
opacity: 0;
cursor: pointer;
}
.upload-icon {
width: 64px;
height: 64px;
background: #334155;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 16px;
font-size: 28px;
}
.upload-title { font-size: 18px; margin-bottom: 8px; }
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
.upload-hint { font-size: 12px; color: #64748b; }
.file-list { margin-bottom: 24px; }
.file-item {
background: #1e293b;
border-radius: 12px;
padding: 16px;
margin-bottom: 12px;
}
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
.file-name { font-weight: 500; word-break: break-all; }
.file-size { font-size: 14px; color: #94a3b8; }
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
.info-box {
background: rgba(30, 41, 59, 0.5);
border-radius: 12px;
padding: 16px;
font-size: 14px;
color: #94a3b8;
}
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
.info-box ul { padding-left: 20px; }
.info-box li { margin-bottom: 4px; }
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
</style>
</head>
<body>
<header class="header">
<h1>BreakPilot Upload</h1>
<span class="badge">DSGVO-konform</span>
</header>
<div class="destination-selector">
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
</div>
<div class="upload-zone" id="upload-zone">
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
<div class="upload-icon">&#x2601;</div>
<div class="upload-title">PDF-Dateien hochladen</div>
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
</div>
<div class="stats" id="stats" style="display: none;">
<span id="completed-count">0 von 0 fertig</span>
<span id="total-size">0 B gesamt</span>
</div>
<div class="file-list" id="file-list"></div>
<div class="info-box">
<h3>Hinweise:</h3>
<ul>
<li>Die Dateien werden lokal im WLAN uebertragen</li>
<li>Keine Daten werden ins Internet gesendet</li>
<li>Unterstuetzte Formate: PDF</li>
</ul>
</div>
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
<script>
const CHUNK_SIZE = 5 * 1024 * 1024;
let destination = 'klausur';
let files = [];
const serverUrl = window.location.origin;
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
function setDestination(dest) {
destination = dest;
document.querySelectorAll('.dest-btn').forEach(btn => {
btn.classList.remove('active-klausur', 'active-rag');
});
if (dest === 'klausur') {
document.getElementById('btn-klausur').classList.add('active-klausur');
} else {
document.getElementById('btn-rag').classList.add('active-rag');
}
}
function formatSize(bytes) {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
}
function updateStats() {
const completed = files.filter(f => f.status === 'complete').length;
const total = files.reduce((sum, f) => sum + f.size, 0);
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
}
function renderFiles() {
const list = document.getElementById('file-list');
list.innerHTML = files.map(f => {
let statusHtml = '';
if (f.status === 'uploading' || f.status === 'pending') {
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
} else if (f.status === 'complete') {
statusHtml = '<div class="status-complete">&#x2713; Erfolgreich hochgeladen</div>';
} else if (f.status === 'error') {
statusHtml = '<div class="status-error">&#x26A0; ' + (f.error || 'Fehler beim Hochladen') + '</div>';
}
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">&times;</button></div>' + statusHtml + '</div>';
}).join('');
updateStats();
}
function removeFile(id) {
files = files.filter(f => f.id !== id);
renderFiles();
}
async function uploadFile(file, fileId) {
const updateProgress = (progress) => {
const f = files.find(f => f.id === fileId);
if (f) { f.progress = progress; renderFiles(); }
};
const setStatus = (status, error) => {
const f = files.find(f => f.id === fileId);
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
};
try {
setStatus('uploading');
if (file.size > 10 * 1024 * 1024) {
// Chunked upload
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
});
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
const { upload_id } = await initRes.json();
for (let i = 0; i < totalChunks; i++) {
const start = i * CHUNK_SIZE;
const end = Math.min(start + CHUNK_SIZE, file.size);
const chunk = file.slice(start, end);
const formData = new FormData();
formData.append('chunk', chunk);
formData.append('upload_id', upload_id);
formData.append('chunk_index', i.toString());
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
}
const finalizeForm = new FormData();
finalizeForm.append('upload_id', upload_id);
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
} else {
// Simple upload
const formData = new FormData();
formData.append('file', file);
formData.append('destination', destination);
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
if (!res.ok) throw new Error('Upload fehlgeschlagen');
updateProgress(100);
}
setStatus('complete');
} catch (e) {
setStatus('error', e.message);
}
}
function handleFiles(fileList) {
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
newFiles.forEach(file => {
const id = Math.random().toString(36).substr(2, 9);
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
renderFiles();
uploadFile(file, id);
});
}
// Drag & Drop
const zone = document.getElementById('upload-zone');
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
</script>
</body>
</html>'''
return HTMLResponse(content=html_content)
from fastapi import APIRouter
from upload_api_chunked import ( # noqa: F401
router as _chunked_router,
UPLOAD_DIR,
CHUNK_DIR,
EH_UPLOAD_DIR,
_upload_sessions,
InitUploadRequest,
InitUploadResponse,
ChunkUploadResponse,
FinalizeResponse,
)
from upload_api_mobile import router as _mobile_router # noqa: F401
# Composite router that includes both sub-routers
router = APIRouter()
router.include_router(_chunked_router)
router.include_router(_mobile_router)

View File

@@ -0,0 +1,320 @@
"""
Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
Extracted from upload_api.py for modularity.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
import os
import uuid
import shutil
import hashlib
from pathlib import Path
from datetime import datetime, timezone
from typing import Dict, Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
# Configuration
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
# Ensure directories exist
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# In-memory storage for upload sessions (for simplicity)
# In production, use Redis or database
_upload_sessions: Dict[str, dict] = {}
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
class InitUploadRequest(BaseModel):
filename: str
filesize: int
chunks: int
destination: str = "klausur" # "klausur" or "rag"
class InitUploadResponse(BaseModel):
upload_id: str
chunk_size: int
total_chunks: int
message: str
class ChunkUploadResponse(BaseModel):
upload_id: str
chunk_index: int
received: bool
chunks_received: int
total_chunks: int
class FinalizeResponse(BaseModel):
upload_id: str
filename: str
filepath: str
filesize: int
checksum: str
message: str
@router.post("/init", response_model=InitUploadResponse)
async def init_upload(request: InitUploadRequest):
"""
Initialize a chunked upload session.
Returns an upload_id that must be used for subsequent chunk uploads.
"""
upload_id = str(uuid.uuid4())
# Create session directory
session_dir = CHUNK_DIR / upload_id
session_dir.mkdir(parents=True, exist_ok=True)
# Store session info
_upload_sessions[upload_id] = {
"filename": request.filename,
"filesize": request.filesize,
"total_chunks": request.chunks,
"received_chunks": set(),
"destination": request.destination,
"session_dir": str(session_dir),
"created_at": datetime.now(timezone.utc).isoformat(),
}
return InitUploadResponse(
upload_id=upload_id,
chunk_size=5 * 1024 * 1024, # 5 MB
total_chunks=request.chunks,
message="Upload-Session erstellt"
)
@router.post("/chunk", response_model=ChunkUploadResponse)
async def upload_chunk(
chunk: UploadFile = File(...),
upload_id: str = Form(...),
chunk_index: int = Form(...)
):
"""
Upload a single chunk of a file.
Chunks are stored temporarily until finalize is called.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
raise HTTPException(
status_code=400,
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
)
# Save chunk
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
with open(chunk_path, "wb") as f:
content = await chunk.read()
f.write(content)
# Track received chunks
session["received_chunks"].add(chunk_index)
return ChunkUploadResponse(
upload_id=upload_id,
chunk_index=chunk_index,
received=True,
chunks_received=len(session["received_chunks"]),
total_chunks=session["total_chunks"]
)
@router.post("/finalize", response_model=FinalizeResponse)
async def finalize_upload(upload_id: str = Form(...)):
"""
Finalize the upload by combining all chunks into a single file.
Validates that all chunks were received and calculates checksum.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Check if all chunks received
if len(session["received_chunks"]) != session["total_chunks"]:
missing = session["total_chunks"] - len(session["received_chunks"])
raise HTTPException(
status_code=400,
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
)
# Determine destination directory
if session["destination"] == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = session["filename"].replace(" ", "_")
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Combine chunks
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as outfile:
for i in range(session["total_chunks"]):
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
if not chunk_path.exists():
raise HTTPException(
status_code=500,
detail=f"Chunk {i} nicht gefunden"
)
with open(chunk_path, "rb") as infile:
data = infile.read()
outfile.write(data)
hasher.update(data)
total_size += len(data)
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
checksum = hasher.hexdigest()
return FinalizeResponse(
upload_id=upload_id,
filename=final_filename,
filepath=str(final_path),
filesize=total_size,
checksum=checksum,
message="Upload erfolgreich abgeschlossen"
)
@router.post("/simple")
async def simple_upload(
file: UploadFile = File(...),
destination: str = Form("klausur")
):
"""
Simple single-request upload for smaller files (<10MB).
For larger files, use the chunked upload endpoints.
"""
# Determine destination directory
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Calculate checksum while writing
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as f:
while True:
chunk = await file.read(1024 * 1024) # Read 1MB at a time
if not chunk:
break
f.write(chunk)
hasher.update(chunk)
total_size += len(chunk)
return {
"filename": final_filename,
"filepath": str(final_path),
"filesize": total_size,
"checksum": hasher.hexdigest(),
"message": "Upload erfolgreich"
}
@router.get("/status/{upload_id}")
async def get_upload_status(upload_id: str):
"""
Get the status of an ongoing upload.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
return {
"upload_id": upload_id,
"filename": session["filename"],
"total_chunks": session["total_chunks"],
"received_chunks": len(session["received_chunks"]),
"progress_percent": round(
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
),
"destination": session["destination"],
"created_at": session["created_at"]
}
@router.delete("/cancel/{upload_id}")
async def cancel_upload(upload_id: str):
"""
Cancel an ongoing upload and clean up temporary files.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
return {"message": "Upload abgebrochen", "upload_id": upload_id}
@router.get("/list")
async def list_uploads(destination: str = "klausur"):
"""
List all uploaded files in the specified destination.
"""
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
files = []
for f in dest_dir.iterdir():
if f.is_file() and f.suffix.lower() == ".pdf":
stat = f.stat()
files.append({
"filename": f.name,
"size": stat.st_size,
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
})
files.sort(key=lambda x: x["modified"], reverse=True)
return {
"destination": destination,
"count": len(files),
"files": files[:50] # Limit to 50 most recent
}

View File

@@ -0,0 +1,292 @@
"""
Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
Extracted from upload_api.py for modularity.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
from fastapi import APIRouter
from fastapi.responses import HTMLResponse
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
@router.get("/mobile", response_class=HTMLResponse)
async def mobile_upload_page():
"""
Serve the mobile upload page directly from the klausur-service.
This allows mobile devices to upload without needing the Next.js website.
"""
html_content = '''<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<meta name="apple-mobile-web-app-capable" content="yes">
<title>BreakPilot Upload</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
color: white;
min-height: 100vh;
padding: 16px;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 16px;
border-bottom: 1px solid #334155;
margin-bottom: 24px;
}
.header h1 { font-size: 20px; color: #60a5fa; }
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
.destination-selector {
display: flex;
gap: 8px;
margin-bottom: 24px;
}
.dest-btn {
flex: 1;
padding: 14px;
border: none;
border-radius: 10px;
font-size: 14px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
}
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
.upload-zone {
border: 2px dashed #475569;
border-radius: 16px;
padding: 40px 20px;
text-align: center;
margin-bottom: 24px;
transition: all 0.2s;
position: relative;
}
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
.upload-zone input[type="file"] {
position: absolute;
inset: 0;
opacity: 0;
cursor: pointer;
}
.upload-icon {
width: 64px;
height: 64px;
background: #334155;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 16px;
font-size: 28px;
}
.upload-title { font-size: 18px; margin-bottom: 8px; }
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
.upload-hint { font-size: 12px; color: #64748b; }
.file-list { margin-bottom: 24px; }
.file-item {
background: #1e293b;
border-radius: 12px;
padding: 16px;
margin-bottom: 12px;
}
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
.file-name { font-weight: 500; word-break: break-all; }
.file-size { font-size: 14px; color: #94a3b8; }
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
.info-box {
background: rgba(30, 41, 59, 0.5);
border-radius: 12px;
padding: 16px;
font-size: 14px;
color: #94a3b8;
}
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
.info-box ul { padding-left: 20px; }
.info-box li { margin-bottom: 4px; }
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
</style>
</head>
<body>
<header class="header">
<h1>BreakPilot Upload</h1>
<span class="badge">DSGVO-konform</span>
</header>
<div class="destination-selector">
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
</div>
<div class="upload-zone" id="upload-zone">
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
<div class="upload-icon">&#x2601;</div>
<div class="upload-title">PDF-Dateien hochladen</div>
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
</div>
<div class="stats" id="stats" style="display: none;">
<span id="completed-count">0 von 0 fertig</span>
<span id="total-size">0 B gesamt</span>
</div>
<div class="file-list" id="file-list"></div>
<div class="info-box">
<h3>Hinweise:</h3>
<ul>
<li>Die Dateien werden lokal im WLAN uebertragen</li>
<li>Keine Daten werden ins Internet gesendet</li>
<li>Unterstuetzte Formate: PDF</li>
</ul>
</div>
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
<script>
const CHUNK_SIZE = 5 * 1024 * 1024;
let destination = 'klausur';
let files = [];
const serverUrl = window.location.origin;
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
function setDestination(dest) {
destination = dest;
document.querySelectorAll('.dest-btn').forEach(btn => {
btn.classList.remove('active-klausur', 'active-rag');
});
if (dest === 'klausur') {
document.getElementById('btn-klausur').classList.add('active-klausur');
} else {
document.getElementById('btn-rag').classList.add('active-rag');
}
}
function formatSize(bytes) {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
}
function updateStats() {
const completed = files.filter(f => f.status === 'complete').length;
const total = files.reduce((sum, f) => sum + f.size, 0);
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
}
function renderFiles() {
const list = document.getElementById('file-list');
list.innerHTML = files.map(f => {
let statusHtml = '';
if (f.status === 'uploading' || f.status === 'pending') {
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
} else if (f.status === 'complete') {
statusHtml = '<div class="status-complete">&#x2713; Erfolgreich hochgeladen</div>';
} else if (f.status === 'error') {
statusHtml = '<div class="status-error">&#x26A0; ' + (f.error || 'Fehler beim Hochladen') + '</div>';
}
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">&times;</button></div>' + statusHtml + '</div>';
}).join('');
updateStats();
}
function removeFile(id) {
files = files.filter(f => f.id !== id);
renderFiles();
}
async function uploadFile(file, fileId) {
const updateProgress = (progress) => {
const f = files.find(f => f.id === fileId);
if (f) { f.progress = progress; renderFiles(); }
};
const setStatus = (status, error) => {
const f = files.find(f => f.id === fileId);
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
};
try {
setStatus('uploading');
if (file.size > 10 * 1024 * 1024) {
// Chunked upload
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
});
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
const { upload_id } = await initRes.json();
for (let i = 0; i < totalChunks; i++) {
const start = i * CHUNK_SIZE;
const end = Math.min(start + CHUNK_SIZE, file.size);
const chunk = file.slice(start, end);
const formData = new FormData();
formData.append('chunk', chunk);
formData.append('upload_id', upload_id);
formData.append('chunk_index', i.toString());
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
}
const finalizeForm = new FormData();
finalizeForm.append('upload_id', upload_id);
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
} else {
// Simple upload
const formData = new FormData();
formData.append('file', file);
formData.append('destination', destination);
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
if (!res.ok) throw new Error('Upload fehlgeschlagen');
updateProgress(100);
}
setStatus('complete');
} catch (e) {
setStatus('error', e.message);
}
}
function handleFiles(fileList) {
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
newFiles.forEach(file => {
const id = Math.random().toString(36).substr(2, 9);
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
renderFiles();
uploadFile(file, id);
});
}
// Drag & Drop
const zone = document.getElementById('upload-zone');
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
</script>
</body>
</html>'''
return HTMLResponse(content=html_content)

View File

@@ -1,537 +1,19 @@
"""
Zeugnis Rights-Aware Crawler - API Endpoints
Zeugnis Rights-Aware Crawler — barrel re-export.
All implementation split into:
zeugnis_api_sources — sources, seed URLs, initialization
zeugnis_api_docs — documents, crawler, statistics, audit
FastAPI router for managing zeugnis sources, documents, and crawler operations.
"""
from datetime import datetime, timedelta
from typing import Optional, List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from pydantic import BaseModel
from fastapi import APIRouter
from zeugnis_models import (
ZeugnisSource, ZeugnisSourceCreate, ZeugnisSourceVerify,
SeedUrl, SeedUrlCreate,
ZeugnisDocument, ZeugnisStats,
CrawlerStatus, CrawlRequest, CrawlQueueItem,
UsageEvent, AuditExport,
LicenseType, CrawlStatus, DocType, EventType,
BUNDESLAENDER, TRAINING_PERMISSIONS,
generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
)
from zeugnis_crawler import (
start_crawler, stop_crawler, get_crawler_status,
)
from metrics_db import (
get_zeugnis_sources, upsert_zeugnis_source,
get_zeugnis_documents, get_zeugnis_stats,
log_zeugnis_event, get_pool,
)
from zeugnis_api_sources import router as _sources_router # noqa: F401
from zeugnis_api_docs import router as _docs_router # noqa: F401
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
# =============================================================================
# Sources Endpoints
# =============================================================================
@router.get("/sources", response_model=List[dict])
async def list_sources():
"""Get all zeugnis sources (Bundesländer)."""
sources = await get_zeugnis_sources()
if not sources:
# Return default sources if none exist
return [
{
"id": None,
"bundesland": code,
"name": info["name"],
"base_url": None,
"license_type": str(get_license_for_bundesland(code).value),
"training_allowed": get_training_allowed(code),
"verified_by": None,
"verified_at": None,
"created_at": None,
"updated_at": None,
}
for code, info in BUNDESLAENDER.items()
]
return sources
@router.post("/sources", response_model=dict)
async def create_source(source: ZeugnisSourceCreate):
"""Create or update a zeugnis source."""
source_id = generate_id()
success = await upsert_zeugnis_source(
id=source_id,
bundesland=source.bundesland,
name=source.name,
license_type=source.license_type.value,
training_allowed=source.training_allowed,
base_url=source.base_url,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create source")
return {"id": source_id, "success": True}
@router.put("/sources/{source_id}/verify", response_model=dict)
async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
"""Verify a source's license status."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE zeugnis_sources
SET license_type = $2,
training_allowed = $3,
verified_by = $4,
verified_at = NOW(),
updated_at = NOW()
WHERE id = $1
""",
source_id, verification.license_type.value,
verification.training_allowed, verification.verified_by
)
return {"success": True, "source_id": source_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sources/{bundesland}", response_model=dict)
async def get_source_by_bundesland(bundesland: str):
"""Get source details for a specific Bundesland."""
pool = await get_pool()
if not pool:
# Return default info
if bundesland not in BUNDESLAENDER:
raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
return {
"bundesland": bundesland,
"name": get_bundesland_name(bundesland),
"training_allowed": get_training_allowed(bundesland),
"license_type": get_license_for_bundesland(bundesland).value,
"document_count": 0,
}
try:
async with pool.acquire() as conn:
source = await conn.fetchrow(
"SELECT * FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
if source:
doc_count = await conn.fetchval(
"""
SELECT COUNT(*) FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
WHERE u.source_id = $1
""",
source["id"]
)
return {**dict(source), "document_count": doc_count or 0}
# Return default
return {
"bundesland": bundesland,
"name": get_bundesland_name(bundesland),
"training_allowed": get_training_allowed(bundesland),
"license_type": get_license_for_bundesland(bundesland).value,
"document_count": 0,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Seed URLs Endpoints
# =============================================================================
@router.get("/sources/{source_id}/urls", response_model=List[dict])
async def list_seed_urls(source_id: str):
"""Get all seed URLs for a source."""
pool = await get_pool()
if not pool:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
source_id
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sources/{source_id}/urls", response_model=dict)
async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
"""Add a new seed URL to a source."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
url_id = generate_id()
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
VALUES ($1, $2, $3, $4, 'pending')
""",
url_id, source_id, seed_url.url, seed_url.doc_type.value
)
return {"id": url_id, "success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/urls/{url_id}", response_model=dict)
async def delete_seed_url(url_id: str):
"""Delete a seed URL."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM zeugnis_seed_urls WHERE id = $1",
url_id
)
return {"success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Documents Endpoints
# =============================================================================
@router.get("/documents", response_model=List[dict])
async def list_documents(
bundesland: Optional[str] = None,
limit: int = Query(100, le=500),
offset: int = 0,
):
"""Get all zeugnis documents with optional filtering."""
documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
return documents
@router.get("/documents/{document_id}", response_model=dict)
async def get_document(document_id: str):
"""Get details for a specific document."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
doc = await conn.fetchrow(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE d.id = $1
""",
document_id
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Log view event
await log_zeugnis_event(document_id, EventType.VIEWED.value)
return dict(doc)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{document_id}/versions", response_model=List[dict])
async def get_document_versions(document_id: str):
"""Get version history for a document."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM zeugnis_document_versions
WHERE document_id = $1
ORDER BY version DESC
""",
document_id
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Crawler Control Endpoints
# =============================================================================
@router.get("/crawler/status", response_model=dict)
async def crawler_status():
"""Get current crawler status."""
return get_crawler_status()
@router.post("/crawler/start", response_model=dict)
async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
"""Start the crawler."""
success = await start_crawler(
bundesland=request.bundesland,
source_id=request.source_id,
)
if not success:
raise HTTPException(status_code=409, detail="Crawler already running")
return {"success": True, "message": "Crawler started"}
@router.post("/crawler/stop", response_model=dict)
async def stop_crawl():
"""Stop the crawler."""
success = await stop_crawler()
if not success:
raise HTTPException(status_code=409, detail="Crawler not running")
return {"success": True, "message": "Crawler stopped"}
@router.get("/crawler/queue", response_model=List[dict])
async def get_queue():
"""Get the crawler queue."""
pool = await get_pool()
if not pool:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT q.*, s.bundesland, s.name as source_name
FROM zeugnis_crawler_queue q
JOIN zeugnis_sources s ON q.source_id = s.id
ORDER BY q.priority DESC, q.created_at
"""
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/crawler/queue", response_model=dict)
async def add_to_queue(request: CrawlRequest):
"""Add a source to the crawler queue."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
queue_id = generate_id()
try:
async with pool.acquire() as conn:
# Get source ID if bundesland provided
source_id = request.source_id
if not source_id and request.bundesland:
source = await conn.fetchrow(
"SELECT id FROM zeugnis_sources WHERE bundesland = $1",
request.bundesland
)
if source:
source_id = source["id"]
if not source_id:
raise HTTPException(status_code=400, detail="Source not found")
await conn.execute(
"""
INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
VALUES ($1, $2, $3, 'pending')
""",
queue_id, source_id, request.priority
)
return {"id": queue_id, "success": True}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Statistics Endpoints
# =============================================================================
@router.get("/stats", response_model=dict)
async def get_stats():
"""Get zeugnis crawler statistics."""
stats = await get_zeugnis_stats()
return stats
@router.get("/stats/bundesland", response_model=List[dict])
async def get_bundesland_stats():
"""Get statistics per Bundesland."""
pool = await get_pool()
# Build stats from BUNDESLAENDER with DB data if available
stats = []
for code, info in BUNDESLAENDER.items():
stat = {
"bundesland": code,
"name": info["name"],
"training_allowed": get_training_allowed(code),
"document_count": 0,
"indexed_count": 0,
"last_crawled": None,
}
if pool:
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT
COUNT(d.id) as doc_count,
COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
MAX(u.last_crawled) as last_crawled
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
WHERE s.bundesland = $1
GROUP BY s.id
""",
code
)
if row:
stat["document_count"] = row["doc_count"] or 0
stat["indexed_count"] = row["indexed_count"] or 0
stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
except Exception:
pass
stats.append(stat)
return stats
# =============================================================================
# Audit Endpoints
# =============================================================================
@router.get("/audit/events", response_model=List[dict])
async def get_audit_events(
document_id: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = Query(100, le=1000),
days: int = Query(30, le=365),
):
"""Get audit events with optional filtering."""
pool = await get_pool()
if not pool:
return []
try:
since = datetime.now() - timedelta(days=days)
async with pool.acquire() as conn:
query = """
SELECT * FROM zeugnis_usage_events
WHERE created_at >= $1
"""
params = [since]
if document_id:
query += " AND document_id = $2"
params.append(document_id)
if event_type:
query += f" AND event_type = ${len(params) + 1}"
params.append(event_type)
query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
params.append(limit)
rows = await conn.fetch(query, *params)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audit/export", response_model=dict)
async def export_audit(
days: int = Query(30, le=365),
requested_by: str = Query(..., description="User requesting the export"),
):
"""Export audit data for GDPR compliance."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
since = datetime.now() - timedelta(days=days)
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM zeugnis_usage_events
WHERE created_at >= $1
ORDER BY created_at DESC
""",
since
)
doc_count = await conn.fetchval(
"SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
since
)
return {
"export_date": datetime.now().isoformat(),
"requested_by": requested_by,
"events": [dict(r) for r in rows],
"document_count": doc_count or 0,
"date_range_start": since.isoformat(),
"date_range_end": datetime.now().isoformat(),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Initialization Endpoint
# =============================================================================
@router.post("/init", response_model=dict)
async def initialize_sources():
"""Initialize default sources from BUNDESLAENDER."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
created = 0
try:
for code, info in BUNDESLAENDER.items():
source_id = generate_id()
success = await upsert_zeugnis_source(
id=source_id,
bundesland=code,
name=info["name"],
license_type=get_license_for_bundesland(code).value,
training_allowed=get_training_allowed(code),
)
if success:
created += 1
return {"success": True, "sources_created": created}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Composite router (used by main.py)
router = APIRouter()
router.include_router(_sources_router)
router.include_router(_docs_router)

View File

@@ -0,0 +1,321 @@
"""
Zeugnis API Docs — documents, crawler control, statistics, audit endpoints.
Extracted from zeugnis_api.py for modularity.
"""
from datetime import datetime, timedelta
from typing import Optional, List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from zeugnis_models import (
CrawlRequest, EventType,
BUNDESLAENDER,
generate_id, get_training_allowed, get_license_for_bundesland,
)
from zeugnis_crawler import (
start_crawler, stop_crawler, get_crawler_status,
)
from metrics_db import (
get_zeugnis_documents, get_zeugnis_stats,
log_zeugnis_event, get_pool,
)
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
# =============================================================================
# Documents Endpoints
# =============================================================================
@router.get("/documents", response_model=List[dict])
async def list_documents(
bundesland: Optional[str] = None,
limit: int = Query(100, le=500),
offset: int = 0,
):
"""Get all zeugnis documents with optional filtering."""
documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
return documents
@router.get("/documents/{document_id}", response_model=dict)
async def get_document(document_id: str):
"""Get details for a specific document."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
doc = await conn.fetchrow(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE d.id = $1
""",
document_id
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Log view event
await log_zeugnis_event(document_id, EventType.VIEWED.value)
return dict(doc)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{document_id}/versions", response_model=List[dict])
async def get_document_versions(document_id: str):
"""Get version history for a document."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM zeugnis_document_versions
WHERE document_id = $1
ORDER BY version DESC
""",
document_id
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Crawler Control Endpoints
# =============================================================================
@router.get("/crawler/status", response_model=dict)
async def crawler_status():
"""Get current crawler status."""
return get_crawler_status()
@router.post("/crawler/start", response_model=dict)
async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
"""Start the crawler."""
success = await start_crawler(
bundesland=request.bundesland,
source_id=request.source_id,
)
if not success:
raise HTTPException(status_code=409, detail="Crawler already running")
return {"success": True, "message": "Crawler started"}
@router.post("/crawler/stop", response_model=dict)
async def stop_crawl():
"""Stop the crawler."""
success = await stop_crawler()
if not success:
raise HTTPException(status_code=409, detail="Crawler not running")
return {"success": True, "message": "Crawler stopped"}
@router.get("/crawler/queue", response_model=List[dict])
async def get_queue():
"""Get the crawler queue."""
pool = await get_pool()
if not pool:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT q.*, s.bundesland, s.name as source_name
FROM zeugnis_crawler_queue q
JOIN zeugnis_sources s ON q.source_id = s.id
ORDER BY q.priority DESC, q.created_at
"""
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/crawler/queue", response_model=dict)
async def add_to_queue(request: CrawlRequest):
"""Add a source to the crawler queue."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
queue_id = generate_id()
try:
async with pool.acquire() as conn:
# Get source ID if bundesland provided
source_id = request.source_id
if not source_id and request.bundesland:
source = await conn.fetchrow(
"SELECT id FROM zeugnis_sources WHERE bundesland = $1",
request.bundesland
)
if source:
source_id = source["id"]
if not source_id:
raise HTTPException(status_code=400, detail="Source not found")
await conn.execute(
"""
INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
VALUES ($1, $2, $3, 'pending')
""",
queue_id, source_id, request.priority
)
return {"id": queue_id, "success": True}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Statistics Endpoints
# =============================================================================
@router.get("/stats", response_model=dict)
async def get_stats():
"""Get zeugnis crawler statistics."""
stats = await get_zeugnis_stats()
return stats
@router.get("/stats/bundesland", response_model=List[dict])
async def get_bundesland_stats():
"""Get statistics per Bundesland."""
pool = await get_pool()
# Build stats from BUNDESLAENDER with DB data if available
stats = []
for code, info in BUNDESLAENDER.items():
stat = {
"bundesland": code,
"name": info["name"],
"training_allowed": get_training_allowed(code),
"document_count": 0,
"indexed_count": 0,
"last_crawled": None,
}
if pool:
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT
COUNT(d.id) as doc_count,
COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
MAX(u.last_crawled) as last_crawled
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
WHERE s.bundesland = $1
GROUP BY s.id
""",
code
)
if row:
stat["document_count"] = row["doc_count"] or 0
stat["indexed_count"] = row["indexed_count"] or 0
stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
except Exception:
pass
stats.append(stat)
return stats
# =============================================================================
# Audit Endpoints
# =============================================================================
@router.get("/audit/events", response_model=List[dict])
async def get_audit_events(
document_id: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = Query(100, le=1000),
days: int = Query(30, le=365),
):
"""Get audit events with optional filtering."""
pool = await get_pool()
if not pool:
return []
try:
since = datetime.now() - timedelta(days=days)
async with pool.acquire() as conn:
query = """
SELECT * FROM zeugnis_usage_events
WHERE created_at >= $1
"""
params = [since]
if document_id:
query += " AND document_id = $2"
params.append(document_id)
if event_type:
query += f" AND event_type = ${len(params) + 1}"
params.append(event_type)
query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
params.append(limit)
rows = await conn.fetch(query, *params)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audit/export", response_model=dict)
async def export_audit(
days: int = Query(30, le=365),
requested_by: str = Query(..., description="User requesting the export"),
):
"""Export audit data for GDPR compliance."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
since = datetime.now() - timedelta(days=days)
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM zeugnis_usage_events
WHERE created_at >= $1
ORDER BY created_at DESC
""",
since
)
doc_count = await conn.fetchval(
"SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
since
)
return {
"export_date": datetime.now().isoformat(),
"requested_by": requested_by,
"events": [dict(r) for r in rows],
"document_count": doc_count or 0,
"date_range_start": since.isoformat(),
"date_range_end": datetime.now().isoformat(),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,232 @@
"""
Zeugnis API Sources — source and seed URL management endpoints.
Extracted from zeugnis_api.py for modularity.
"""
from typing import Optional, List
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from zeugnis_models import (
ZeugnisSourceCreate, ZeugnisSourceVerify,
SeedUrlCreate,
LicenseType, DocType,
BUNDESLAENDER,
generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
)
from metrics_db import (
get_zeugnis_sources, upsert_zeugnis_source, get_pool,
)
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
# =============================================================================
# Sources Endpoints
# =============================================================================
@router.get("/sources", response_model=List[dict])
async def list_sources():
"""Get all zeugnis sources (Bundeslaender)."""
sources = await get_zeugnis_sources()
if not sources:
# Return default sources if none exist
return [
{
"id": None,
"bundesland": code,
"name": info["name"],
"base_url": None,
"license_type": str(get_license_for_bundesland(code).value),
"training_allowed": get_training_allowed(code),
"verified_by": None,
"verified_at": None,
"created_at": None,
"updated_at": None,
}
for code, info in BUNDESLAENDER.items()
]
return sources
@router.post("/sources", response_model=dict)
async def create_source(source: ZeugnisSourceCreate):
"""Create or update a zeugnis source."""
source_id = generate_id()
success = await upsert_zeugnis_source(
id=source_id,
bundesland=source.bundesland,
name=source.name,
license_type=source.license_type.value,
training_allowed=source.training_allowed,
base_url=source.base_url,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create source")
return {"id": source_id, "success": True}
@router.put("/sources/{source_id}/verify", response_model=dict)
async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
"""Verify a source's license status."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE zeugnis_sources
SET license_type = $2,
training_allowed = $3,
verified_by = $4,
verified_at = NOW(),
updated_at = NOW()
WHERE id = $1
""",
source_id, verification.license_type.value,
verification.training_allowed, verification.verified_by
)
return {"success": True, "source_id": source_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sources/{bundesland}", response_model=dict)
async def get_source_by_bundesland(bundesland: str):
"""Get source details for a specific Bundesland."""
pool = await get_pool()
if not pool:
# Return default info
if bundesland not in BUNDESLAENDER:
raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
return {
"bundesland": bundesland,
"name": get_bundesland_name(bundesland),
"training_allowed": get_training_allowed(bundesland),
"license_type": get_license_for_bundesland(bundesland).value,
"document_count": 0,
}
try:
async with pool.acquire() as conn:
source = await conn.fetchrow(
"SELECT * FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
if source:
doc_count = await conn.fetchval(
"""
SELECT COUNT(*) FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
WHERE u.source_id = $1
""",
source["id"]
)
return {**dict(source), "document_count": doc_count or 0}
# Return default
return {
"bundesland": bundesland,
"name": get_bundesland_name(bundesland),
"training_allowed": get_training_allowed(bundesland),
"license_type": get_license_for_bundesland(bundesland).value,
"document_count": 0,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Seed URLs Endpoints
# =============================================================================
@router.get("/sources/{source_id}/urls", response_model=List[dict])
async def list_seed_urls(source_id: str):
"""Get all seed URLs for a source."""
pool = await get_pool()
if not pool:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
source_id
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sources/{source_id}/urls", response_model=dict)
async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
"""Add a new seed URL to a source."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
url_id = generate_id()
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
VALUES ($1, $2, $3, $4, 'pending')
""",
url_id, source_id, seed_url.url, seed_url.doc_type.value
)
return {"id": url_id, "success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/urls/{url_id}", response_model=dict)
async def delete_seed_url(url_id: str):
"""Delete a seed URL."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM zeugnis_seed_urls WHERE id = $1",
url_id
)
return {"success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Initialization Endpoint
# =============================================================================
@router.post("/init", response_model=dict)
async def initialize_sources():
"""Initialize default sources from BUNDESLAENDER."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
created = 0
try:
for code, info in BUNDESLAENDER.items():
source_id = generate_id()
success = await upsert_zeugnis_source(
id=source_id,
bundesland=code,
name=info["name"],
license_type=get_license_for_bundesland(code).value,
training_allowed=get_training_allowed(code),
)
if success:
created += 1
return {"success": True, "sources_created": created}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))