[split-required] Split final 43 files (500-668 LOC) to complete refactoring
klausur-service (11 files): - cv_gutter_repair, ocr_pipeline_regression, upload_api - ocr_pipeline_sessions, smart_spell, nru_worksheet_generator - ocr_pipeline_overlays, mail/aggregator, zeugnis_api - cv_syllable_detect, self_rag backend-lehrer (17 files): - classroom_engine/suggestions, generators/quiz_generator - worksheets_api, llm_gateway/comparison, state_engine_api - classroom/models (→ 4 submodules), services/file_processor - alerts_agent/api/wizard+digests+routes, content_generators/pdf - classroom/routes/sessions, llm_gateway/inference - classroom_engine/analytics, auth/keycloak_auth - alerts_agent/processing/rule_engine, ai_processor/print_versions agent-core (5 files): - brain/memory_store, brain/knowledge_graph, brain/context_manager - orchestrator/supervisor, sessions/session_manager admin-lehrer (5 components): - GridOverlay, StepGridReview, DevOpsPipelineSidebar - DataFlowDiagram, sbom/wizard/page website (2 files): - DependencyMap, lehrer/abitur-archiv Other: nibis_ingestion, grid_detection_service, export-doclayout-onnx Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,610 +1,35 @@
|
||||
"""
|
||||
Gutter Repair — detects and fixes words truncated or blurred at the book gutter.
|
||||
Gutter Repair — barrel re-export.
|
||||
|
||||
When scanning double-page spreads, the binding area (gutter) causes:
|
||||
1. Blurry/garbled trailing characters ("stammeli" → "stammeln")
|
||||
2. Words split across lines with a hyphen lost in the gutter
|
||||
("ve" + "künden" → "verkünden")
|
||||
|
||||
This module analyses grid cells, identifies gutter-edge candidates, and
|
||||
proposes corrections using pyspellchecker (DE + EN).
|
||||
All implementation split into:
|
||||
cv_gutter_repair_core — spellchecker setup, data types, single-word repair
|
||||
cv_gutter_repair_grid — grid analysis, suggestion application
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spellchecker setup (lazy, cached)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_spell_de = None
|
||||
_spell_en = None
|
||||
_SPELL_AVAILABLE = False
|
||||
|
||||
def _init_spellcheckers():
|
||||
"""Lazy-load DE + EN spellcheckers (cached across calls)."""
|
||||
global _spell_de, _spell_en, _SPELL_AVAILABLE
|
||||
if _spell_de is not None:
|
||||
return
|
||||
try:
|
||||
from spellchecker import SpellChecker
|
||||
_spell_de = SpellChecker(language='de', distance=1)
|
||||
_spell_en = SpellChecker(language='en', distance=1)
|
||||
_SPELL_AVAILABLE = True
|
||||
logger.info("Gutter repair: spellcheckers loaded (DE + EN)")
|
||||
except ImportError:
|
||||
logger.warning("pyspellchecker not installed — gutter repair unavailable")
|
||||
|
||||
|
||||
def _is_known(word: str) -> bool:
|
||||
"""Check if a word is known in DE or EN dictionary."""
|
||||
_init_spellcheckers()
|
||||
if not _SPELL_AVAILABLE:
|
||||
return False
|
||||
w = word.lower()
|
||||
return bool(_spell_de.known([w])) or bool(_spell_en.known([w]))
|
||||
|
||||
|
||||
def _spell_candidates(word: str, lang: str = "both") -> List[str]:
|
||||
"""Get all plausible spellchecker candidates for a word (deduplicated)."""
|
||||
_init_spellcheckers()
|
||||
if not _SPELL_AVAILABLE:
|
||||
return []
|
||||
w = word.lower()
|
||||
seen: set = set()
|
||||
results: List[str] = []
|
||||
|
||||
for checker in ([_spell_de, _spell_en] if lang == "both"
|
||||
else [_spell_de] if lang == "de"
|
||||
else [_spell_en]):
|
||||
if checker is None:
|
||||
continue
|
||||
cands = checker.candidates(w)
|
||||
if cands:
|
||||
for c in cands:
|
||||
if c and c != w and c not in seen:
|
||||
seen.add(c)
|
||||
results.append(c)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gutter position detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Minimum word length for spell-fix (very short words are often legitimate)
|
||||
_MIN_WORD_LEN_SPELL = 3
|
||||
|
||||
# Minimum word length for hyphen-join candidates (fragments at the gutter
|
||||
# can be as short as 1-2 chars, e.g. "ve" from "ver-künden")
|
||||
_MIN_WORD_LEN_HYPHEN = 2
|
||||
|
||||
# How close to the right column edge a word must be to count as "gutter-adjacent".
|
||||
# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%).
|
||||
_GUTTER_EDGE_THRESHOLD = 0.70
|
||||
|
||||
# Small common words / abbreviations that should NOT be repaired
|
||||
_STOPWORDS = frozenset([
|
||||
# German
|
||||
"ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um",
|
||||
"zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh",
|
||||
# English
|
||||
"a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in",
|
||||
"is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us",
|
||||
"we",
|
||||
])
|
||||
|
||||
# IPA / phonetic patterns — skip these cells
|
||||
_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]')
|
||||
|
||||
|
||||
def _is_ipa_text(text: str) -> bool:
|
||||
"""True if text looks like IPA transcription."""
|
||||
return bool(_IPA_RE.search(text))
|
||||
|
||||
|
||||
def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool:
|
||||
"""Check if a word's right edge is near the right boundary of its column."""
|
||||
if col_width <= 0:
|
||||
return False
|
||||
word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0)
|
||||
col_right = col_x + col_width
|
||||
# Word's right edge within the rightmost portion of the column
|
||||
relative_pos = (word_right - col_x) / col_width
|
||||
return relative_pos >= _GUTTER_EDGE_THRESHOLD
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Suggestion types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class GutterSuggestion:
|
||||
"""A single correction suggestion."""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
||||
type: str = "" # "hyphen_join" | "spell_fix"
|
||||
zone_index: int = 0
|
||||
row_index: int = 0
|
||||
col_index: int = 0
|
||||
col_type: str = ""
|
||||
cell_id: str = ""
|
||||
original_text: str = ""
|
||||
suggested_text: str = ""
|
||||
# For hyphen_join:
|
||||
next_row_index: int = -1
|
||||
next_row_cell_id: str = ""
|
||||
next_row_text: str = ""
|
||||
missing_chars: str = ""
|
||||
display_parts: List[str] = field(default_factory=list)
|
||||
# Alternatives (other plausible corrections the user can pick from)
|
||||
alternatives: List[str] = field(default_factory=list)
|
||||
# Meta:
|
||||
confidence: float = 0.0
|
||||
reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core repair logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$')
|
||||
|
||||
|
||||
def _try_hyphen_join(
|
||||
word_text: str,
|
||||
next_word_text: str,
|
||||
max_missing: int = 3,
|
||||
) -> Optional[Tuple[str, str, float]]:
|
||||
"""Try joining two fragments with 0..max_missing interpolated chars.
|
||||
|
||||
Strips trailing punctuation from the continuation word before testing
|
||||
(e.g. "künden," → "künden") so dictionary lookup succeeds.
|
||||
|
||||
Returns (joined_word, missing_chars, confidence) or None.
|
||||
"""
|
||||
base = word_text.rstrip("-").rstrip()
|
||||
# Strip trailing punctuation from continuation (commas, periods, etc.)
|
||||
raw_continuation = next_word_text.lstrip()
|
||||
continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation)
|
||||
|
||||
if not base or not continuation:
|
||||
return None
|
||||
|
||||
# 1. Direct join (no missing chars)
|
||||
direct = base + continuation
|
||||
if _is_known(direct):
|
||||
return (direct, "", 0.95)
|
||||
|
||||
# 2. Try with 1..max_missing missing characters
|
||||
# Use common letters, weighted by frequency in German/English
|
||||
_COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu"
|
||||
|
||||
for n_missing in range(1, max_missing + 1):
|
||||
for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing):
|
||||
candidate = base + "".join(chars) + continuation
|
||||
if _is_known(candidate):
|
||||
missing = "".join(chars)
|
||||
# Confidence decreases with more missing chars
|
||||
conf = 0.90 - (n_missing - 1) * 0.10
|
||||
return (candidate, missing, conf)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _try_spell_fix(
|
||||
word_text: str, col_type: str = "",
|
||||
) -> Optional[Tuple[str, float, List[str]]]:
|
||||
"""Try to fix a single garbled gutter word via spellchecker.
|
||||
|
||||
Returns (best_correction, confidence, alternatives_list) or None.
|
||||
The alternatives list contains other plausible corrections the user
|
||||
can choose from (e.g. "stammelt" vs "stammeln").
|
||||
"""
|
||||
if len(word_text) < _MIN_WORD_LEN_SPELL:
|
||||
return None
|
||||
|
||||
# Strip trailing/leading parentheses and check if the bare word is valid.
|
||||
# Words like "probieren)" or "(Englisch" are valid words with punctuation,
|
||||
# not OCR errors. Don't suggest corrections for them.
|
||||
stripped = word_text.strip("()")
|
||||
if stripped and _is_known(stripped):
|
||||
return None
|
||||
|
||||
# Determine language priority from column type
|
||||
if "en" in col_type:
|
||||
lang = "en"
|
||||
elif "de" in col_type:
|
||||
lang = "de"
|
||||
else:
|
||||
lang = "both"
|
||||
|
||||
candidates = _spell_candidates(word_text, lang=lang)
|
||||
if not candidates and lang != "both":
|
||||
candidates = _spell_candidates(word_text, lang="both")
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Preserve original casing
|
||||
is_upper = word_text[0].isupper()
|
||||
|
||||
def _preserve_case(w: str) -> str:
|
||||
if is_upper and w:
|
||||
return w[0].upper() + w[1:]
|
||||
return w
|
||||
|
||||
# Sort candidates by edit distance (closest first)
|
||||
scored = []
|
||||
for c in candidates:
|
||||
dist = _edit_distance(word_text.lower(), c.lower())
|
||||
scored.append((dist, c))
|
||||
scored.sort(key=lambda x: x[0])
|
||||
|
||||
best_dist, best = scored[0]
|
||||
best = _preserve_case(best)
|
||||
conf = max(0.5, 1.0 - best_dist * 0.15)
|
||||
|
||||
# Build alternatives (all other candidates, also case-preserved)
|
||||
alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()]
|
||||
# Limit to top 5 alternatives
|
||||
alts = alts[:5]
|
||||
|
||||
return (best, conf, alts)
|
||||
|
||||
|
||||
def _edit_distance(a: str, b: str) -> int:
|
||||
"""Simple Levenshtein distance."""
|
||||
if len(a) < len(b):
|
||||
return _edit_distance(b, a)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
prev = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a):
|
||||
curr = [i + 1]
|
||||
for j, cb in enumerate(b):
|
||||
cost = 0 if ca == cb else 1
|
||||
curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost))
|
||||
prev = curr
|
||||
return prev[len(b)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grid analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def analyse_grid_for_gutter_repair(
|
||||
grid_data: Dict[str, Any],
|
||||
image_width: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyse a structured grid and return gutter repair suggestions.
|
||||
|
||||
Args:
|
||||
grid_data: The grid_editor_result from the session (zones→cells structure).
|
||||
image_width: Image width in pixels (for determining gutter side).
|
||||
|
||||
Returns:
|
||||
Dict with "suggestions" list and "stats".
|
||||
"""
|
||||
t0 = time.time()
|
||||
_init_spellcheckers()
|
||||
|
||||
if not _SPELL_AVAILABLE:
|
||||
return {
|
||||
"suggestions": [],
|
||||
"stats": {"error": "pyspellchecker not installed"},
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
|
||||
zones = grid_data.get("zones", [])
|
||||
suggestions: List[GutterSuggestion] = []
|
||||
words_checked = 0
|
||||
gutter_candidates = 0
|
||||
|
||||
for zi, zone in enumerate(zones):
|
||||
columns = zone.get("columns", [])
|
||||
cells = zone.get("cells", [])
|
||||
if not columns or not cells:
|
||||
continue
|
||||
|
||||
# Build column lookup: col_index → {x, width, type}
|
||||
col_info: Dict[int, Dict] = {}
|
||||
for col in columns:
|
||||
ci = col.get("index", col.get("col_index", -1))
|
||||
col_info[ci] = {
|
||||
"x": col.get("x_min_px", col.get("x", 0)),
|
||||
"width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)),
|
||||
"type": col.get("type", col.get("col_type", "")),
|
||||
}
|
||||
|
||||
# Build row→col→cell lookup
|
||||
cell_map: Dict[Tuple[int, int], Dict] = {}
|
||||
max_row = 0
|
||||
for cell in cells:
|
||||
ri = cell.get("row_index", 0)
|
||||
ci = cell.get("col_index", 0)
|
||||
cell_map[(ri, ci)] = cell
|
||||
if ri > max_row:
|
||||
max_row = ri
|
||||
|
||||
# Determine which columns are at the gutter edge.
|
||||
# For a left page: rightmost content columns.
|
||||
# For now, check ALL columns — a word is a candidate if it's at the
|
||||
# right edge of its column AND not a known word.
|
||||
for (ri, ci), cell in cell_map.items():
|
||||
text = (cell.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if _is_ipa_text(text):
|
||||
continue
|
||||
|
||||
words_checked += 1
|
||||
col = col_info.get(ci, {})
|
||||
col_type = col.get("type", "")
|
||||
|
||||
# Get word boxes to check position
|
||||
word_boxes = cell.get("word_boxes", [])
|
||||
|
||||
# Check the LAST word in the cell (rightmost, closest to gutter)
|
||||
cell_words = text.split()
|
||||
if not cell_words:
|
||||
continue
|
||||
|
||||
last_word = cell_words[-1]
|
||||
|
||||
# Skip stopwords
|
||||
if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS:
|
||||
continue
|
||||
|
||||
last_word_clean = last_word.rstrip(".,;:!?)(")
|
||||
if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN:
|
||||
continue
|
||||
|
||||
# Check if the last word is at the gutter edge
|
||||
is_at_edge = False
|
||||
if word_boxes:
|
||||
last_wb = word_boxes[-1]
|
||||
is_at_edge = _word_is_at_gutter_edge(
|
||||
last_wb, col.get("x", 0), col.get("width", 1)
|
||||
)
|
||||
else:
|
||||
# No word boxes — use cell bbox
|
||||
bbox = cell.get("bbox_px", {})
|
||||
is_at_edge = _word_is_at_gutter_edge(
|
||||
{"left": bbox.get("x", 0), "width": bbox.get("w", 0)},
|
||||
col.get("x", 0), col.get("width", 1)
|
||||
)
|
||||
|
||||
if not is_at_edge:
|
||||
continue
|
||||
|
||||
# Word is at gutter edge — check if it's a known word
|
||||
if _is_known(last_word_clean):
|
||||
continue
|
||||
|
||||
# Check if the word ends with "-" (explicit hyphen break)
|
||||
ends_with_hyphen = last_word.endswith("-")
|
||||
|
||||
# If the word already ends with "-" and the stem (without
|
||||
# the hyphen) is a known word, this is a VALID line-break
|
||||
# hyphenation — not a gutter error. Gutter problems cause
|
||||
# the hyphen to be LOST ("ve" instead of "ver-"), so a
|
||||
# visible hyphen + known stem = intentional word-wrap.
|
||||
# Example: "wunder-" → "wunder" is known → skip.
|
||||
if ends_with_hyphen:
|
||||
stem = last_word_clean.rstrip("-")
|
||||
if stem and _is_known(stem):
|
||||
continue
|
||||
|
||||
gutter_candidates += 1
|
||||
|
||||
# --- Strategy 1: Hyphen join with next row ---
|
||||
next_cell = cell_map.get((ri + 1, ci))
|
||||
if next_cell:
|
||||
next_text = (next_cell.get("text") or "").strip()
|
||||
next_words = next_text.split()
|
||||
if next_words:
|
||||
first_next = next_words[0]
|
||||
first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next)
|
||||
first_alpha = next((c for c in first_next if c.isalpha()), "")
|
||||
|
||||
# Also skip if the joined word is known (covers compound
|
||||
# words where the stem alone might not be in the dictionary)
|
||||
if ends_with_hyphen and first_next_clean:
|
||||
direct = last_word_clean.rstrip("-") + first_next_clean
|
||||
if _is_known(direct):
|
||||
continue
|
||||
|
||||
# Continuation likely if:
|
||||
# - explicit hyphen, OR
|
||||
# - next row starts lowercase (= not a new entry)
|
||||
if ends_with_hyphen or (first_alpha and first_alpha.islower()):
|
||||
result = _try_hyphen_join(last_word_clean, first_next)
|
||||
if result:
|
||||
joined, missing, conf = result
|
||||
# Build display parts: show hyphenation for original layout
|
||||
if ends_with_hyphen:
|
||||
display_p1 = last_word_clean.rstrip("-")
|
||||
if missing:
|
||||
display_p1 += missing
|
||||
display_p1 += "-"
|
||||
else:
|
||||
display_p1 = last_word_clean
|
||||
if missing:
|
||||
display_p1 += missing + "-"
|
||||
else:
|
||||
display_p1 += "-"
|
||||
|
||||
suggestion = GutterSuggestion(
|
||||
type="hyphen_join",
|
||||
zone_index=zi,
|
||||
row_index=ri,
|
||||
col_index=ci,
|
||||
col_type=col_type,
|
||||
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
|
||||
original_text=last_word,
|
||||
suggested_text=joined,
|
||||
next_row_index=ri + 1,
|
||||
next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"),
|
||||
next_row_text=next_text,
|
||||
missing_chars=missing,
|
||||
display_parts=[display_p1, first_next],
|
||||
confidence=conf,
|
||||
reason="gutter_truncation" if missing else "hyphen_continuation",
|
||||
)
|
||||
suggestions.append(suggestion)
|
||||
continue # skip spell_fix if hyphen_join found
|
||||
|
||||
# --- Strategy 2: Single-word spell fix (only for longer words) ---
|
||||
fix_result = _try_spell_fix(last_word_clean, col_type)
|
||||
if fix_result:
|
||||
corrected, conf, alts = fix_result
|
||||
suggestion = GutterSuggestion(
|
||||
type="spell_fix",
|
||||
zone_index=zi,
|
||||
row_index=ri,
|
||||
col_index=ci,
|
||||
col_type=col_type,
|
||||
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
|
||||
original_text=last_word,
|
||||
suggested_text=corrected,
|
||||
alternatives=alts,
|
||||
confidence=conf,
|
||||
reason="gutter_blur",
|
||||
)
|
||||
suggestions.append(suggestion)
|
||||
|
||||
duration = round(time.time() - t0, 3)
|
||||
|
||||
logger.info(
|
||||
"Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)",
|
||||
words_checked, gutter_candidates, len(suggestions), duration,
|
||||
)
|
||||
|
||||
return {
|
||||
"suggestions": [s.to_dict() for s in suggestions],
|
||||
"stats": {
|
||||
"words_checked": words_checked,
|
||||
"gutter_candidates": gutter_candidates,
|
||||
"suggestions_found": len(suggestions),
|
||||
},
|
||||
"duration_seconds": duration,
|
||||
}
|
||||
|
||||
|
||||
def apply_gutter_suggestions(
|
||||
grid_data: Dict[str, Any],
|
||||
accepted_ids: List[str],
|
||||
suggestions: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply accepted gutter repair suggestions to the grid data.
|
||||
|
||||
Modifies cells in-place and returns summary of changes.
|
||||
|
||||
Args:
|
||||
grid_data: The grid_editor_result (zones→cells).
|
||||
accepted_ids: List of suggestion IDs the user accepted.
|
||||
suggestions: The full suggestions list (from analyse_grid_for_gutter_repair).
|
||||
|
||||
Returns:
|
||||
Dict with "applied_count" and "changes" list.
|
||||
"""
|
||||
accepted_set = set(accepted_ids)
|
||||
accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set]
|
||||
|
||||
zones = grid_data.get("zones", [])
|
||||
changes: List[Dict[str, Any]] = []
|
||||
|
||||
for s in accepted_suggestions:
|
||||
zi = s.get("zone_index", 0)
|
||||
ri = s.get("row_index", 0)
|
||||
ci = s.get("col_index", 0)
|
||||
stype = s.get("type", "")
|
||||
|
||||
if zi >= len(zones):
|
||||
continue
|
||||
zone_cells = zones[zi].get("cells", [])
|
||||
|
||||
# Find the target cell
|
||||
target_cell = None
|
||||
for cell in zone_cells:
|
||||
if cell.get("row_index") == ri and cell.get("col_index") == ci:
|
||||
target_cell = cell
|
||||
break
|
||||
|
||||
if not target_cell:
|
||||
continue
|
||||
|
||||
old_text = target_cell.get("text", "")
|
||||
|
||||
if stype == "spell_fix":
|
||||
# Replace the last word in the cell text
|
||||
original_word = s.get("original_text", "")
|
||||
corrected = s.get("suggested_text", "")
|
||||
if original_word and corrected:
|
||||
# Replace from the right (last occurrence)
|
||||
idx = old_text.rfind(original_word)
|
||||
if idx >= 0:
|
||||
new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):]
|
||||
target_cell["text"] = new_text
|
||||
changes.append({
|
||||
"type": "spell_fix",
|
||||
"zone_index": zi,
|
||||
"row_index": ri,
|
||||
"col_index": ci,
|
||||
"cell_id": target_cell.get("cell_id", ""),
|
||||
"old_text": old_text,
|
||||
"new_text": new_text,
|
||||
})
|
||||
|
||||
elif stype == "hyphen_join":
|
||||
# Current cell: replace last word with the hyphenated first part
|
||||
original_word = s.get("original_text", "")
|
||||
joined = s.get("suggested_text", "")
|
||||
display_parts = s.get("display_parts", [])
|
||||
next_ri = s.get("next_row_index", -1)
|
||||
|
||||
if not original_word or not joined or not display_parts:
|
||||
continue
|
||||
|
||||
# The first display part is what goes in the current row
|
||||
first_part = display_parts[0] if display_parts else ""
|
||||
|
||||
# Replace the last word in current cell with the restored form.
|
||||
# The next row is NOT modified — "künden" stays in its row
|
||||
# because the original book layout has it there. We only fix
|
||||
# the truncated word in the current row (e.g. "ve" → "ver-").
|
||||
idx = old_text.rfind(original_word)
|
||||
if idx >= 0:
|
||||
new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):]
|
||||
target_cell["text"] = new_text
|
||||
changes.append({
|
||||
"type": "hyphen_join",
|
||||
"zone_index": zi,
|
||||
"row_index": ri,
|
||||
"col_index": ci,
|
||||
"cell_id": target_cell.get("cell_id", ""),
|
||||
"old_text": old_text,
|
||||
"new_text": new_text,
|
||||
"joined_word": joined,
|
||||
})
|
||||
|
||||
logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions))
|
||||
|
||||
return {
|
||||
"applied_count": len(accepted_suggestions),
|
||||
"changes": changes,
|
||||
}
|
||||
# Core: spellchecker, data types, repair helpers
|
||||
from cv_gutter_repair_core import ( # noqa: F401
|
||||
_init_spellcheckers,
|
||||
_is_known,
|
||||
_spell_candidates,
|
||||
_MIN_WORD_LEN_SPELL,
|
||||
_MIN_WORD_LEN_HYPHEN,
|
||||
_GUTTER_EDGE_THRESHOLD,
|
||||
_STOPWORDS,
|
||||
_IPA_RE,
|
||||
_is_ipa_text,
|
||||
_word_is_at_gutter_edge,
|
||||
GutterSuggestion,
|
||||
_TRAILING_PUNCT_RE,
|
||||
_try_hyphen_join,
|
||||
_try_spell_fix,
|
||||
_edit_distance,
|
||||
)
|
||||
|
||||
# Grid: analysis and application
|
||||
from cv_gutter_repair_grid import ( # noqa: F401
|
||||
analyse_grid_for_gutter_repair,
|
||||
apply_gutter_suggestions,
|
||||
)
|
||||
|
||||
275
klausur-service/backend/cv_gutter_repair_core.py
Normal file
275
klausur-service/backend/cv_gutter_repair_core.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Gutter Repair Core — spellchecker setup, data types, and single-word repair logic.
|
||||
|
||||
Extracted from cv_gutter_repair.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spellchecker setup (lazy, cached)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_spell_de = None
|
||||
_spell_en = None
|
||||
_SPELL_AVAILABLE = False
|
||||
|
||||
def _init_spellcheckers():
|
||||
"""Lazy-load DE + EN spellcheckers (cached across calls)."""
|
||||
global _spell_de, _spell_en, _SPELL_AVAILABLE
|
||||
if _spell_de is not None:
|
||||
return
|
||||
try:
|
||||
from spellchecker import SpellChecker
|
||||
_spell_de = SpellChecker(language='de', distance=1)
|
||||
_spell_en = SpellChecker(language='en', distance=1)
|
||||
_SPELL_AVAILABLE = True
|
||||
logger.info("Gutter repair: spellcheckers loaded (DE + EN)")
|
||||
except ImportError:
|
||||
logger.warning("pyspellchecker not installed — gutter repair unavailable")
|
||||
|
||||
|
||||
def _is_known(word: str) -> bool:
|
||||
"""Check if a word is known in DE or EN dictionary."""
|
||||
_init_spellcheckers()
|
||||
if not _SPELL_AVAILABLE:
|
||||
return False
|
||||
w = word.lower()
|
||||
return bool(_spell_de.known([w])) or bool(_spell_en.known([w]))
|
||||
|
||||
|
||||
def _spell_candidates(word: str, lang: str = "both") -> List[str]:
|
||||
"""Get all plausible spellchecker candidates for a word (deduplicated)."""
|
||||
_init_spellcheckers()
|
||||
if not _SPELL_AVAILABLE:
|
||||
return []
|
||||
w = word.lower()
|
||||
seen: set = set()
|
||||
results: List[str] = []
|
||||
|
||||
for checker in ([_spell_de, _spell_en] if lang == "both"
|
||||
else [_spell_de] if lang == "de"
|
||||
else [_spell_en]):
|
||||
if checker is None:
|
||||
continue
|
||||
cands = checker.candidates(w)
|
||||
if cands:
|
||||
for c in cands:
|
||||
if c and c != w and c not in seen:
|
||||
seen.add(c)
|
||||
results.append(c)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gutter position detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Minimum word length for spell-fix (very short words are often legitimate)
|
||||
_MIN_WORD_LEN_SPELL = 3
|
||||
|
||||
# Minimum word length for hyphen-join candidates (fragments at the gutter
|
||||
# can be as short as 1-2 chars, e.g. "ve" from "ver-künden")
|
||||
_MIN_WORD_LEN_HYPHEN = 2
|
||||
|
||||
# How close to the right column edge a word must be to count as "gutter-adjacent".
|
||||
# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%).
|
||||
_GUTTER_EDGE_THRESHOLD = 0.70
|
||||
|
||||
# Small common words / abbreviations that should NOT be repaired
|
||||
_STOPWORDS = frozenset([
|
||||
# German
|
||||
"ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um",
|
||||
"zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh",
|
||||
# English
|
||||
"a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in",
|
||||
"is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us",
|
||||
"we",
|
||||
])
|
||||
|
||||
# IPA / phonetic patterns — skip these cells
|
||||
_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]')
|
||||
|
||||
|
||||
def _is_ipa_text(text: str) -> bool:
|
||||
"""True if text looks like IPA transcription."""
|
||||
return bool(_IPA_RE.search(text))
|
||||
|
||||
|
||||
def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool:
|
||||
"""Check if a word's right edge is near the right boundary of its column."""
|
||||
if col_width <= 0:
|
||||
return False
|
||||
word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0)
|
||||
col_right = col_x + col_width
|
||||
# Word's right edge within the rightmost portion of the column
|
||||
relative_pos = (word_right - col_x) / col_width
|
||||
return relative_pos >= _GUTTER_EDGE_THRESHOLD
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Suggestion types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class GutterSuggestion:
|
||||
"""A single correction suggestion."""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
||||
type: str = "" # "hyphen_join" | "spell_fix"
|
||||
zone_index: int = 0
|
||||
row_index: int = 0
|
||||
col_index: int = 0
|
||||
col_type: str = ""
|
||||
cell_id: str = ""
|
||||
original_text: str = ""
|
||||
suggested_text: str = ""
|
||||
# For hyphen_join:
|
||||
next_row_index: int = -1
|
||||
next_row_cell_id: str = ""
|
||||
next_row_text: str = ""
|
||||
missing_chars: str = ""
|
||||
display_parts: List[str] = field(default_factory=list)
|
||||
# Alternatives (other plausible corrections the user can pick from)
|
||||
alternatives: List[str] = field(default_factory=list)
|
||||
# Meta:
|
||||
confidence: float = 0.0
|
||||
reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core repair logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$')
|
||||
|
||||
|
||||
def _try_hyphen_join(
|
||||
word_text: str,
|
||||
next_word_text: str,
|
||||
max_missing: int = 3,
|
||||
) -> Optional[Tuple[str, str, float]]:
|
||||
"""Try joining two fragments with 0..max_missing interpolated chars.
|
||||
|
||||
Strips trailing punctuation from the continuation word before testing
|
||||
(e.g. "künden," → "künden") so dictionary lookup succeeds.
|
||||
|
||||
Returns (joined_word, missing_chars, confidence) or None.
|
||||
"""
|
||||
base = word_text.rstrip("-").rstrip()
|
||||
# Strip trailing punctuation from continuation (commas, periods, etc.)
|
||||
raw_continuation = next_word_text.lstrip()
|
||||
continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation)
|
||||
|
||||
if not base or not continuation:
|
||||
return None
|
||||
|
||||
# 1. Direct join (no missing chars)
|
||||
direct = base + continuation
|
||||
if _is_known(direct):
|
||||
return (direct, "", 0.95)
|
||||
|
||||
# 2. Try with 1..max_missing missing characters
|
||||
# Use common letters, weighted by frequency in German/English
|
||||
_COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu"
|
||||
|
||||
for n_missing in range(1, max_missing + 1):
|
||||
for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing):
|
||||
candidate = base + "".join(chars) + continuation
|
||||
if _is_known(candidate):
|
||||
missing = "".join(chars)
|
||||
# Confidence decreases with more missing chars
|
||||
conf = 0.90 - (n_missing - 1) * 0.10
|
||||
return (candidate, missing, conf)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _try_spell_fix(
|
||||
word_text: str, col_type: str = "",
|
||||
) -> Optional[Tuple[str, float, List[str]]]:
|
||||
"""Try to fix a single garbled gutter word via spellchecker.
|
||||
|
||||
Returns (best_correction, confidence, alternatives_list) or None.
|
||||
The alternatives list contains other plausible corrections the user
|
||||
can choose from (e.g. "stammelt" vs "stammeln").
|
||||
"""
|
||||
if len(word_text) < _MIN_WORD_LEN_SPELL:
|
||||
return None
|
||||
|
||||
# Strip trailing/leading parentheses and check if the bare word is valid.
|
||||
# Words like "probieren)" or "(Englisch" are valid words with punctuation,
|
||||
# not OCR errors. Don't suggest corrections for them.
|
||||
stripped = word_text.strip("()")
|
||||
if stripped and _is_known(stripped):
|
||||
return None
|
||||
|
||||
# Determine language priority from column type
|
||||
if "en" in col_type:
|
||||
lang = "en"
|
||||
elif "de" in col_type:
|
||||
lang = "de"
|
||||
else:
|
||||
lang = "both"
|
||||
|
||||
candidates = _spell_candidates(word_text, lang=lang)
|
||||
if not candidates and lang != "both":
|
||||
candidates = _spell_candidates(word_text, lang="both")
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Preserve original casing
|
||||
is_upper = word_text[0].isupper()
|
||||
|
||||
def _preserve_case(w: str) -> str:
|
||||
if is_upper and w:
|
||||
return w[0].upper() + w[1:]
|
||||
return w
|
||||
|
||||
# Sort candidates by edit distance (closest first)
|
||||
scored = []
|
||||
for c in candidates:
|
||||
dist = _edit_distance(word_text.lower(), c.lower())
|
||||
scored.append((dist, c))
|
||||
scored.sort(key=lambda x: x[0])
|
||||
|
||||
best_dist, best = scored[0]
|
||||
best = _preserve_case(best)
|
||||
conf = max(0.5, 1.0 - best_dist * 0.15)
|
||||
|
||||
# Build alternatives (all other candidates, also case-preserved)
|
||||
alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()]
|
||||
# Limit to top 5 alternatives
|
||||
alts = alts[:5]
|
||||
|
||||
return (best, conf, alts)
|
||||
|
||||
|
||||
def _edit_distance(a: str, b: str) -> int:
|
||||
"""Simple Levenshtein distance."""
|
||||
if len(a) < len(b):
|
||||
return _edit_distance(b, a)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
prev = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a):
|
||||
curr = [i + 1]
|
||||
for j, cb in enumerate(b):
|
||||
cost = 0 if ca == cb else 1
|
||||
curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost))
|
||||
prev = curr
|
||||
return prev[len(b)]
|
||||
356
klausur-service/backend/cv_gutter_repair_grid.py
Normal file
356
klausur-service/backend/cv_gutter_repair_grid.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
Gutter Repair Grid — grid analysis and suggestion application.
|
||||
|
||||
Extracted from cv_gutter_repair.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from cv_gutter_repair_core import (
|
||||
_init_spellcheckers,
|
||||
_is_ipa_text,
|
||||
_is_known,
|
||||
_MIN_WORD_LEN_HYPHEN,
|
||||
_SPELL_AVAILABLE,
|
||||
_STOPWORDS,
|
||||
_TRAILING_PUNCT_RE,
|
||||
_try_hyphen_join,
|
||||
_try_spell_fix,
|
||||
_word_is_at_gutter_edge,
|
||||
GutterSuggestion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grid analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def analyse_grid_for_gutter_repair(
|
||||
grid_data: Dict[str, Any],
|
||||
image_width: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyse a structured grid and return gutter repair suggestions.
|
||||
|
||||
Args:
|
||||
grid_data: The grid_editor_result from the session (zones→cells structure).
|
||||
image_width: Image width in pixels (for determining gutter side).
|
||||
|
||||
Returns:
|
||||
Dict with "suggestions" list and "stats".
|
||||
"""
|
||||
t0 = time.time()
|
||||
_init_spellcheckers()
|
||||
|
||||
if not _SPELL_AVAILABLE:
|
||||
return {
|
||||
"suggestions": [],
|
||||
"stats": {"error": "pyspellchecker not installed"},
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
|
||||
zones = grid_data.get("zones", [])
|
||||
suggestions: List[GutterSuggestion] = []
|
||||
words_checked = 0
|
||||
gutter_candidates = 0
|
||||
|
||||
for zi, zone in enumerate(zones):
|
||||
columns = zone.get("columns", [])
|
||||
cells = zone.get("cells", [])
|
||||
if not columns or not cells:
|
||||
continue
|
||||
|
||||
# Build column lookup: col_index → {x, width, type}
|
||||
col_info: Dict[int, Dict] = {}
|
||||
for col in columns:
|
||||
ci = col.get("index", col.get("col_index", -1))
|
||||
col_info[ci] = {
|
||||
"x": col.get("x_min_px", col.get("x", 0)),
|
||||
"width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)),
|
||||
"type": col.get("type", col.get("col_type", "")),
|
||||
}
|
||||
|
||||
# Build row→col→cell lookup
|
||||
cell_map: Dict[Tuple[int, int], Dict] = {}
|
||||
max_row = 0
|
||||
for cell in cells:
|
||||
ri = cell.get("row_index", 0)
|
||||
ci = cell.get("col_index", 0)
|
||||
cell_map[(ri, ci)] = cell
|
||||
if ri > max_row:
|
||||
max_row = ri
|
||||
|
||||
# Determine which columns are at the gutter edge.
|
||||
# For a left page: rightmost content columns.
|
||||
# For now, check ALL columns — a word is a candidate if it's at the
|
||||
# right edge of its column AND not a known word.
|
||||
for (ri, ci), cell in cell_map.items():
|
||||
text = (cell.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if _is_ipa_text(text):
|
||||
continue
|
||||
|
||||
words_checked += 1
|
||||
col = col_info.get(ci, {})
|
||||
col_type = col.get("type", "")
|
||||
|
||||
# Get word boxes to check position
|
||||
word_boxes = cell.get("word_boxes", [])
|
||||
|
||||
# Check the LAST word in the cell (rightmost, closest to gutter)
|
||||
cell_words = text.split()
|
||||
if not cell_words:
|
||||
continue
|
||||
|
||||
last_word = cell_words[-1]
|
||||
|
||||
# Skip stopwords
|
||||
if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS:
|
||||
continue
|
||||
|
||||
last_word_clean = last_word.rstrip(".,;:!?)(")
|
||||
if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN:
|
||||
continue
|
||||
|
||||
# Check if the last word is at the gutter edge
|
||||
is_at_edge = False
|
||||
if word_boxes:
|
||||
last_wb = word_boxes[-1]
|
||||
is_at_edge = _word_is_at_gutter_edge(
|
||||
last_wb, col.get("x", 0), col.get("width", 1)
|
||||
)
|
||||
else:
|
||||
# No word boxes — use cell bbox
|
||||
bbox = cell.get("bbox_px", {})
|
||||
is_at_edge = _word_is_at_gutter_edge(
|
||||
{"left": bbox.get("x", 0), "width": bbox.get("w", 0)},
|
||||
col.get("x", 0), col.get("width", 1)
|
||||
)
|
||||
|
||||
if not is_at_edge:
|
||||
continue
|
||||
|
||||
# Word is at gutter edge — check if it's a known word
|
||||
if _is_known(last_word_clean):
|
||||
continue
|
||||
|
||||
# Check if the word ends with "-" (explicit hyphen break)
|
||||
ends_with_hyphen = last_word.endswith("-")
|
||||
|
||||
# If the word already ends with "-" and the stem (without
|
||||
# the hyphen) is a known word, this is a VALID line-break
|
||||
# hyphenation — not a gutter error. Gutter problems cause
|
||||
# the hyphen to be LOST ("ve" instead of "ver-"), so a
|
||||
# visible hyphen + known stem = intentional word-wrap.
|
||||
# Example: "wunder-" → "wunder" is known → skip.
|
||||
if ends_with_hyphen:
|
||||
stem = last_word_clean.rstrip("-")
|
||||
if stem and _is_known(stem):
|
||||
continue
|
||||
|
||||
gutter_candidates += 1
|
||||
|
||||
# --- Strategy 1: Hyphen join with next row ---
|
||||
next_cell = cell_map.get((ri + 1, ci))
|
||||
if next_cell:
|
||||
next_text = (next_cell.get("text") or "").strip()
|
||||
next_words = next_text.split()
|
||||
if next_words:
|
||||
first_next = next_words[0]
|
||||
first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next)
|
||||
first_alpha = next((c for c in first_next if c.isalpha()), "")
|
||||
|
||||
# Also skip if the joined word is known (covers compound
|
||||
# words where the stem alone might not be in the dictionary)
|
||||
if ends_with_hyphen and first_next_clean:
|
||||
direct = last_word_clean.rstrip("-") + first_next_clean
|
||||
if _is_known(direct):
|
||||
continue
|
||||
|
||||
# Continuation likely if:
|
||||
# - explicit hyphen, OR
|
||||
# - next row starts lowercase (= not a new entry)
|
||||
if ends_with_hyphen or (first_alpha and first_alpha.islower()):
|
||||
result = _try_hyphen_join(last_word_clean, first_next)
|
||||
if result:
|
||||
joined, missing, conf = result
|
||||
# Build display parts: show hyphenation for original layout
|
||||
if ends_with_hyphen:
|
||||
display_p1 = last_word_clean.rstrip("-")
|
||||
if missing:
|
||||
display_p1 += missing
|
||||
display_p1 += "-"
|
||||
else:
|
||||
display_p1 = last_word_clean
|
||||
if missing:
|
||||
display_p1 += missing + "-"
|
||||
else:
|
||||
display_p1 += "-"
|
||||
|
||||
suggestion = GutterSuggestion(
|
||||
type="hyphen_join",
|
||||
zone_index=zi,
|
||||
row_index=ri,
|
||||
col_index=ci,
|
||||
col_type=col_type,
|
||||
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
|
||||
original_text=last_word,
|
||||
suggested_text=joined,
|
||||
next_row_index=ri + 1,
|
||||
next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"),
|
||||
next_row_text=next_text,
|
||||
missing_chars=missing,
|
||||
display_parts=[display_p1, first_next],
|
||||
confidence=conf,
|
||||
reason="gutter_truncation" if missing else "hyphen_continuation",
|
||||
)
|
||||
suggestions.append(suggestion)
|
||||
continue # skip spell_fix if hyphen_join found
|
||||
|
||||
# --- Strategy 2: Single-word spell fix (only for longer words) ---
|
||||
fix_result = _try_spell_fix(last_word_clean, col_type)
|
||||
if fix_result:
|
||||
corrected, conf, alts = fix_result
|
||||
suggestion = GutterSuggestion(
|
||||
type="spell_fix",
|
||||
zone_index=zi,
|
||||
row_index=ri,
|
||||
col_index=ci,
|
||||
col_type=col_type,
|
||||
cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
|
||||
original_text=last_word,
|
||||
suggested_text=corrected,
|
||||
alternatives=alts,
|
||||
confidence=conf,
|
||||
reason="gutter_blur",
|
||||
)
|
||||
suggestions.append(suggestion)
|
||||
|
||||
duration = round(time.time() - t0, 3)
|
||||
|
||||
logger.info(
|
||||
"Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)",
|
||||
words_checked, gutter_candidates, len(suggestions), duration,
|
||||
)
|
||||
|
||||
return {
|
||||
"suggestions": [s.to_dict() for s in suggestions],
|
||||
"stats": {
|
||||
"words_checked": words_checked,
|
||||
"gutter_candidates": gutter_candidates,
|
||||
"suggestions_found": len(suggestions),
|
||||
},
|
||||
"duration_seconds": duration,
|
||||
}
|
||||
|
||||
|
||||
def apply_gutter_suggestions(
|
||||
grid_data: Dict[str, Any],
|
||||
accepted_ids: List[str],
|
||||
suggestions: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply accepted gutter repair suggestions to the grid data.
|
||||
|
||||
Modifies cells in-place and returns summary of changes.
|
||||
|
||||
Args:
|
||||
grid_data: The grid_editor_result (zones→cells).
|
||||
accepted_ids: List of suggestion IDs the user accepted.
|
||||
suggestions: The full suggestions list (from analyse_grid_for_gutter_repair).
|
||||
|
||||
Returns:
|
||||
Dict with "applied_count" and "changes" list.
|
||||
"""
|
||||
accepted_set = set(accepted_ids)
|
||||
accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set]
|
||||
|
||||
zones = grid_data.get("zones", [])
|
||||
changes: List[Dict[str, Any]] = []
|
||||
|
||||
for s in accepted_suggestions:
|
||||
zi = s.get("zone_index", 0)
|
||||
ri = s.get("row_index", 0)
|
||||
ci = s.get("col_index", 0)
|
||||
stype = s.get("type", "")
|
||||
|
||||
if zi >= len(zones):
|
||||
continue
|
||||
zone_cells = zones[zi].get("cells", [])
|
||||
|
||||
# Find the target cell
|
||||
target_cell = None
|
||||
for cell in zone_cells:
|
||||
if cell.get("row_index") == ri and cell.get("col_index") == ci:
|
||||
target_cell = cell
|
||||
break
|
||||
|
||||
if not target_cell:
|
||||
continue
|
||||
|
||||
old_text = target_cell.get("text", "")
|
||||
|
||||
if stype == "spell_fix":
|
||||
# Replace the last word in the cell text
|
||||
original_word = s.get("original_text", "")
|
||||
corrected = s.get("suggested_text", "")
|
||||
if original_word and corrected:
|
||||
# Replace from the right (last occurrence)
|
||||
idx = old_text.rfind(original_word)
|
||||
if idx >= 0:
|
||||
new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):]
|
||||
target_cell["text"] = new_text
|
||||
changes.append({
|
||||
"type": "spell_fix",
|
||||
"zone_index": zi,
|
||||
"row_index": ri,
|
||||
"col_index": ci,
|
||||
"cell_id": target_cell.get("cell_id", ""),
|
||||
"old_text": old_text,
|
||||
"new_text": new_text,
|
||||
})
|
||||
|
||||
elif stype == "hyphen_join":
|
||||
# Current cell: replace last word with the hyphenated first part
|
||||
original_word = s.get("original_text", "")
|
||||
joined = s.get("suggested_text", "")
|
||||
display_parts = s.get("display_parts", [])
|
||||
next_ri = s.get("next_row_index", -1)
|
||||
|
||||
if not original_word or not joined or not display_parts:
|
||||
continue
|
||||
|
||||
# The first display part is what goes in the current row
|
||||
first_part = display_parts[0] if display_parts else ""
|
||||
|
||||
# Replace the last word in current cell with the restored form.
|
||||
# The next row is NOT modified — "künden" stays in its row
|
||||
# because the original book layout has it there. We only fix
|
||||
# the truncated word in the current row (e.g. "ve" → "ver-").
|
||||
idx = old_text.rfind(original_word)
|
||||
if idx >= 0:
|
||||
new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):]
|
||||
target_cell["text"] = new_text
|
||||
changes.append({
|
||||
"type": "hyphen_join",
|
||||
"zone_index": zi,
|
||||
"row_index": ri,
|
||||
"col_index": ci,
|
||||
"cell_id": target_cell.get("cell_id", ""),
|
||||
"old_text": old_text,
|
||||
"new_text": new_text,
|
||||
"joined_word": joined,
|
||||
})
|
||||
|
||||
logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions))
|
||||
|
||||
return {
|
||||
"applied_count": len(accepted_suggestions),
|
||||
"changes": changes,
|
||||
}
|
||||
231
klausur-service/backend/cv_syllable_core.py
Normal file
231
klausur-service/backend/cv_syllable_core.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Syllable Core — hyphenator init, word validation, pipe autocorrect.
|
||||
|
||||
Extracted from cv_syllable_detect.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# IPA/phonetic characters -- skip cells containing these
|
||||
_IPA_RE = re.compile(r'[\[\]\u02c8\u02cc\u02d0\u0283\u0292\u03b8\u00f0\u014b\u0251\u0252\u00e6\u0254\u0259\u025b\u025c\u026a\u028a\u028c]')
|
||||
|
||||
# Common German words that should NOT be merged with adjacent tokens.
|
||||
_STOP_WORDS = frozenset([
|
||||
# Articles
|
||||
'der', 'die', 'das', 'dem', 'den', 'des',
|
||||
'ein', 'eine', 'einem', 'einen', 'einer',
|
||||
# Pronouns
|
||||
'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich',
|
||||
'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn',
|
||||
# Prepositions
|
||||
'mit', 'von', 'zu', 'f\u00fcr', 'auf', 'in', 'an', 'um', 'am', 'im',
|
||||
'aus', 'bei', 'nach', 'vor', 'bis', 'durch', '\u00fcber', 'unter',
|
||||
'zwischen', 'ohne', 'gegen',
|
||||
# Conjunctions
|
||||
'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber',
|
||||
# Adverbs
|
||||
'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht',
|
||||
# Verbs
|
||||
'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf',
|
||||
'sein', 'haben',
|
||||
# Other
|
||||
'kein', 'keine', 'keinem', 'keinen', 'keiner',
|
||||
])
|
||||
|
||||
# Cached hyphenators
|
||||
_hyph_de = None
|
||||
_hyph_en = None
|
||||
|
||||
# Cached spellchecker (for autocorrect_pipe_artifacts)
|
||||
_spell_de = None
|
||||
|
||||
|
||||
def _get_hyphenators():
|
||||
"""Lazy-load pyphen hyphenators (cached across calls)."""
|
||||
global _hyph_de, _hyph_en
|
||||
if _hyph_de is not None:
|
||||
return _hyph_de, _hyph_en
|
||||
try:
|
||||
import pyphen
|
||||
except ImportError:
|
||||
return None, None
|
||||
_hyph_de = pyphen.Pyphen(lang='de_DE')
|
||||
_hyph_en = pyphen.Pyphen(lang='en_US')
|
||||
return _hyph_de, _hyph_en
|
||||
|
||||
|
||||
def _get_spellchecker():
|
||||
"""Lazy-load German spellchecker (cached across calls)."""
|
||||
global _spell_de
|
||||
if _spell_de is not None:
|
||||
return _spell_de
|
||||
try:
|
||||
from spellchecker import SpellChecker
|
||||
except ImportError:
|
||||
return None
|
||||
_spell_de = SpellChecker(language='de')
|
||||
return _spell_de
|
||||
|
||||
|
||||
def _is_known_word(word: str, hyph_de, hyph_en) -> bool:
|
||||
"""Check whether pyphen recognises a word (DE or EN)."""
|
||||
if len(word) < 2:
|
||||
return False
|
||||
return ('|' in hyph_de.inserted(word, hyphen='|')
|
||||
or '|' in hyph_en.inserted(word, hyphen='|'))
|
||||
|
||||
|
||||
def _is_real_word(word: str) -> bool:
|
||||
"""Check whether spellchecker knows this word (case-insensitive)."""
|
||||
spell = _get_spellchecker()
|
||||
if spell is None:
|
||||
return False
|
||||
return word.lower() in spell
|
||||
|
||||
|
||||
def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]:
|
||||
"""Try to hyphenate a word using DE then EN dictionary.
|
||||
|
||||
Returns word with | separators, or None if not recognized.
|
||||
"""
|
||||
hyph = hyph_de.inserted(word, hyphen='|')
|
||||
if '|' in hyph:
|
||||
return hyph
|
||||
hyph = hyph_en.inserted(word, hyphen='|')
|
||||
if '|' in hyph:
|
||||
return hyph
|
||||
return None
|
||||
|
||||
|
||||
def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]:
|
||||
"""Try to correct a word that has OCR pipe artifacts.
|
||||
|
||||
Printed syllable divider lines on dictionary pages confuse OCR:
|
||||
the vertical stroke is often read as an extra character (commonly
|
||||
``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears.
|
||||
|
||||
Uses ``spellchecker`` (frequency-based word list) for validation.
|
||||
|
||||
Strategy:
|
||||
1. Strip ``|`` -- if spellchecker knows the result, done.
|
||||
2. Try deleting each pipe-like character (l, I, 1, i, t).
|
||||
3. Fall back to spellchecker's own ``correction()`` method.
|
||||
4. Preserve the original casing of the first letter.
|
||||
"""
|
||||
stripped = word_with_pipes.replace('|', '')
|
||||
if not stripped or len(stripped) < 3:
|
||||
return stripped # too short to validate
|
||||
|
||||
# Step 1: if the stripped word is already a real word, done
|
||||
if _is_real_word(stripped):
|
||||
return stripped
|
||||
|
||||
# Step 2: try deleting pipe-like characters (most likely artifacts)
|
||||
_PIPE_LIKE = frozenset('lI1it')
|
||||
for idx in range(len(stripped)):
|
||||
if stripped[idx] not in _PIPE_LIKE:
|
||||
continue
|
||||
candidate = stripped[:idx] + stripped[idx + 1:]
|
||||
if len(candidate) >= 3 and _is_real_word(candidate):
|
||||
return candidate
|
||||
|
||||
# Step 3: use spellchecker's built-in correction
|
||||
spell = _get_spellchecker()
|
||||
if spell is not None:
|
||||
suggestion = spell.correction(stripped.lower())
|
||||
if suggestion and suggestion != stripped.lower():
|
||||
# Preserve original first-letter case
|
||||
if stripped[0].isupper():
|
||||
suggestion = suggestion[0].upper() + suggestion[1:]
|
||||
return suggestion
|
||||
|
||||
return None # could not fix
|
||||
|
||||
|
||||
def autocorrect_pipe_artifacts(
|
||||
zones_data: List[Dict], session_id: str,
|
||||
) -> int:
|
||||
"""Strip OCR pipe artifacts and correct garbled words in-place.
|
||||
|
||||
Printed syllable divider lines on dictionary scans are read by OCR
|
||||
as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``).
|
||||
This function:
|
||||
|
||||
1. Strips ``|`` from every word in content cells.
|
||||
2. Validates with spellchecker (real dictionary lookup).
|
||||
3. If not recognised, tries deleting pipe-like characters or uses
|
||||
spellchecker's correction (e.g. ``Zeplpelin`` -> ``Zeppelin``).
|
||||
4. Updates both word-box texts and cell text.
|
||||
|
||||
Returns the number of cells modified.
|
||||
"""
|
||||
spell = _get_spellchecker()
|
||||
if spell is None:
|
||||
logger.warning("spellchecker not available -- pipe autocorrect limited")
|
||||
# Fall back: still strip pipes even without spellchecker
|
||||
pass
|
||||
|
||||
modified = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
ct = cell.get("col_type", "")
|
||||
if not ct.startswith("column_"):
|
||||
continue
|
||||
|
||||
cell_changed = False
|
||||
|
||||
# --- Fix word boxes ---
|
||||
for wb in cell.get("word_boxes", []):
|
||||
wb_text = wb.get("text", "")
|
||||
if "|" not in wb_text:
|
||||
continue
|
||||
|
||||
# Separate trailing punctuation
|
||||
m = re.match(
|
||||
r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)'
|
||||
r'(.*?)'
|
||||
r'([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$',
|
||||
wb_text,
|
||||
)
|
||||
if not m:
|
||||
continue
|
||||
lead, core, trail = m.group(1), m.group(2), m.group(3)
|
||||
if "|" not in core:
|
||||
continue
|
||||
|
||||
corrected = _autocorrect_piped_word(core)
|
||||
if corrected is not None and corrected != core:
|
||||
wb["text"] = lead + corrected + trail
|
||||
cell_changed = True
|
||||
|
||||
# --- Rebuild cell text from word boxes ---
|
||||
if cell_changed:
|
||||
wbs = cell.get("word_boxes", [])
|
||||
if wbs:
|
||||
cell["text"] = " ".join(
|
||||
(wb.get("text") or "") for wb in wbs
|
||||
)
|
||||
modified += 1
|
||||
|
||||
# --- Fallback: strip residual | from cell text ---
|
||||
text = cell.get("text", "")
|
||||
if "|" in text:
|
||||
clean = text.replace("|", "")
|
||||
if clean != text:
|
||||
cell["text"] = clean
|
||||
if not cell_changed:
|
||||
modified += 1
|
||||
|
||||
if modified:
|
||||
logger.info(
|
||||
"build-grid session %s: autocorrected pipe artifacts in %d cells",
|
||||
session_id, modified,
|
||||
)
|
||||
return modified
|
||||
@@ -1,532 +1,32 @@
|
||||
"""
|
||||
Syllable divider insertion for dictionary pages.
|
||||
Syllable divider insertion for dictionary pages — barrel re-export.
|
||||
|
||||
For confirmed dictionary pages (is_dictionary=True), processes all content
|
||||
column cells:
|
||||
1. Strips existing | dividers for clean normalization
|
||||
2. Merges pipe-gap spaces (where OCR split a word at a divider position)
|
||||
3. Applies pyphen syllabification to each word >= 3 alpha chars (DE then EN)
|
||||
4. Only modifies words that pyphen recognizes — garbled OCR stays as-is
|
||||
|
||||
No CV gate needed — the dictionary detection confidence is sufficient.
|
||||
pyphen uses Hunspell/TeX hyphenation dictionaries and is very reliable.
|
||||
All implementation split into:
|
||||
cv_syllable_core — hyphenator init, word validation, pipe autocorrect
|
||||
cv_syllable_merge — word gap merging, syllabification, divider insertion
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# IPA/phonetic characters — skip cells containing these
|
||||
_IPA_RE = re.compile(r'[\[\]ˈˌːʃʒθðŋɑɒæɔəɛɜɪʊʌ]')
|
||||
|
||||
# Common German words that should NOT be merged with adjacent tokens.
|
||||
# These are function words that appear as standalone words between
|
||||
# headwords/definitions on dictionary pages.
|
||||
_STOP_WORDS = frozenset([
|
||||
# Articles
|
||||
'der', 'die', 'das', 'dem', 'den', 'des',
|
||||
'ein', 'eine', 'einem', 'einen', 'einer',
|
||||
# Pronouns
|
||||
'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich',
|
||||
'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn',
|
||||
# Prepositions
|
||||
'mit', 'von', 'zu', 'für', 'auf', 'in', 'an', 'um', 'am', 'im',
|
||||
'aus', 'bei', 'nach', 'vor', 'bis', 'durch', 'über', 'unter',
|
||||
'zwischen', 'ohne', 'gegen',
|
||||
# Conjunctions
|
||||
'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber',
|
||||
# Adverbs
|
||||
'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht',
|
||||
# Verbs
|
||||
'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf',
|
||||
'sein', 'haben',
|
||||
# Other
|
||||
'kein', 'keine', 'keinem', 'keinen', 'keiner',
|
||||
])
|
||||
|
||||
# Cached hyphenators
|
||||
_hyph_de = None
|
||||
_hyph_en = None
|
||||
|
||||
# Cached spellchecker (for autocorrect_pipe_artifacts)
|
||||
_spell_de = None
|
||||
|
||||
|
||||
def _get_hyphenators():
|
||||
"""Lazy-load pyphen hyphenators (cached across calls)."""
|
||||
global _hyph_de, _hyph_en
|
||||
if _hyph_de is not None:
|
||||
return _hyph_de, _hyph_en
|
||||
try:
|
||||
import pyphen
|
||||
except ImportError:
|
||||
return None, None
|
||||
_hyph_de = pyphen.Pyphen(lang='de_DE')
|
||||
_hyph_en = pyphen.Pyphen(lang='en_US')
|
||||
return _hyph_de, _hyph_en
|
||||
|
||||
|
||||
def _get_spellchecker():
|
||||
"""Lazy-load German spellchecker (cached across calls)."""
|
||||
global _spell_de
|
||||
if _spell_de is not None:
|
||||
return _spell_de
|
||||
try:
|
||||
from spellchecker import SpellChecker
|
||||
except ImportError:
|
||||
return None
|
||||
_spell_de = SpellChecker(language='de')
|
||||
return _spell_de
|
||||
|
||||
|
||||
def _is_known_word(word: str, hyph_de, hyph_en) -> bool:
|
||||
"""Check whether pyphen recognises a word (DE or EN)."""
|
||||
if len(word) < 2:
|
||||
return False
|
||||
return ('|' in hyph_de.inserted(word, hyphen='|')
|
||||
or '|' in hyph_en.inserted(word, hyphen='|'))
|
||||
|
||||
|
||||
def _is_real_word(word: str) -> bool:
|
||||
"""Check whether spellchecker knows this word (case-insensitive)."""
|
||||
spell = _get_spellchecker()
|
||||
if spell is None:
|
||||
return False
|
||||
return word.lower() in spell
|
||||
|
||||
|
||||
def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]:
|
||||
"""Try to hyphenate a word using DE then EN dictionary.
|
||||
|
||||
Returns word with | separators, or None if not recognized.
|
||||
"""
|
||||
hyph = hyph_de.inserted(word, hyphen='|')
|
||||
if '|' in hyph:
|
||||
return hyph
|
||||
hyph = hyph_en.inserted(word, hyphen='|')
|
||||
if '|' in hyph:
|
||||
return hyph
|
||||
return None
|
||||
|
||||
|
||||
def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]:
|
||||
"""Try to correct a word that has OCR pipe artifacts.
|
||||
|
||||
Printed syllable divider lines on dictionary pages confuse OCR:
|
||||
the vertical stroke is often read as an extra character (commonly
|
||||
``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears.
|
||||
Sometimes OCR reads one divider as ``|`` and another as a letter,
|
||||
so the garbled character may be far from any detected pipe.
|
||||
|
||||
Uses ``spellchecker`` (frequency-based word list) for validation —
|
||||
unlike pyphen which is a pattern-based hyphenator and accepts
|
||||
nonsense strings like "Zeplpelin".
|
||||
|
||||
Strategy:
|
||||
1. Strip ``|`` — if spellchecker knows the result, done.
|
||||
2. Try deleting each pipe-like character (l, I, 1, i, t).
|
||||
OCR inserts extra chars that resemble vertical strokes.
|
||||
3. Fall back to spellchecker's own ``correction()`` method.
|
||||
4. Preserve the original casing of the first letter.
|
||||
"""
|
||||
stripped = word_with_pipes.replace('|', '')
|
||||
if not stripped or len(stripped) < 3:
|
||||
return stripped # too short to validate
|
||||
|
||||
# Step 1: if the stripped word is already a real word, done
|
||||
if _is_real_word(stripped):
|
||||
return stripped
|
||||
|
||||
# Step 2: try deleting pipe-like characters (most likely artifacts)
|
||||
_PIPE_LIKE = frozenset('lI1it')
|
||||
for idx in range(len(stripped)):
|
||||
if stripped[idx] not in _PIPE_LIKE:
|
||||
continue
|
||||
candidate = stripped[:idx] + stripped[idx + 1:]
|
||||
if len(candidate) >= 3 and _is_real_word(candidate):
|
||||
return candidate
|
||||
|
||||
# Step 3: use spellchecker's built-in correction
|
||||
spell = _get_spellchecker()
|
||||
if spell is not None:
|
||||
suggestion = spell.correction(stripped.lower())
|
||||
if suggestion and suggestion != stripped.lower():
|
||||
# Preserve original first-letter case
|
||||
if stripped[0].isupper():
|
||||
suggestion = suggestion[0].upper() + suggestion[1:]
|
||||
return suggestion
|
||||
|
||||
return None # could not fix
|
||||
|
||||
|
||||
def autocorrect_pipe_artifacts(
|
||||
zones_data: List[Dict], session_id: str,
|
||||
) -> int:
|
||||
"""Strip OCR pipe artifacts and correct garbled words in-place.
|
||||
|
||||
Printed syllable divider lines on dictionary scans are read by OCR
|
||||
as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``).
|
||||
This function:
|
||||
|
||||
1. Strips ``|`` from every word in content cells.
|
||||
2. Validates with spellchecker (real dictionary lookup).
|
||||
3. If not recognised, tries deleting pipe-like characters or uses
|
||||
spellchecker's correction (e.g. ``Zeplpelin`` → ``Zeppelin``).
|
||||
4. Updates both word-box texts and cell text.
|
||||
|
||||
Returns the number of cells modified.
|
||||
"""
|
||||
spell = _get_spellchecker()
|
||||
if spell is None:
|
||||
logger.warning("spellchecker not available — pipe autocorrect limited")
|
||||
# Fall back: still strip pipes even without spellchecker
|
||||
pass
|
||||
|
||||
modified = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
ct = cell.get("col_type", "")
|
||||
if not ct.startswith("column_"):
|
||||
continue
|
||||
|
||||
cell_changed = False
|
||||
|
||||
# --- Fix word boxes ---
|
||||
for wb in cell.get("word_boxes", []):
|
||||
wb_text = wb.get("text", "")
|
||||
if "|" not in wb_text:
|
||||
continue
|
||||
|
||||
# Separate trailing punctuation
|
||||
m = re.match(
|
||||
r'^([^a-zA-ZäöüÄÖÜßẞ]*)'
|
||||
r'(.*?)'
|
||||
r'([^a-zA-ZäöüÄÖÜßẞ]*)$',
|
||||
wb_text,
|
||||
)
|
||||
if not m:
|
||||
continue
|
||||
lead, core, trail = m.group(1), m.group(2), m.group(3)
|
||||
if "|" not in core:
|
||||
continue
|
||||
|
||||
corrected = _autocorrect_piped_word(core)
|
||||
if corrected is not None and corrected != core:
|
||||
wb["text"] = lead + corrected + trail
|
||||
cell_changed = True
|
||||
|
||||
# --- Rebuild cell text from word boxes ---
|
||||
if cell_changed:
|
||||
wbs = cell.get("word_boxes", [])
|
||||
if wbs:
|
||||
cell["text"] = " ".join(
|
||||
(wb.get("text") or "") for wb in wbs
|
||||
)
|
||||
modified += 1
|
||||
|
||||
# --- Fallback: strip residual | from cell text ---
|
||||
# (covers cases where word_boxes don't exist or weren't fixed)
|
||||
text = cell.get("text", "")
|
||||
if "|" in text:
|
||||
clean = text.replace("|", "")
|
||||
if clean != text:
|
||||
cell["text"] = clean
|
||||
if not cell_changed:
|
||||
modified += 1
|
||||
|
||||
if modified:
|
||||
logger.info(
|
||||
"build-grid session %s: autocorrected pipe artifacts in %d cells",
|
||||
session_id, modified,
|
||||
)
|
||||
return modified
|
||||
|
||||
|
||||
def _try_merge_pipe_gaps(text: str, hyph_de) -> str:
|
||||
"""Merge fragments separated by single spaces where OCR split at a pipe.
|
||||
|
||||
Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word).
|
||||
Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau".
|
||||
|
||||
Guards against false merges:
|
||||
- The FIRST token must be pure alpha (word start — no attached punctuation)
|
||||
- The second token may have trailing punctuation (comma, period) which
|
||||
stays attached to the merged word: "Kä" + "fer," -> "Käfer,"
|
||||
- Common German function words (der, die, das, ...) are never merged
|
||||
- At least one fragment must be very short (<=3 alpha chars)
|
||||
"""
|
||||
parts = text.split(' ')
|
||||
if len(parts) < 2:
|
||||
return text
|
||||
|
||||
result = [parts[0]]
|
||||
i = 1
|
||||
while i < len(parts):
|
||||
prev = result[-1]
|
||||
curr = parts[i]
|
||||
|
||||
# Extract alpha-only core for lookup
|
||||
prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev)
|
||||
curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr)
|
||||
|
||||
# Guard 1: first token must be pure alpha (word-start fragment)
|
||||
# second token may have trailing punctuation
|
||||
# Guard 2: neither alpha core can be a common German function word
|
||||
# Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal)
|
||||
# Guard 4: combined length must be >= 4
|
||||
should_try = (
|
||||
prev == prev_alpha # first token: pure alpha (word start)
|
||||
and prev_alpha and curr_alpha
|
||||
and prev_alpha.lower() not in _STOP_WORDS
|
||||
and curr_alpha.lower() not in _STOP_WORDS
|
||||
and min(len(prev_alpha), len(curr_alpha)) <= 3
|
||||
and len(prev_alpha) + len(curr_alpha) >= 4
|
||||
)
|
||||
|
||||
if should_try:
|
||||
merged_alpha = prev_alpha + curr_alpha
|
||||
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
|
||||
if '-' in hyph:
|
||||
# pyphen recognizes merged word — collapse the space
|
||||
result[-1] = prev + curr
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(curr)
|
||||
i += 1
|
||||
|
||||
return ' '.join(result)
|
||||
|
||||
|
||||
def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int:
|
||||
"""Merge OCR word-gap fragments in cell texts using pyphen validation.
|
||||
|
||||
OCR often splits words at syllable boundaries into separate word_boxes,
|
||||
producing text like "zerknit tert" instead of "zerknittert". This
|
||||
function tries to merge adjacent fragments in every content cell.
|
||||
|
||||
More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3)
|
||||
but still guarded by pyphen dictionary lookup and stop-word exclusion.
|
||||
|
||||
Returns the number of cells modified.
|
||||
"""
|
||||
hyph_de, _ = _get_hyphenators()
|
||||
if hyph_de is None:
|
||||
return 0
|
||||
|
||||
modified = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
ct = cell.get("col_type", "")
|
||||
if not ct.startswith("column_"):
|
||||
continue
|
||||
text = cell.get("text", "")
|
||||
if not text or " " not in text:
|
||||
continue
|
||||
|
||||
# Skip IPA cells
|
||||
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
|
||||
if _IPA_RE.search(text_no_brackets):
|
||||
continue
|
||||
|
||||
new_text = _try_merge_word_gaps(text, hyph_de)
|
||||
if new_text != text:
|
||||
cell["text"] = new_text
|
||||
modified += 1
|
||||
|
||||
if modified:
|
||||
logger.info(
|
||||
"build-grid session %s: merged word gaps in %d cells",
|
||||
session_id, modified,
|
||||
)
|
||||
return modified
|
||||
|
||||
|
||||
def _try_merge_word_gaps(text: str, hyph_de) -> str:
|
||||
"""Merge OCR word fragments with relaxed threshold (max_short=5).
|
||||
|
||||
Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments
|
||||
(max_short=5 instead of 3). Still requires pyphen to recognize the
|
||||
merged word.
|
||||
"""
|
||||
parts = text.split(' ')
|
||||
if len(parts) < 2:
|
||||
return text
|
||||
|
||||
result = [parts[0]]
|
||||
i = 1
|
||||
while i < len(parts):
|
||||
prev = result[-1]
|
||||
curr = parts[i]
|
||||
|
||||
prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev)
|
||||
curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr)
|
||||
|
||||
should_try = (
|
||||
prev == prev_alpha
|
||||
and prev_alpha and curr_alpha
|
||||
and prev_alpha.lower() not in _STOP_WORDS
|
||||
and curr_alpha.lower() not in _STOP_WORDS
|
||||
and min(len(prev_alpha), len(curr_alpha)) <= 5
|
||||
and len(prev_alpha) + len(curr_alpha) >= 4
|
||||
)
|
||||
|
||||
if should_try:
|
||||
merged_alpha = prev_alpha + curr_alpha
|
||||
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
|
||||
if '-' in hyph:
|
||||
result[-1] = prev + curr
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(curr)
|
||||
i += 1
|
||||
|
||||
return ' '.join(result)
|
||||
|
||||
|
||||
def _syllabify_text(text: str, hyph_de, hyph_en) -> str:
|
||||
"""Syllabify all significant words in a text string.
|
||||
|
||||
1. Strip existing | dividers
|
||||
2. Merge pipe-gap spaces where possible
|
||||
3. Apply pyphen to each word >= 3 alphabetic chars
|
||||
4. Words pyphen doesn't recognize stay as-is (no bad guesses)
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Skip cells that contain IPA transcription characters outside brackets.
|
||||
# Bracket content like [bɪltʃøn] is programmatically inserted and should
|
||||
# not block syllabification of the surrounding text.
|
||||
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
|
||||
if _IPA_RE.search(text_no_brackets):
|
||||
return text
|
||||
|
||||
# Phase 1: strip existing pipe dividers for clean normalization
|
||||
clean = text.replace('|', '')
|
||||
|
||||
# Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting)
|
||||
clean = _try_merge_pipe_gaps(clean, hyph_de)
|
||||
|
||||
# Phase 3: tokenize and syllabify each word
|
||||
# Split on whitespace and comma/semicolon sequences, keeping separators
|
||||
tokens = re.split(r'(\s+|[,;:]+\s*)', clean)
|
||||
|
||||
result = []
|
||||
for tok in tokens:
|
||||
if not tok or re.match(r'^[\s,;:]+$', tok):
|
||||
result.append(tok)
|
||||
continue
|
||||
|
||||
# Strip trailing/leading punctuation for pyphen lookup
|
||||
m = re.match(r'^([^a-zA-ZäöüÄÖÜßẞ]*)(.*?)([^a-zA-ZäöüÄÖÜßẞ]*)$', tok)
|
||||
if not m:
|
||||
result.append(tok)
|
||||
continue
|
||||
lead, word, trail = m.group(1), m.group(2), m.group(3)
|
||||
|
||||
if len(word) < 3 or not re.search(r'[a-zA-ZäöüÄÖÜß]', word):
|
||||
result.append(tok)
|
||||
continue
|
||||
|
||||
hyph = _hyphenate_word(word, hyph_de, hyph_en)
|
||||
if hyph:
|
||||
result.append(lead + hyph + trail)
|
||||
else:
|
||||
result.append(tok)
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def insert_syllable_dividers(
|
||||
zones_data: List[Dict],
|
||||
img_bgr: np.ndarray,
|
||||
session_id: str,
|
||||
*,
|
||||
force: bool = False,
|
||||
col_filter: Optional[set] = None,
|
||||
) -> int:
|
||||
"""Insert pipe syllable dividers into dictionary cells.
|
||||
|
||||
For dictionary pages: process all content column cells, strip existing
|
||||
pipes, merge pipe-gap spaces, and re-syllabify using pyphen.
|
||||
|
||||
Pre-check: at least 1% of content cells must already contain ``|`` from
|
||||
OCR. This guards against pages with zero pipe characters (the primary
|
||||
guard — article_col_index — is checked at the call site).
|
||||
|
||||
Args:
|
||||
force: If True, skip the pipe-ratio pre-check and syllabify all
|
||||
content words regardless of whether the original has pipe dividers.
|
||||
col_filter: If set, only process cells whose col_type is in this set.
|
||||
None means process all content columns.
|
||||
|
||||
Returns the number of cells modified.
|
||||
"""
|
||||
hyph_de, hyph_en = _get_hyphenators()
|
||||
if hyph_de is None:
|
||||
logger.warning("pyphen not installed — skipping syllable insertion")
|
||||
return 0
|
||||
|
||||
# Pre-check: count cells that already have | from OCR.
|
||||
# Real dictionary pages with printed syllable dividers will have OCR-
|
||||
# detected pipes in many cells. Pages without syllable dividers will
|
||||
# have zero — skip those to avoid false syllabification.
|
||||
if not force:
|
||||
total_col_cells = 0
|
||||
cells_with_pipes = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
if cell.get("col_type", "").startswith("column_"):
|
||||
total_col_cells += 1
|
||||
if "|" in cell.get("text", ""):
|
||||
cells_with_pipes += 1
|
||||
|
||||
if total_col_cells > 0:
|
||||
pipe_ratio = cells_with_pipes / total_col_cells
|
||||
if pipe_ratio < 0.01:
|
||||
logger.info(
|
||||
"build-grid session %s: skipping syllable insertion — "
|
||||
"only %.1f%% of cells have existing pipes (need >=1%%)",
|
||||
session_id, pipe_ratio * 100,
|
||||
)
|
||||
return 0
|
||||
|
||||
insertions = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
ct = cell.get("col_type", "")
|
||||
if not ct.startswith("column_"):
|
||||
continue
|
||||
if col_filter is not None and ct not in col_filter:
|
||||
continue
|
||||
text = cell.get("text", "")
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# In auto mode (force=False), only normalize cells that already
|
||||
# have | from OCR (i.e. printed syllable dividers on the original
|
||||
# scan). Don't add new syllable marks to other words.
|
||||
if not force and "|" not in text:
|
||||
continue
|
||||
|
||||
new_text = _syllabify_text(text, hyph_de, hyph_en)
|
||||
if new_text != text:
|
||||
cell["text"] = new_text
|
||||
insertions += 1
|
||||
|
||||
if insertions:
|
||||
logger.info(
|
||||
"build-grid session %s: syllable dividers inserted/normalized "
|
||||
"in %d cells (pyphen)",
|
||||
session_id, insertions,
|
||||
)
|
||||
return insertions
|
||||
# Core: init, validation, autocorrect
|
||||
from cv_syllable_core import ( # noqa: F401
|
||||
_IPA_RE,
|
||||
_STOP_WORDS,
|
||||
_get_hyphenators,
|
||||
_get_spellchecker,
|
||||
_is_known_word,
|
||||
_is_real_word,
|
||||
_hyphenate_word,
|
||||
_autocorrect_piped_word,
|
||||
autocorrect_pipe_artifacts,
|
||||
)
|
||||
|
||||
# Merge: gap merging, syllabify, insert
|
||||
from cv_syllable_merge import ( # noqa: F401
|
||||
_try_merge_pipe_gaps,
|
||||
merge_word_gaps_in_zones,
|
||||
_try_merge_word_gaps,
|
||||
_syllabify_text,
|
||||
insert_syllable_dividers,
|
||||
)
|
||||
|
||||
300
klausur-service/backend/cv_syllable_merge.py
Normal file
300
klausur-service/backend/cv_syllable_merge.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Syllable Merge — word gap merging, syllabification, divider insertion.
|
||||
|
||||
Extracted from cv_syllable_detect.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cv_syllable_core import (
|
||||
_get_hyphenators,
|
||||
_hyphenate_word,
|
||||
_IPA_RE,
|
||||
_STOP_WORDS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_merge_pipe_gaps(text: str, hyph_de) -> str:
|
||||
"""Merge fragments separated by single spaces where OCR split at a pipe.
|
||||
|
||||
Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word).
|
||||
Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau".
|
||||
|
||||
Guards against false merges:
|
||||
- The FIRST token must be pure alpha (word start -- no attached punctuation)
|
||||
- The second token may have trailing punctuation (comma, period) which
|
||||
stays attached to the merged word: "Ka" + "fer," -> "Kafer,"
|
||||
- Common German function words (der, die, das, ...) are never merged
|
||||
- At least one fragment must be very short (<=3 alpha chars)
|
||||
"""
|
||||
parts = text.split(' ')
|
||||
if len(parts) < 2:
|
||||
return text
|
||||
|
||||
result = [parts[0]]
|
||||
i = 1
|
||||
while i < len(parts):
|
||||
prev = result[-1]
|
||||
curr = parts[i]
|
||||
|
||||
# Extract alpha-only core for lookup
|
||||
prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev)
|
||||
curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr)
|
||||
|
||||
# Guard 1: first token must be pure alpha (word-start fragment)
|
||||
# second token may have trailing punctuation
|
||||
# Guard 2: neither alpha core can be a common German function word
|
||||
# Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal)
|
||||
# Guard 4: combined length must be >= 4
|
||||
should_try = (
|
||||
prev == prev_alpha # first token: pure alpha (word start)
|
||||
and prev_alpha and curr_alpha
|
||||
and prev_alpha.lower() not in _STOP_WORDS
|
||||
and curr_alpha.lower() not in _STOP_WORDS
|
||||
and min(len(prev_alpha), len(curr_alpha)) <= 3
|
||||
and len(prev_alpha) + len(curr_alpha) >= 4
|
||||
)
|
||||
|
||||
if should_try:
|
||||
merged_alpha = prev_alpha + curr_alpha
|
||||
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
|
||||
if '-' in hyph:
|
||||
# pyphen recognizes merged word -- collapse the space
|
||||
result[-1] = prev + curr
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(curr)
|
||||
i += 1
|
||||
|
||||
return ' '.join(result)
|
||||
|
||||
|
||||
def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int:
|
||||
"""Merge OCR word-gap fragments in cell texts using pyphen validation.
|
||||
|
||||
OCR often splits words at syllable boundaries into separate word_boxes,
|
||||
producing text like "zerknit tert" instead of "zerknittert". This
|
||||
function tries to merge adjacent fragments in every content cell.
|
||||
|
||||
More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3)
|
||||
but still guarded by pyphen dictionary lookup and stop-word exclusion.
|
||||
|
||||
Returns the number of cells modified.
|
||||
"""
|
||||
hyph_de, _ = _get_hyphenators()
|
||||
if hyph_de is None:
|
||||
return 0
|
||||
|
||||
modified = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
ct = cell.get("col_type", "")
|
||||
if not ct.startswith("column_"):
|
||||
continue
|
||||
text = cell.get("text", "")
|
||||
if not text or " " not in text:
|
||||
continue
|
||||
|
||||
# Skip IPA cells
|
||||
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
|
||||
if _IPA_RE.search(text_no_brackets):
|
||||
continue
|
||||
|
||||
new_text = _try_merge_word_gaps(text, hyph_de)
|
||||
if new_text != text:
|
||||
cell["text"] = new_text
|
||||
modified += 1
|
||||
|
||||
if modified:
|
||||
logger.info(
|
||||
"build-grid session %s: merged word gaps in %d cells",
|
||||
session_id, modified,
|
||||
)
|
||||
return modified
|
||||
|
||||
|
||||
def _try_merge_word_gaps(text: str, hyph_de) -> str:
|
||||
"""Merge OCR word fragments with relaxed threshold (max_short=5).
|
||||
|
||||
Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments
|
||||
(max_short=5 instead of 3). Still requires pyphen to recognize the
|
||||
merged word.
|
||||
"""
|
||||
parts = text.split(' ')
|
||||
if len(parts) < 2:
|
||||
return text
|
||||
|
||||
result = [parts[0]]
|
||||
i = 1
|
||||
while i < len(parts):
|
||||
prev = result[-1]
|
||||
curr = parts[i]
|
||||
|
||||
prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev)
|
||||
curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr)
|
||||
|
||||
should_try = (
|
||||
prev == prev_alpha
|
||||
and prev_alpha and curr_alpha
|
||||
and prev_alpha.lower() not in _STOP_WORDS
|
||||
and curr_alpha.lower() not in _STOP_WORDS
|
||||
and min(len(prev_alpha), len(curr_alpha)) <= 5
|
||||
and len(prev_alpha) + len(curr_alpha) >= 4
|
||||
)
|
||||
|
||||
if should_try:
|
||||
merged_alpha = prev_alpha + curr_alpha
|
||||
hyph = hyph_de.inserted(merged_alpha, hyphen='-')
|
||||
if '-' in hyph:
|
||||
result[-1] = prev + curr
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(curr)
|
||||
i += 1
|
||||
|
||||
return ' '.join(result)
|
||||
|
||||
|
||||
def _syllabify_text(text: str, hyph_de, hyph_en) -> str:
|
||||
"""Syllabify all significant words in a text string.
|
||||
|
||||
1. Strip existing | dividers
|
||||
2. Merge pipe-gap spaces where possible
|
||||
3. Apply pyphen to each word >= 3 alphabetic chars
|
||||
4. Words pyphen doesn't recognize stay as-is (no bad guesses)
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Skip cells that contain IPA transcription characters outside brackets.
|
||||
text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
|
||||
if _IPA_RE.search(text_no_brackets):
|
||||
return text
|
||||
|
||||
# Phase 1: strip existing pipe dividers for clean normalization
|
||||
clean = text.replace('|', '')
|
||||
|
||||
# Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting)
|
||||
clean = _try_merge_pipe_gaps(clean, hyph_de)
|
||||
|
||||
# Phase 3: tokenize and syllabify each word
|
||||
# Split on whitespace and comma/semicolon sequences, keeping separators
|
||||
tokens = re.split(r'(\s+|[,;:]+\s*)', clean)
|
||||
|
||||
result = []
|
||||
for tok in tokens:
|
||||
if not tok or re.match(r'^[\s,;:]+$', tok):
|
||||
result.append(tok)
|
||||
continue
|
||||
|
||||
# Strip trailing/leading punctuation for pyphen lookup
|
||||
m = re.match(r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)(.*?)([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$', tok)
|
||||
if not m:
|
||||
result.append(tok)
|
||||
continue
|
||||
lead, word, trail = m.group(1), m.group(2), m.group(3)
|
||||
|
||||
if len(word) < 3 or not re.search(r'[a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df]', word):
|
||||
result.append(tok)
|
||||
continue
|
||||
|
||||
hyph = _hyphenate_word(word, hyph_de, hyph_en)
|
||||
if hyph:
|
||||
result.append(lead + hyph + trail)
|
||||
else:
|
||||
result.append(tok)
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def insert_syllable_dividers(
|
||||
zones_data: List[Dict],
|
||||
img_bgr: np.ndarray,
|
||||
session_id: str,
|
||||
*,
|
||||
force: bool = False,
|
||||
col_filter: Optional[set] = None,
|
||||
) -> int:
|
||||
"""Insert pipe syllable dividers into dictionary cells.
|
||||
|
||||
For dictionary pages: process all content column cells, strip existing
|
||||
pipes, merge pipe-gap spaces, and re-syllabify using pyphen.
|
||||
|
||||
Pre-check: at least 1% of content cells must already contain ``|`` from
|
||||
OCR. This guards against pages with zero pipe characters.
|
||||
|
||||
Args:
|
||||
force: If True, skip the pipe-ratio pre-check and syllabify all
|
||||
content words regardless of whether the original has pipe dividers.
|
||||
col_filter: If set, only process cells whose col_type is in this set.
|
||||
None means process all content columns.
|
||||
|
||||
Returns the number of cells modified.
|
||||
"""
|
||||
hyph_de, hyph_en = _get_hyphenators()
|
||||
if hyph_de is None:
|
||||
logger.warning("pyphen not installed -- skipping syllable insertion")
|
||||
return 0
|
||||
|
||||
# Pre-check: count cells that already have | from OCR.
|
||||
if not force:
|
||||
total_col_cells = 0
|
||||
cells_with_pipes = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
if cell.get("col_type", "").startswith("column_"):
|
||||
total_col_cells += 1
|
||||
if "|" in cell.get("text", ""):
|
||||
cells_with_pipes += 1
|
||||
|
||||
if total_col_cells > 0:
|
||||
pipe_ratio = cells_with_pipes / total_col_cells
|
||||
if pipe_ratio < 0.01:
|
||||
logger.info(
|
||||
"build-grid session %s: skipping syllable insertion -- "
|
||||
"only %.1f%% of cells have existing pipes (need >=1%%)",
|
||||
session_id, pipe_ratio * 100,
|
||||
)
|
||||
return 0
|
||||
|
||||
insertions = 0
|
||||
for z in zones_data:
|
||||
for cell in z.get("cells", []):
|
||||
ct = cell.get("col_type", "")
|
||||
if not ct.startswith("column_"):
|
||||
continue
|
||||
if col_filter is not None and ct not in col_filter:
|
||||
continue
|
||||
text = cell.get("text", "")
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# In auto mode (force=False), only normalize cells that already
|
||||
# have | from OCR (i.e. printed syllable dividers on the original
|
||||
# scan). Don't add new syllable marks to other words.
|
||||
if not force and "|" not in text:
|
||||
continue
|
||||
|
||||
new_text = _syllabify_text(text, hyph_de, hyph_en)
|
||||
if new_text != text:
|
||||
cell["text"] = new_text
|
||||
insertions += 1
|
||||
|
||||
if insertions:
|
||||
logger.info(
|
||||
"build-grid session %s: syllable dividers inserted/normalized "
|
||||
"in %d cells (pyphen)",
|
||||
session_id, insertions,
|
||||
)
|
||||
return insertions
|
||||
@@ -1,52 +1,27 @@
|
||||
"""
|
||||
Mail Aggregator Service
|
||||
Mail Aggregator Service — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
aggregator_imap — IMAP connection, sync, email parsing
|
||||
aggregator_smtp — SMTP connection, email sending
|
||||
|
||||
Multi-account IMAP aggregation with async support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import ssl
|
||||
import email
|
||||
import asyncio
|
||||
import logging
|
||||
import smtplib
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.header import decode_header, make_header
|
||||
from email.utils import parsedate_to_datetime, parseaddr
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from .credentials import get_credentials_service, MailCredentials
|
||||
from .mail_db import (
|
||||
get_email_accounts,
|
||||
get_email_account,
|
||||
update_account_status,
|
||||
upsert_email,
|
||||
get_unified_inbox,
|
||||
)
|
||||
from .models import (
|
||||
AccountStatus,
|
||||
AccountTestResult,
|
||||
AggregatedEmail,
|
||||
EmailComposeRequest,
|
||||
EmailSendResult,
|
||||
)
|
||||
from .credentials import get_credentials_service
|
||||
from .mail_db import get_email_accounts, get_unified_inbox
|
||||
from .models import AccountTestResult
|
||||
from .aggregator_imap import IMAPMixin, IMAPConnectionError
|
||||
from .aggregator_smtp import SMTPMixin, SMTPConnectionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMAPConnectionError(Exception):
|
||||
"""Raised when IMAP connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class SMTPConnectionError(Exception):
|
||||
"""Raised when SMTP connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class MailAggregator:
|
||||
class MailAggregator(IMAPMixin, SMTPMixin):
|
||||
"""
|
||||
Aggregates emails from multiple IMAP accounts into a unified inbox.
|
||||
|
||||
@@ -86,390 +61,29 @@ class MailAggregator:
|
||||
)
|
||||
|
||||
# Test IMAP
|
||||
try:
|
||||
import imaplib
|
||||
|
||||
if imap_ssl:
|
||||
imap = imaplib.IMAP4_SSL(imap_host, imap_port)
|
||||
else:
|
||||
imap = imaplib.IMAP4(imap_host, imap_port)
|
||||
|
||||
imap.login(email_address, password)
|
||||
result.imap_connected = True
|
||||
|
||||
# List folders
|
||||
status, folders = imap.list()
|
||||
if status == "OK":
|
||||
result.folders_found = [
|
||||
self._parse_folder_name(f) for f in folders if f
|
||||
]
|
||||
|
||||
imap.logout()
|
||||
|
||||
except Exception as e:
|
||||
result.error_message = f"IMAP Error: {str(e)}"
|
||||
logger.warning(f"IMAP test failed for {email_address}: {e}")
|
||||
imap_ok, imap_err, folders = await self.test_imap_connection(
|
||||
imap_host, imap_port, imap_ssl, email_address, password
|
||||
)
|
||||
result.imap_connected = imap_ok
|
||||
if folders:
|
||||
result.folders_found = folders
|
||||
if imap_err:
|
||||
result.error_message = imap_err
|
||||
|
||||
# Test SMTP
|
||||
try:
|
||||
if smtp_ssl:
|
||||
smtp = smtplib.SMTP_SSL(smtp_host, smtp_port)
|
||||
else:
|
||||
smtp = smtplib.SMTP(smtp_host, smtp_port)
|
||||
smtp.starttls()
|
||||
|
||||
smtp.login(email_address, password)
|
||||
result.smtp_connected = True
|
||||
smtp.quit()
|
||||
|
||||
except Exception as e:
|
||||
smtp_error = f"SMTP Error: {str(e)}"
|
||||
smtp_ok, smtp_err = await self.test_smtp_connection(
|
||||
smtp_host, smtp_port, smtp_ssl, email_address, password
|
||||
)
|
||||
result.smtp_connected = smtp_ok
|
||||
if smtp_err:
|
||||
if result.error_message:
|
||||
result.error_message += f"; {smtp_error}"
|
||||
result.error_message += f"; {smtp_err}"
|
||||
else:
|
||||
result.error_message = smtp_error
|
||||
logger.warning(f"SMTP test failed for {email_address}: {e}")
|
||||
result.error_message = smtp_err
|
||||
|
||||
result.success = result.imap_connected and result.smtp_connected
|
||||
return result
|
||||
|
||||
def _parse_folder_name(self, folder_response: bytes) -> str:
|
||||
"""Parse folder name from IMAP LIST response."""
|
||||
try:
|
||||
# Format: '(\\HasNoChildren) "/" "INBOX"'
|
||||
decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response
|
||||
parts = decoded.rsplit('" "', 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1].rstrip('"')
|
||||
return decoded
|
||||
except Exception:
|
||||
return str(folder_response)
|
||||
|
||||
async def sync_account(
|
||||
self,
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
max_emails: int = 100,
|
||||
folders: Optional[List[str]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Sync emails from an IMAP account.
|
||||
|
||||
Args:
|
||||
account_id: The account ID
|
||||
user_id: The user ID
|
||||
max_emails: Maximum emails to fetch
|
||||
folders: Specific folders to sync (default: INBOX)
|
||||
|
||||
Returns:
|
||||
Tuple of (new_emails, total_emails)
|
||||
"""
|
||||
import imaplib
|
||||
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise ValueError(f"Account not found: {account_id}")
|
||||
|
||||
# Get credentials
|
||||
vault_path = account.get("vault_path", "")
|
||||
creds = await self._credentials_service.get_credentials(account_id, vault_path)
|
||||
if not creds:
|
||||
await update_account_status(account_id, "error", "Credentials not found")
|
||||
raise IMAPConnectionError("Credentials not found")
|
||||
|
||||
new_count = 0
|
||||
total_count = 0
|
||||
|
||||
try:
|
||||
# Connect to IMAP
|
||||
if account["imap_ssl"]:
|
||||
imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"])
|
||||
else:
|
||||
imap = imaplib.IMAP4(account["imap_host"], account["imap_port"])
|
||||
|
||||
imap.login(creds.email, creds.password)
|
||||
|
||||
# Sync specified folders or just INBOX
|
||||
sync_folders = folders or ["INBOX"]
|
||||
|
||||
for folder in sync_folders:
|
||||
try:
|
||||
status, _ = imap.select(folder)
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
# Search for recent emails
|
||||
status, messages = imap.search(None, "ALL")
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
message_ids = messages[0].split()
|
||||
total_count += len(message_ids)
|
||||
|
||||
# Fetch most recent emails
|
||||
recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids
|
||||
|
||||
for msg_id in recent_ids:
|
||||
try:
|
||||
email_data = await self._fetch_and_store_email(
|
||||
imap, msg_id, account_id, user_id, account["tenant_id"], folder
|
||||
)
|
||||
if email_data:
|
||||
new_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch email {msg_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync folder {folder}: {e}")
|
||||
|
||||
imap.logout()
|
||||
|
||||
# Update account status
|
||||
await update_account_status(
|
||||
account_id,
|
||||
"active",
|
||||
email_count=total_count,
|
||||
unread_count=new_count, # Will be recalculated
|
||||
)
|
||||
|
||||
return new_count, total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Account sync failed: {e}")
|
||||
await update_account_status(account_id, "error", str(e))
|
||||
raise IMAPConnectionError(str(e))
|
||||
|
||||
async def _fetch_and_store_email(
|
||||
self,
|
||||
imap,
|
||||
msg_id: bytes,
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
folder: str,
|
||||
) -> Optional[str]:
|
||||
"""Fetch a single email and store it in the database."""
|
||||
try:
|
||||
status, msg_data = imap.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not msg_data or not msg_data[0]:
|
||||
return None
|
||||
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_email)
|
||||
|
||||
# Parse headers
|
||||
message_id = msg.get("Message-ID", str(msg_id))
|
||||
subject = self._decode_header(msg.get("Subject", ""))
|
||||
from_header = msg.get("From", "")
|
||||
sender_name, sender_email = parseaddr(from_header)
|
||||
sender_name = self._decode_header(sender_name)
|
||||
|
||||
# Parse recipients
|
||||
to_header = msg.get("To", "")
|
||||
recipients = [addr[1] for addr in email.utils.getaddresses([to_header])]
|
||||
|
||||
cc_header = msg.get("Cc", "")
|
||||
cc = [addr[1] for addr in email.utils.getaddresses([cc_header])]
|
||||
|
||||
# Parse dates
|
||||
date_str = msg.get("Date")
|
||||
try:
|
||||
date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc)
|
||||
except Exception:
|
||||
date_sent = datetime.now(timezone.utc)
|
||||
|
||||
date_received = datetime.now(timezone.utc)
|
||||
|
||||
# Parse body
|
||||
body_text, body_html, attachments = self._parse_body(msg)
|
||||
|
||||
# Create preview
|
||||
body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text
|
||||
|
||||
# Get headers dict
|
||||
headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]}
|
||||
|
||||
# Store in database
|
||||
email_id = await upsert_email(
|
||||
account_id=account_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message_id=message_id,
|
||||
subject=subject,
|
||||
sender_email=sender_email,
|
||||
sender_name=sender_name,
|
||||
recipients=recipients,
|
||||
cc=cc,
|
||||
body_preview=body_preview,
|
||||
body_text=body_text,
|
||||
body_html=body_html,
|
||||
has_attachments=len(attachments) > 0,
|
||||
attachments=attachments,
|
||||
headers=headers,
|
||||
folder=folder,
|
||||
date_sent=date_sent,
|
||||
date_received=date_received,
|
||||
)
|
||||
|
||||
return email_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse email: {e}")
|
||||
return None
|
||||
|
||||
def _decode_header(self, header_value: str) -> str:
|
||||
"""Decode email header value."""
|
||||
if not header_value:
|
||||
return ""
|
||||
try:
|
||||
decoded = decode_header(header_value)
|
||||
return str(make_header(decoded))
|
||||
except Exception:
|
||||
return str(header_value)
|
||||
|
||||
def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]:
|
||||
"""
|
||||
Parse email body and attachments.
|
||||
|
||||
Returns:
|
||||
Tuple of (body_text, body_html, attachments)
|
||||
"""
|
||||
body_text = None
|
||||
body_html = None
|
||||
attachments = []
|
||||
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
content_disposition = str(part.get("Content-Disposition", ""))
|
||||
|
||||
# Skip multipart containers
|
||||
if content_type.startswith("multipart/"):
|
||||
continue
|
||||
|
||||
# Check for attachments
|
||||
if "attachment" in content_disposition:
|
||||
filename = part.get_filename()
|
||||
if filename:
|
||||
attachments.append({
|
||||
"filename": self._decode_header(filename),
|
||||
"content_type": content_type,
|
||||
"size": len(part.get_payload(decode=True) or b""),
|
||||
})
|
||||
continue
|
||||
|
||||
# Get body content
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
|
||||
if payload:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
|
||||
if content_type == "text/plain" and not body_text:
|
||||
body_text = text
|
||||
elif content_type == "text/html" and not body_html:
|
||||
body_html = text
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to decode body part: {e}")
|
||||
|
||||
else:
|
||||
# Single part message
|
||||
content_type = msg.get_content_type()
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
|
||||
if payload:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
|
||||
if content_type == "text/plain":
|
||||
body_text = text
|
||||
elif content_type == "text/html":
|
||||
body_html = text
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to decode body: {e}")
|
||||
|
||||
return body_text, body_html, attachments
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
request: EmailComposeRequest,
|
||||
) -> EmailSendResult:
|
||||
"""
|
||||
Send an email via SMTP.
|
||||
|
||||
Args:
|
||||
account_id: The account to send from
|
||||
user_id: The user ID
|
||||
request: The compose request with recipients and content
|
||||
|
||||
Returns:
|
||||
EmailSendResult with success status
|
||||
"""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
return EmailSendResult(success=False, error="Account not found")
|
||||
|
||||
# Verify the account_id matches
|
||||
if request.account_id != account_id:
|
||||
return EmailSendResult(success=False, error="Account mismatch")
|
||||
|
||||
# Get credentials
|
||||
vault_path = account.get("vault_path", "")
|
||||
creds = await self._credentials_service.get_credentials(account_id, vault_path)
|
||||
if not creds:
|
||||
return EmailSendResult(success=False, error="Credentials not found")
|
||||
|
||||
try:
|
||||
# Create message
|
||||
if request.is_html:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg.attach(MIMEText(request.body, "html"))
|
||||
else:
|
||||
msg = MIMEText(request.body, "plain")
|
||||
|
||||
msg["Subject"] = request.subject
|
||||
msg["From"] = account["email"]
|
||||
msg["To"] = ", ".join(request.to)
|
||||
|
||||
if request.cc:
|
||||
msg["Cc"] = ", ".join(request.cc)
|
||||
|
||||
if request.reply_to_message_id:
|
||||
msg["In-Reply-To"] = request.reply_to_message_id
|
||||
msg["References"] = request.reply_to_message_id
|
||||
|
||||
# Send via SMTP
|
||||
if account["smtp_ssl"]:
|
||||
smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"])
|
||||
else:
|
||||
smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"])
|
||||
smtp.starttls()
|
||||
|
||||
smtp.login(creds.email, creds.password)
|
||||
|
||||
# All recipients
|
||||
all_recipients = list(request.to)
|
||||
if request.cc:
|
||||
all_recipients.extend(request.cc)
|
||||
if request.bcc:
|
||||
all_recipients.extend(request.bcc)
|
||||
|
||||
smtp.sendmail(account["email"], all_recipients, msg.as_string())
|
||||
smtp.quit()
|
||||
|
||||
return EmailSendResult(
|
||||
success=True,
|
||||
message_id=msg.get("Message-ID"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email: {e}")
|
||||
return EmailSendResult(success=False, error=str(e))
|
||||
|
||||
async def sync_all_accounts(self, user_id: str, tenant_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync all accounts for a user.
|
||||
|
||||
322
klausur-service/backend/mail/aggregator_imap.py
Normal file
322
klausur-service/backend/mail/aggregator_imap.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Mail Aggregator IMAP — IMAP connection, sync, email parsing.
|
||||
|
||||
Extracted from aggregator.py for modularity.
|
||||
"""
|
||||
|
||||
import email
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from email.header import decode_header, make_header
|
||||
from email.utils import parsedate_to_datetime, parseaddr
|
||||
|
||||
from .mail_db import upsert_email, update_account_status, get_email_account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMAPConnectionError(Exception):
|
||||
"""Raised when IMAP connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class IMAPMixin:
|
||||
"""IMAP-related methods for MailAggregator.
|
||||
|
||||
Provides connection testing, syncing, and email parsing.
|
||||
Must be mixed into a class that has ``_credentials_service``.
|
||||
"""
|
||||
|
||||
def _parse_folder_name(self, folder_response: bytes) -> str:
|
||||
"""Parse folder name from IMAP LIST response."""
|
||||
try:
|
||||
# Format: '(\\HasNoChildren) "/" "INBOX"'
|
||||
decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response
|
||||
parts = decoded.rsplit('" "', 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1].rstrip('"')
|
||||
return decoded
|
||||
except Exception:
|
||||
return str(folder_response)
|
||||
|
||||
async def test_imap_connection(
|
||||
self,
|
||||
imap_host: str,
|
||||
imap_port: int,
|
||||
imap_ssl: bool,
|
||||
email_address: str,
|
||||
password: str,
|
||||
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
|
||||
"""Test IMAP connection. Returns (success, error, folders)."""
|
||||
try:
|
||||
import imaplib
|
||||
|
||||
if imap_ssl:
|
||||
imap = imaplib.IMAP4_SSL(imap_host, imap_port)
|
||||
else:
|
||||
imap = imaplib.IMAP4(imap_host, imap_port)
|
||||
|
||||
imap.login(email_address, password)
|
||||
|
||||
# List folders
|
||||
folders_found = None
|
||||
status, folders = imap.list()
|
||||
if status == "OK":
|
||||
folders_found = [
|
||||
self._parse_folder_name(f) for f in folders if f
|
||||
]
|
||||
|
||||
imap.logout()
|
||||
return True, None, folders_found
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"IMAP test failed for {email_address}: {e}")
|
||||
return False, f"IMAP Error: {str(e)}", None
|
||||
|
||||
async def sync_account(
|
||||
self,
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
max_emails: int = 100,
|
||||
folders: Optional[List[str]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Sync emails from an IMAP account.
|
||||
|
||||
Args:
|
||||
account_id: The account ID
|
||||
user_id: The user ID
|
||||
max_emails: Maximum emails to fetch
|
||||
folders: Specific folders to sync (default: INBOX)
|
||||
|
||||
Returns:
|
||||
Tuple of (new_emails, total_emails)
|
||||
"""
|
||||
import imaplib
|
||||
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise ValueError(f"Account not found: {account_id}")
|
||||
|
||||
# Get credentials
|
||||
vault_path = account.get("vault_path", "")
|
||||
creds = await self._credentials_service.get_credentials(account_id, vault_path)
|
||||
if not creds:
|
||||
await update_account_status(account_id, "error", "Credentials not found")
|
||||
raise IMAPConnectionError("Credentials not found")
|
||||
|
||||
new_count = 0
|
||||
total_count = 0
|
||||
|
||||
try:
|
||||
# Connect to IMAP
|
||||
if account["imap_ssl"]:
|
||||
imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"])
|
||||
else:
|
||||
imap = imaplib.IMAP4(account["imap_host"], account["imap_port"])
|
||||
|
||||
imap.login(creds.email, creds.password)
|
||||
|
||||
# Sync specified folders or just INBOX
|
||||
sync_folders = folders or ["INBOX"]
|
||||
|
||||
for folder in sync_folders:
|
||||
try:
|
||||
status, _ = imap.select(folder)
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
# Search for recent emails
|
||||
status, messages = imap.search(None, "ALL")
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
message_ids = messages[0].split()
|
||||
total_count += len(message_ids)
|
||||
|
||||
# Fetch most recent emails
|
||||
recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids
|
||||
|
||||
for msg_id in recent_ids:
|
||||
try:
|
||||
email_data = await self._fetch_and_store_email(
|
||||
imap, msg_id, account_id, user_id, account["tenant_id"], folder
|
||||
)
|
||||
if email_data:
|
||||
new_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch email {msg_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync folder {folder}: {e}")
|
||||
|
||||
imap.logout()
|
||||
|
||||
# Update account status
|
||||
await update_account_status(
|
||||
account_id,
|
||||
"active",
|
||||
email_count=total_count,
|
||||
unread_count=new_count, # Will be recalculated
|
||||
)
|
||||
|
||||
return new_count, total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Account sync failed: {e}")
|
||||
await update_account_status(account_id, "error", str(e))
|
||||
raise IMAPConnectionError(str(e))
|
||||
|
||||
async def _fetch_and_store_email(
|
||||
self,
|
||||
imap,
|
||||
msg_id: bytes,
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
folder: str,
|
||||
) -> Optional[str]:
|
||||
"""Fetch a single email and store it in the database."""
|
||||
try:
|
||||
status, msg_data = imap.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not msg_data or not msg_data[0]:
|
||||
return None
|
||||
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_email)
|
||||
|
||||
# Parse headers
|
||||
message_id = msg.get("Message-ID", str(msg_id))
|
||||
subject = self._decode_header(msg.get("Subject", ""))
|
||||
from_header = msg.get("From", "")
|
||||
sender_name, sender_email = parseaddr(from_header)
|
||||
sender_name = self._decode_header(sender_name)
|
||||
|
||||
# Parse recipients
|
||||
to_header = msg.get("To", "")
|
||||
recipients = [addr[1] for addr in email.utils.getaddresses([to_header])]
|
||||
|
||||
cc_header = msg.get("Cc", "")
|
||||
cc = [addr[1] for addr in email.utils.getaddresses([cc_header])]
|
||||
|
||||
# Parse dates
|
||||
date_str = msg.get("Date")
|
||||
try:
|
||||
date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc)
|
||||
except Exception:
|
||||
date_sent = datetime.now(timezone.utc)
|
||||
|
||||
date_received = datetime.now(timezone.utc)
|
||||
|
||||
# Parse body
|
||||
body_text, body_html, attachments = self._parse_body(msg)
|
||||
|
||||
# Create preview
|
||||
body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text
|
||||
|
||||
# Get headers dict
|
||||
headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]}
|
||||
|
||||
# Store in database
|
||||
email_id = await upsert_email(
|
||||
account_id=account_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message_id=message_id,
|
||||
subject=subject,
|
||||
sender_email=sender_email,
|
||||
sender_name=sender_name,
|
||||
recipients=recipients,
|
||||
cc=cc,
|
||||
body_preview=body_preview,
|
||||
body_text=body_text,
|
||||
body_html=body_html,
|
||||
has_attachments=len(attachments) > 0,
|
||||
attachments=attachments,
|
||||
headers=headers,
|
||||
folder=folder,
|
||||
date_sent=date_sent,
|
||||
date_received=date_received,
|
||||
)
|
||||
|
||||
return email_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse email: {e}")
|
||||
return None
|
||||
|
||||
def _decode_header(self, header_value: str) -> str:
|
||||
"""Decode email header value."""
|
||||
if not header_value:
|
||||
return ""
|
||||
try:
|
||||
decoded = decode_header(header_value)
|
||||
return str(make_header(decoded))
|
||||
except Exception:
|
||||
return str(header_value)
|
||||
|
||||
def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]:
|
||||
"""
|
||||
Parse email body and attachments.
|
||||
|
||||
Returns:
|
||||
Tuple of (body_text, body_html, attachments)
|
||||
"""
|
||||
body_text = None
|
||||
body_html = None
|
||||
attachments = []
|
||||
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
content_disposition = str(part.get("Content-Disposition", ""))
|
||||
|
||||
# Skip multipart containers
|
||||
if content_type.startswith("multipart/"):
|
||||
continue
|
||||
|
||||
# Check for attachments
|
||||
if "attachment" in content_disposition:
|
||||
filename = part.get_filename()
|
||||
if filename:
|
||||
attachments.append({
|
||||
"filename": self._decode_header(filename),
|
||||
"content_type": content_type,
|
||||
"size": len(part.get_payload(decode=True) or b""),
|
||||
})
|
||||
continue
|
||||
|
||||
# Get body content
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
|
||||
if payload:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
|
||||
if content_type == "text/plain" and not body_text:
|
||||
body_text = text
|
||||
elif content_type == "text/html" and not body_html:
|
||||
body_html = text
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to decode body part: {e}")
|
||||
|
||||
else:
|
||||
# Single part message
|
||||
content_type = msg.get_content_type()
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
|
||||
if payload:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
|
||||
if content_type == "text/plain":
|
||||
body_text = text
|
||||
elif content_type == "text/html":
|
||||
body_html = text
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to decode body: {e}")
|
||||
|
||||
return body_text, body_html, attachments
|
||||
131
klausur-service/backend/mail/aggregator_smtp.py
Normal file
131
klausur-service/backend/mail/aggregator_smtp.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Mail Aggregator SMTP — email sending via SMTP.
|
||||
|
||||
Extracted from aggregator.py for modularity.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import smtplib
|
||||
from typing import Optional, List, Dict, Any
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from .mail_db import get_email_account
|
||||
from .models import EmailComposeRequest, EmailSendResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SMTPConnectionError(Exception):
|
||||
"""Raised when SMTP connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class SMTPMixin:
|
||||
"""SMTP-related methods for MailAggregator.
|
||||
|
||||
Provides SMTP connection testing and email sending.
|
||||
Must be mixed into a class that has ``_credentials_service``.
|
||||
"""
|
||||
|
||||
async def test_smtp_connection(
|
||||
self,
|
||||
smtp_host: str,
|
||||
smtp_port: int,
|
||||
smtp_ssl: bool,
|
||||
email_address: str,
|
||||
password: str,
|
||||
) -> tuple:
|
||||
"""Test SMTP connection. Returns (success, error)."""
|
||||
try:
|
||||
if smtp_ssl:
|
||||
smtp = smtplib.SMTP_SSL(smtp_host, smtp_port)
|
||||
else:
|
||||
smtp = smtplib.SMTP(smtp_host, smtp_port)
|
||||
smtp.starttls()
|
||||
|
||||
smtp.login(email_address, password)
|
||||
smtp.quit()
|
||||
return True, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"SMTP test failed for {email_address}: {e}")
|
||||
return False, f"SMTP Error: {str(e)}"
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
request: EmailComposeRequest,
|
||||
) -> EmailSendResult:
|
||||
"""
|
||||
Send an email via SMTP.
|
||||
|
||||
Args:
|
||||
account_id: The account to send from
|
||||
user_id: The user ID
|
||||
request: The compose request with recipients and content
|
||||
|
||||
Returns:
|
||||
EmailSendResult with success status
|
||||
"""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
return EmailSendResult(success=False, error="Account not found")
|
||||
|
||||
# Verify the account_id matches
|
||||
if request.account_id != account_id:
|
||||
return EmailSendResult(success=False, error="Account mismatch")
|
||||
|
||||
# Get credentials
|
||||
vault_path = account.get("vault_path", "")
|
||||
creds = await self._credentials_service.get_credentials(account_id, vault_path)
|
||||
if not creds:
|
||||
return EmailSendResult(success=False, error="Credentials not found")
|
||||
|
||||
try:
|
||||
# Create message
|
||||
if request.is_html:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg.attach(MIMEText(request.body, "html"))
|
||||
else:
|
||||
msg = MIMEText(request.body, "plain")
|
||||
|
||||
msg["Subject"] = request.subject
|
||||
msg["From"] = account["email"]
|
||||
msg["To"] = ", ".join(request.to)
|
||||
|
||||
if request.cc:
|
||||
msg["Cc"] = ", ".join(request.cc)
|
||||
|
||||
if request.reply_to_message_id:
|
||||
msg["In-Reply-To"] = request.reply_to_message_id
|
||||
msg["References"] = request.reply_to_message_id
|
||||
|
||||
# Send via SMTP
|
||||
if account["smtp_ssl"]:
|
||||
smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"])
|
||||
else:
|
||||
smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"])
|
||||
smtp.starttls()
|
||||
|
||||
smtp.login(creds.email, creds.password)
|
||||
|
||||
# All recipients
|
||||
all_recipients = list(request.to)
|
||||
if request.cc:
|
||||
all_recipients.extend(request.cc)
|
||||
if request.bcc:
|
||||
all_recipients.extend(request.bcc)
|
||||
|
||||
smtp.sendmail(account["email"], all_recipients, msg.as_string())
|
||||
smtp.quit()
|
||||
|
||||
return EmailSendResult(
|
||||
success=True,
|
||||
message_id=msg.get("Message-ID"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email: {e}")
|
||||
return EmailSendResult(success=False, error=str(e))
|
||||
@@ -10,12 +10,11 @@ Unterstützt:
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
@@ -23,6 +22,7 @@ import asyncio
|
||||
# Local imports
|
||||
from eh_pipeline import chunk_text, generate_embeddings, extract_text_from_pdf, get_vector_size, EMBEDDING_BACKEND
|
||||
from qdrant_service import QdrantService
|
||||
from nibis_parsers import parse_filename_old_format, parse_filename_new_format
|
||||
|
||||
# Configuration
|
||||
DOCS_BASE_PATH = Path("/Users/benjaminadmin/projekte/breakpilot-pwa/docs")
|
||||
@@ -87,15 +87,6 @@ SUBJECT_MAPPING = {
|
||||
"gespfl": "Gesundheit-Pflege",
|
||||
}
|
||||
|
||||
# Niveau-Mapping
|
||||
NIVEAU_MAPPING = {
|
||||
"ea": "eA", # erhöhtes Anforderungsniveau
|
||||
"ga": "gA", # grundlegendes Anforderungsniveau
|
||||
"neuga": "gA (neu einsetzend)",
|
||||
"neuea": "eA (neu einsetzend)",
|
||||
}
|
||||
|
||||
|
||||
def compute_file_hash(file_path: Path) -> str:
|
||||
"""Berechnet SHA-256 Hash einer Datei."""
|
||||
sha256 = hashlib.sha256()
|
||||
@@ -135,103 +126,6 @@ def extract_zip_files(base_path: Path) -> List[Path]:
|
||||
return extracted
|
||||
|
||||
|
||||
def parse_filename_old_format(filename: str, file_path: Path) -> Optional[Dict]:
|
||||
"""
|
||||
Parst alte Namenskonvention (2016, 2017):
|
||||
- {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf
|
||||
- Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf
|
||||
"""
|
||||
# Pattern für Lehrer-Dateien
|
||||
pattern = r"(\d{4})([A-Za-zäöüÄÖÜ]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$"
|
||||
|
||||
match = re.search(pattern, filename, re.IGNORECASE)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
year = int(match.group(1))
|
||||
subject_raw = match.group(2).lower()
|
||||
niveau = match.group(3).upper()
|
||||
task_num = match.group(4) or match.group(5)
|
||||
|
||||
# Prüfe ob es ein Lehrer-Dokument ist (EWH)
|
||||
is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf")
|
||||
|
||||
# Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.)
|
||||
variant = None
|
||||
variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"]
|
||||
for v in variant_patterns:
|
||||
if v.lower() in str(file_path).lower():
|
||||
variant = v
|
||||
break
|
||||
|
||||
return {
|
||||
"year": year,
|
||||
"subject": subject_raw,
|
||||
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
|
||||
"task_number": int(task_num) if task_num else None,
|
||||
"doc_type": "EWH" if is_ewh else "Aufgabe",
|
||||
"variant": variant,
|
||||
}
|
||||
|
||||
|
||||
def parse_filename_new_format(filename: str, file_path: Path) -> Optional[Dict]:
|
||||
"""
|
||||
Parst neue Namenskonvention (2024, 2025):
|
||||
- {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf
|
||||
- Beispiel: 2025_Deutsch_eA_I_EWH.pdf
|
||||
"""
|
||||
# Pattern für neue Dateien
|
||||
pattern = r"(\d{4})_([A-Za-zäöüÄÖÜ]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$"
|
||||
|
||||
match = re.search(pattern, filename, re.IGNORECASE)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
year = int(match.group(1))
|
||||
subject_raw = match.group(2).lower()
|
||||
niveau = match.group(3)
|
||||
task_id = match.group(4)
|
||||
suffix = match.group(5) or ""
|
||||
|
||||
# Task-Nummer aus römischen Zahlen
|
||||
task_num = None
|
||||
if task_id:
|
||||
roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
|
||||
task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None)
|
||||
|
||||
# Dokumenttyp
|
||||
is_ewh = "EWH" in filename or "ewh" in filename.lower()
|
||||
|
||||
# Spezielle Dokumenttypen
|
||||
doc_type = "EWH" if is_ewh else "Aufgabe"
|
||||
if "Material" in suffix:
|
||||
doc_type = "Material"
|
||||
elif "GBU" in suffix:
|
||||
doc_type = "GBU"
|
||||
elif "Ergebnis" in suffix:
|
||||
doc_type = "Ergebnis"
|
||||
elif "Bewertungsbogen" in suffix:
|
||||
doc_type = "Bewertungsbogen"
|
||||
elif "HV" in suffix:
|
||||
doc_type = "Hörverstehen"
|
||||
elif "ME" in suffix:
|
||||
doc_type = "Mediation"
|
||||
|
||||
# BG Variante
|
||||
variant = "BG" if "BG" in filename else None
|
||||
if "mitExp" in str(file_path):
|
||||
variant = "mitExp"
|
||||
|
||||
return {
|
||||
"year": year,
|
||||
"subject": subject_raw,
|
||||
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
|
||||
"task_number": task_num,
|
||||
"doc_type": doc_type,
|
||||
"variant": variant,
|
||||
}
|
||||
|
||||
|
||||
def discover_documents(base_path: Path, ewh_only: bool = True) -> List[NiBiSDocument]:
|
||||
"""
|
||||
Findet alle relevanten Dokumente in den za-download Verzeichnissen.
|
||||
|
||||
113
klausur-service/backend/nibis_parsers.py
Normal file
113
klausur-service/backend/nibis_parsers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
NiBiS Filename Parsers
|
||||
|
||||
Parses old and new naming conventions for NiBiS Abitur documents.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Niveau-Mapping
|
||||
NIVEAU_MAPPING = {
|
||||
"ea": "eA", # erhoehtes Anforderungsniveau
|
||||
"ga": "gA", # grundlegendes Anforderungsniveau
|
||||
"neuga": "gA (neu einsetzend)",
|
||||
"neuea": "eA (neu einsetzend)",
|
||||
}
|
||||
|
||||
|
||||
def parse_filename_old_format(filename: str, file_path) -> Optional[Dict]:
|
||||
"""
|
||||
Parst alte Namenskonvention (2016, 2017):
|
||||
- {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf
|
||||
- Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf
|
||||
"""
|
||||
# Pattern fuer Lehrer-Dateien
|
||||
pattern = r"(\d{4})([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$"
|
||||
|
||||
match = re.search(pattern, filename, re.IGNORECASE)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
year = int(match.group(1))
|
||||
subject_raw = match.group(2).lower()
|
||||
niveau = match.group(3).upper()
|
||||
task_num = match.group(4) or match.group(5)
|
||||
|
||||
# Pruefe ob es ein Lehrer-Dokument ist (EWH)
|
||||
is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf")
|
||||
|
||||
# Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.)
|
||||
variant = None
|
||||
variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"]
|
||||
for v in variant_patterns:
|
||||
if v.lower() in str(file_path).lower():
|
||||
variant = v
|
||||
break
|
||||
|
||||
return {
|
||||
"year": year,
|
||||
"subject": subject_raw,
|
||||
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
|
||||
"task_number": int(task_num) if task_num else None,
|
||||
"doc_type": "EWH" if is_ewh else "Aufgabe",
|
||||
"variant": variant,
|
||||
}
|
||||
|
||||
|
||||
def parse_filename_new_format(filename: str, file_path) -> Optional[Dict]:
|
||||
"""
|
||||
Parst neue Namenskonvention (2024, 2025):
|
||||
- {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf
|
||||
- Beispiel: 2025_Deutsch_eA_I_EWH.pdf
|
||||
"""
|
||||
# Pattern fuer neue Dateien
|
||||
pattern = r"(\d{4})_([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$"
|
||||
|
||||
match = re.search(pattern, filename, re.IGNORECASE)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
year = int(match.group(1))
|
||||
subject_raw = match.group(2).lower()
|
||||
niveau = match.group(3)
|
||||
task_id = match.group(4)
|
||||
suffix = match.group(5) or ""
|
||||
|
||||
# Task-Nummer aus roemischen Zahlen
|
||||
task_num = None
|
||||
if task_id:
|
||||
roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
|
||||
task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None)
|
||||
|
||||
# Dokumenttyp
|
||||
is_ewh = "EWH" in filename or "ewh" in filename.lower()
|
||||
|
||||
# Spezielle Dokumenttypen
|
||||
doc_type = "EWH" if is_ewh else "Aufgabe"
|
||||
if "Material" in suffix:
|
||||
doc_type = "Material"
|
||||
elif "GBU" in suffix:
|
||||
doc_type = "GBU"
|
||||
elif "Ergebnis" in suffix:
|
||||
doc_type = "Ergebnis"
|
||||
elif "Bewertungsbogen" in suffix:
|
||||
doc_type = "Bewertungsbogen"
|
||||
elif "HV" in suffix:
|
||||
doc_type = "Hoerverstehen"
|
||||
elif "ME" in suffix:
|
||||
doc_type = "Mediation"
|
||||
|
||||
# BG Variante
|
||||
variant = "BG" if "BG" in filename else None
|
||||
if "mitExp" in str(file_path):
|
||||
variant = "mitExp"
|
||||
|
||||
return {
|
||||
"year": year,
|
||||
"subject": subject_raw,
|
||||
"niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
|
||||
"task_number": task_num,
|
||||
"doc_type": doc_type,
|
||||
"variant": variant,
|
||||
}
|
||||
@@ -1,557 +1,26 @@
|
||||
"""
|
||||
NRU Worksheet Generator - Generate vocabulary worksheets in NRU format.
|
||||
NRU Worksheet Generator — barrel re-export.
|
||||
|
||||
Format:
|
||||
- Page 1 (Vokabeln): 3-column table
|
||||
- Column 1: English vocabulary
|
||||
- Column 2: Empty (child writes German translation)
|
||||
- Column 3: Empty (child writes corrected English after parent review)
|
||||
|
||||
- Page 2 (Lernsätze): Full-width table
|
||||
- Row 1: German sentence (pre-filled)
|
||||
- Row 2-3: Empty lines (child writes English translation)
|
||||
All implementation split into:
|
||||
nru_worksheet_models — data classes, entry separation
|
||||
nru_worksheet_html — HTML generation
|
||||
nru_worksheet_pdf — PDF generation
|
||||
|
||||
Per scanned page, we generate 2 worksheet pages.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
# Models
|
||||
from nru_worksheet_models import ( # noqa: F401
|
||||
VocabEntry,
|
||||
SentenceEntry,
|
||||
separate_vocab_and_sentences,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# HTML generation
|
||||
from nru_worksheet_html import ( # noqa: F401
|
||||
generate_nru_html,
|
||||
generate_nru_worksheet_html,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabEntry:
|
||||
english: str
|
||||
german: str
|
||||
source_page: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentenceEntry:
|
||||
german: str
|
||||
english: str # For solution sheet
|
||||
source_page: int = 1
|
||||
|
||||
|
||||
def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
|
||||
"""
|
||||
Separate vocabulary entries into single words/phrases and full sentences.
|
||||
|
||||
Sentences are identified by:
|
||||
- Ending with punctuation (. ! ?)
|
||||
- Being longer than 40 characters
|
||||
- Containing multiple words with capital letters mid-sentence
|
||||
"""
|
||||
vocab_list = []
|
||||
sentence_list = []
|
||||
|
||||
for entry in entries:
|
||||
english = entry.get("english", "").strip()
|
||||
german = entry.get("german", "").strip()
|
||||
source_page = entry.get("source_page", 1)
|
||||
|
||||
if not english or not german:
|
||||
continue
|
||||
|
||||
# Detect if this is a sentence
|
||||
is_sentence = (
|
||||
english.endswith('.') or
|
||||
english.endswith('!') or
|
||||
english.endswith('?') or
|
||||
len(english) > 50 or
|
||||
(len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
|
||||
)
|
||||
|
||||
if is_sentence:
|
||||
sentence_list.append(SentenceEntry(
|
||||
german=german,
|
||||
english=english,
|
||||
source_page=source_page
|
||||
))
|
||||
else:
|
||||
vocab_list.append(VocabEntry(
|
||||
english=english,
|
||||
german=german,
|
||||
source_page=source_page
|
||||
))
|
||||
|
||||
return vocab_list, sentence_list
|
||||
|
||||
|
||||
def generate_nru_html(
|
||||
vocab_list: List[VocabEntry],
|
||||
sentence_list: List[SentenceEntry],
|
||||
page_number: int,
|
||||
title: str = "Vokabeltest",
|
||||
show_solutions: bool = False,
|
||||
line_height_px: int = 28
|
||||
) -> str:
|
||||
"""
|
||||
Generate HTML for NRU-format worksheet.
|
||||
|
||||
Returns HTML for 2 pages:
|
||||
- Page 1: Vocabulary table (3 columns)
|
||||
- Page 2: Sentence practice (full width)
|
||||
"""
|
||||
|
||||
# Filter by page
|
||||
page_vocab = [v for v in vocab_list if v.source_page == page_number]
|
||||
page_sentences = [s for s in sentence_list if s.source_page == page_number]
|
||||
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
@page {{
|
||||
size: A4;
|
||||
margin: 1.5cm 2cm;
|
||||
}}
|
||||
* {{
|
||||
box-sizing: border-box;
|
||||
}}
|
||||
body {{
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 12pt;
|
||||
line-height: 1.4;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}}
|
||||
.page {{
|
||||
page-break-after: always;
|
||||
min-height: 100%;
|
||||
}}
|
||||
.page:last-child {{
|
||||
page-break-after: avoid;
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 16pt;
|
||||
margin: 0 0 8px 0;
|
||||
text-align: center;
|
||||
}}
|
||||
.header {{
|
||||
margin-bottom: 15px;
|
||||
}}
|
||||
.name-line {{
|
||||
font-size: 11pt;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
|
||||
/* Vocabulary Table - 3 columns */
|
||||
.vocab-table {{
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
table-layout: fixed;
|
||||
}}
|
||||
.vocab-table th {{
|
||||
background: #f0f0f0;
|
||||
border: 1px solid #333;
|
||||
padding: 6px 8px;
|
||||
font-weight: bold;
|
||||
font-size: 11pt;
|
||||
text-align: left;
|
||||
}}
|
||||
.vocab-table td {{
|
||||
border: 1px solid #333;
|
||||
padding: 4px 8px;
|
||||
height: {line_height_px}px;
|
||||
vertical-align: middle;
|
||||
}}
|
||||
.vocab-table .col-english {{ width: 35%; }}
|
||||
.vocab-table .col-german {{ width: 35%; }}
|
||||
.vocab-table .col-correction {{ width: 30%; }}
|
||||
.vocab-answer {{
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
}}
|
||||
|
||||
/* Sentence Table - full width */
|
||||
.sentence-table {{
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 15px;
|
||||
}}
|
||||
.sentence-table td {{
|
||||
border: 1px solid #333;
|
||||
padding: 6px 10px;
|
||||
}}
|
||||
.sentence-header {{
|
||||
background: #f5f5f5;
|
||||
font-weight: normal;
|
||||
min-height: 30px;
|
||||
}}
|
||||
.sentence-line {{
|
||||
height: {line_height_px + 4}px;
|
||||
}}
|
||||
.sentence-answer {{
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
font-size: 11pt;
|
||||
}}
|
||||
|
||||
.page-info {{
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
text-align: right;
|
||||
margin-top: 10px;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
"""
|
||||
|
||||
# ========== PAGE 1: VOCABULARY TABLE ==========
|
||||
if page_vocab:
|
||||
html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Vokabeln (Seite {page_number})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
|
||||
<table class="vocab-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="col-english">Englisch</th>
|
||||
<th class="col-german">Deutsch</th>
|
||||
<th class="col-correction">Korrektur</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
for v in page_vocab:
|
||||
if show_solutions:
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td class="vocab-answer">{v.german}</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</tbody>
|
||||
</table>
|
||||
<div class="page-info">Vokabeln aus Unit</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# ========== PAGE 2: SENTENCE PRACTICE ==========
|
||||
if page_sentences:
|
||||
html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Lernsaetze (Seite {page_number})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
"""
|
||||
for s in page_sentences:
|
||||
html += f"""
|
||||
<table class="sentence-table">
|
||||
<tr>
|
||||
<td class="sentence-header">{s.german}</td>
|
||||
</tr>
|
||||
"""
|
||||
if show_solutions:
|
||||
html += f"""
|
||||
<tr>
|
||||
<td class="sentence-line sentence-answer">{s.english}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
html += """
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
html += """
|
||||
</table>
|
||||
"""
|
||||
|
||||
html += """
|
||||
<div class="page-info">Lernsaetze aus Unit</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html
|
||||
|
||||
|
||||
def generate_nru_worksheet_html(
|
||||
entries: List[Dict],
|
||||
title: str = "Vokabeltest",
|
||||
show_solutions: bool = False,
|
||||
specific_pages: List[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate complete NRU worksheet HTML for all pages.
|
||||
|
||||
Args:
|
||||
entries: List of vocabulary entries with source_page
|
||||
title: Worksheet title
|
||||
show_solutions: Whether to show answers
|
||||
specific_pages: List of specific page numbers to include (1-indexed)
|
||||
|
||||
Returns:
|
||||
Complete HTML document
|
||||
"""
|
||||
# Separate into vocab and sentences
|
||||
vocab_list, sentence_list = separate_vocab_and_sentences(entries)
|
||||
|
||||
# Get unique page numbers
|
||||
all_pages = set()
|
||||
for v in vocab_list:
|
||||
all_pages.add(v.source_page)
|
||||
for s in sentence_list:
|
||||
all_pages.add(s.source_page)
|
||||
|
||||
# Filter to specific pages if requested
|
||||
if specific_pages:
|
||||
all_pages = all_pages.intersection(set(specific_pages))
|
||||
|
||||
pages_sorted = sorted(all_pages)
|
||||
|
||||
logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
|
||||
logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
|
||||
|
||||
# Generate HTML for each page
|
||||
combined_html = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 1.5cm 2cm;
|
||||
}
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 12pt;
|
||||
line-height: 1.4;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.page {
|
||||
page-break-after: always;
|
||||
min-height: 100%;
|
||||
}
|
||||
.page:last-child {
|
||||
page-break-after: avoid;
|
||||
}
|
||||
h1 {
|
||||
font-size: 16pt;
|
||||
margin: 0 0 8px 0;
|
||||
text-align: center;
|
||||
}
|
||||
.header {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.name-line {
|
||||
font-size: 11pt;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
/* Vocabulary Table - 3 columns */
|
||||
.vocab-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
table-layout: fixed;
|
||||
}
|
||||
.vocab-table th {
|
||||
background: #f0f0f0;
|
||||
border: 1px solid #333;
|
||||
padding: 6px 8px;
|
||||
font-weight: bold;
|
||||
font-size: 11pt;
|
||||
text-align: left;
|
||||
}
|
||||
.vocab-table td {
|
||||
border: 1px solid #333;
|
||||
padding: 4px 8px;
|
||||
height: 28px;
|
||||
vertical-align: middle;
|
||||
}
|
||||
.vocab-table .col-english { width: 35%; }
|
||||
.vocab-table .col-german { width: 35%; }
|
||||
.vocab-table .col-correction { width: 30%; }
|
||||
.vocab-answer {
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Sentence Table - full width */
|
||||
.sentence-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.sentence-table td {
|
||||
border: 1px solid #333;
|
||||
padding: 6px 10px;
|
||||
}
|
||||
.sentence-header {
|
||||
background: #f5f5f5;
|
||||
font-weight: normal;
|
||||
min-height: 30px;
|
||||
}
|
||||
.sentence-line {
|
||||
height: 32px;
|
||||
}
|
||||
.sentence-answer {
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
font-size: 11pt;
|
||||
}
|
||||
|
||||
.page-info {
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
text-align: right;
|
||||
margin-top: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
"""
|
||||
|
||||
for page_num in pages_sorted:
|
||||
page_vocab = [v for v in vocab_list if v.source_page == page_num]
|
||||
page_sentences = [s for s in sentence_list if s.source_page == page_num]
|
||||
|
||||
# PAGE 1: VOCABULARY TABLE
|
||||
if page_vocab:
|
||||
combined_html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Vokabeln (Seite {page_num})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
|
||||
<table class="vocab-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="col-english">Englisch</th>
|
||||
<th class="col-german">Deutsch</th>
|
||||
<th class="col-correction">Korrektur</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
for v in page_vocab:
|
||||
if show_solutions:
|
||||
combined_html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td class="vocab-answer">{v.german}</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
combined_html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
combined_html += f"""
|
||||
</tbody>
|
||||
</table>
|
||||
<div class="page-info">{title} - Seite {page_num}</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# PAGE 2: SENTENCE PRACTICE
|
||||
if page_sentences:
|
||||
combined_html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Lernsaetze (Seite {page_num})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
"""
|
||||
for s in page_sentences:
|
||||
combined_html += f"""
|
||||
<table class="sentence-table">
|
||||
<tr>
|
||||
<td class="sentence-header">{s.german}</td>
|
||||
</tr>
|
||||
"""
|
||||
if show_solutions:
|
||||
combined_html += f"""
|
||||
<tr>
|
||||
<td class="sentence-line sentence-answer">{s.english}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
combined_html += """
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
combined_html += """
|
||||
</table>
|
||||
"""
|
||||
|
||||
combined_html += f"""
|
||||
<div class="page-info">{title} - Seite {page_num}</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
combined_html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return combined_html
|
||||
|
||||
|
||||
async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
Generate NRU worksheet PDFs.
|
||||
|
||||
Returns:
|
||||
Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
|
||||
"""
|
||||
from weasyprint import HTML
|
||||
|
||||
# Generate worksheet HTML
|
||||
worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
|
||||
worksheet_pdf = HTML(string=worksheet_html).write_pdf()
|
||||
|
||||
# Generate solution HTML
|
||||
solution_pdf = None
|
||||
if include_solutions:
|
||||
solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
|
||||
solution_pdf = HTML(string=solution_html).write_pdf()
|
||||
|
||||
return worksheet_pdf, solution_pdf
|
||||
# PDF generation
|
||||
from nru_worksheet_pdf import generate_nru_pdf # noqa: F401
|
||||
|
||||
466
klausur-service/backend/nru_worksheet_html.py
Normal file
466
klausur-service/backend/nru_worksheet_html.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
NRU Worksheet HTML — HTML generation for vocabulary worksheets.
|
||||
|
||||
Extracted from nru_worksheet_generator.py for modularity.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
|
||||
from nru_worksheet_models import VocabEntry, SentenceEntry, separate_vocab_and_sentences
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_nru_html(
|
||||
vocab_list: List[VocabEntry],
|
||||
sentence_list: List[SentenceEntry],
|
||||
page_number: int,
|
||||
title: str = "Vokabeltest",
|
||||
show_solutions: bool = False,
|
||||
line_height_px: int = 28
|
||||
) -> str:
|
||||
"""
|
||||
Generate HTML for NRU-format worksheet.
|
||||
|
||||
Returns HTML for 2 pages:
|
||||
- Page 1: Vocabulary table (3 columns)
|
||||
- Page 2: Sentence practice (full width)
|
||||
"""
|
||||
|
||||
# Filter by page
|
||||
page_vocab = [v for v in vocab_list if v.source_page == page_number]
|
||||
page_sentences = [s for s in sentence_list if s.source_page == page_number]
|
||||
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
@page {{
|
||||
size: A4;
|
||||
margin: 1.5cm 2cm;
|
||||
}}
|
||||
* {{
|
||||
box-sizing: border-box;
|
||||
}}
|
||||
body {{
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 12pt;
|
||||
line-height: 1.4;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}}
|
||||
.page {{
|
||||
page-break-after: always;
|
||||
min-height: 100%;
|
||||
}}
|
||||
.page:last-child {{
|
||||
page-break-after: avoid;
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 16pt;
|
||||
margin: 0 0 8px 0;
|
||||
text-align: center;
|
||||
}}
|
||||
.header {{
|
||||
margin-bottom: 15px;
|
||||
}}
|
||||
.name-line {{
|
||||
font-size: 11pt;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
|
||||
/* Vocabulary Table - 3 columns */
|
||||
.vocab-table {{
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
table-layout: fixed;
|
||||
}}
|
||||
.vocab-table th {{
|
||||
background: #f0f0f0;
|
||||
border: 1px solid #333;
|
||||
padding: 6px 8px;
|
||||
font-weight: bold;
|
||||
font-size: 11pt;
|
||||
text-align: left;
|
||||
}}
|
||||
.vocab-table td {{
|
||||
border: 1px solid #333;
|
||||
padding: 4px 8px;
|
||||
height: {line_height_px}px;
|
||||
vertical-align: middle;
|
||||
}}
|
||||
.vocab-table .col-english {{ width: 35%; }}
|
||||
.vocab-table .col-german {{ width: 35%; }}
|
||||
.vocab-table .col-correction {{ width: 30%; }}
|
||||
.vocab-answer {{
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
}}
|
||||
|
||||
/* Sentence Table - full width */
|
||||
.sentence-table {{
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 15px;
|
||||
}}
|
||||
.sentence-table td {{
|
||||
border: 1px solid #333;
|
||||
padding: 6px 10px;
|
||||
}}
|
||||
.sentence-header {{
|
||||
background: #f5f5f5;
|
||||
font-weight: normal;
|
||||
min-height: 30px;
|
||||
}}
|
||||
.sentence-line {{
|
||||
height: {line_height_px + 4}px;
|
||||
}}
|
||||
.sentence-answer {{
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
font-size: 11pt;
|
||||
}}
|
||||
|
||||
.page-info {{
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
text-align: right;
|
||||
margin-top: 10px;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
"""
|
||||
|
||||
# ========== PAGE 1: VOCABULARY TABLE ==========
|
||||
if page_vocab:
|
||||
html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Vokabeln (Seite {page_number})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
|
||||
<table class="vocab-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="col-english">Englisch</th>
|
||||
<th class="col-german">Deutsch</th>
|
||||
<th class="col-correction">Korrektur</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
for v in page_vocab:
|
||||
if show_solutions:
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td class="vocab-answer">{v.german}</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</tbody>
|
||||
</table>
|
||||
<div class="page-info">Vokabeln aus Unit</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# ========== PAGE 2: SENTENCE PRACTICE ==========
|
||||
if page_sentences:
|
||||
html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Lernsaetze (Seite {page_number})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
"""
|
||||
for s in page_sentences:
|
||||
html += f"""
|
||||
<table class="sentence-table">
|
||||
<tr>
|
||||
<td class="sentence-header">{s.german}</td>
|
||||
</tr>
|
||||
"""
|
||||
if show_solutions:
|
||||
html += f"""
|
||||
<tr>
|
||||
<td class="sentence-line sentence-answer">{s.english}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
html += """
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
html += """
|
||||
</table>
|
||||
"""
|
||||
|
||||
html += """
|
||||
<div class="page-info">Lernsaetze aus Unit</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html
|
||||
|
||||
|
||||
def generate_nru_worksheet_html(
|
||||
entries: List[Dict],
|
||||
title: str = "Vokabeltest",
|
||||
show_solutions: bool = False,
|
||||
specific_pages: List[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate complete NRU worksheet HTML for all pages.
|
||||
|
||||
Args:
|
||||
entries: List of vocabulary entries with source_page
|
||||
title: Worksheet title
|
||||
show_solutions: Whether to show answers
|
||||
specific_pages: List of specific page numbers to include (1-indexed)
|
||||
|
||||
Returns:
|
||||
Complete HTML document
|
||||
"""
|
||||
# Separate into vocab and sentences
|
||||
vocab_list, sentence_list = separate_vocab_and_sentences(entries)
|
||||
|
||||
# Get unique page numbers
|
||||
all_pages = set()
|
||||
for v in vocab_list:
|
||||
all_pages.add(v.source_page)
|
||||
for s in sentence_list:
|
||||
all_pages.add(s.source_page)
|
||||
|
||||
# Filter to specific pages if requested
|
||||
if specific_pages:
|
||||
all_pages = all_pages.intersection(set(specific_pages))
|
||||
|
||||
pages_sorted = sorted(all_pages)
|
||||
|
||||
logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
|
||||
logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
|
||||
|
||||
# Generate HTML for each page
|
||||
combined_html = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 1.5cm 2cm;
|
||||
}
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 12pt;
|
||||
line-height: 1.4;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.page {
|
||||
page-break-after: always;
|
||||
min-height: 100%;
|
||||
}
|
||||
.page:last-child {
|
||||
page-break-after: avoid;
|
||||
}
|
||||
h1 {
|
||||
font-size: 16pt;
|
||||
margin: 0 0 8px 0;
|
||||
text-align: center;
|
||||
}
|
||||
.header {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.name-line {
|
||||
font-size: 11pt;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
/* Vocabulary Table - 3 columns */
|
||||
.vocab-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
table-layout: fixed;
|
||||
}
|
||||
.vocab-table th {
|
||||
background: #f0f0f0;
|
||||
border: 1px solid #333;
|
||||
padding: 6px 8px;
|
||||
font-weight: bold;
|
||||
font-size: 11pt;
|
||||
text-align: left;
|
||||
}
|
||||
.vocab-table td {
|
||||
border: 1px solid #333;
|
||||
padding: 4px 8px;
|
||||
height: 28px;
|
||||
vertical-align: middle;
|
||||
}
|
||||
.vocab-table .col-english { width: 35%; }
|
||||
.vocab-table .col-german { width: 35%; }
|
||||
.vocab-table .col-correction { width: 30%; }
|
||||
.vocab-answer {
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Sentence Table - full width */
|
||||
.sentence-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.sentence-table td {
|
||||
border: 1px solid #333;
|
||||
padding: 6px 10px;
|
||||
}
|
||||
.sentence-header {
|
||||
background: #f5f5f5;
|
||||
font-weight: normal;
|
||||
min-height: 30px;
|
||||
}
|
||||
.sentence-line {
|
||||
height: 32px;
|
||||
}
|
||||
.sentence-answer {
|
||||
color: #0066cc;
|
||||
font-style: italic;
|
||||
font-size: 11pt;
|
||||
}
|
||||
|
||||
.page-info {
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
text-align: right;
|
||||
margin-top: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
"""
|
||||
|
||||
for page_num in pages_sorted:
|
||||
page_vocab = [v for v in vocab_list if v.source_page == page_num]
|
||||
page_sentences = [s for s in sentence_list if s.source_page == page_num]
|
||||
|
||||
# PAGE 1: VOCABULARY TABLE
|
||||
if page_vocab:
|
||||
combined_html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Vokabeln (Seite {page_num})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
|
||||
<table class="vocab-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="col-english">Englisch</th>
|
||||
<th class="col-german">Deutsch</th>
|
||||
<th class="col-correction">Korrektur</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
for v in page_vocab:
|
||||
if show_solutions:
|
||||
combined_html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td class="vocab-answer">{v.german}</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
combined_html += f"""
|
||||
<tr>
|
||||
<td>{v.english}</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
combined_html += f"""
|
||||
</tbody>
|
||||
</table>
|
||||
<div class="page-info">{title} - Seite {page_num}</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# PAGE 2: SENTENCE PRACTICE
|
||||
if page_sentences:
|
||||
combined_html += f"""
|
||||
<div class="page">
|
||||
<div class="header">
|
||||
<h1>{title} - Lernsaetze (Seite {page_num})</h1>
|
||||
<div class="name-line">Name: _________________________ Datum: _____________</div>
|
||||
</div>
|
||||
"""
|
||||
for s in page_sentences:
|
||||
combined_html += f"""
|
||||
<table class="sentence-table">
|
||||
<tr>
|
||||
<td class="sentence-header">{s.german}</td>
|
||||
</tr>
|
||||
"""
|
||||
if show_solutions:
|
||||
combined_html += f"""
|
||||
<tr>
|
||||
<td class="sentence-line sentence-answer">{s.english}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
else:
|
||||
combined_html += """
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="sentence-line"></td>
|
||||
</tr>
|
||||
"""
|
||||
combined_html += """
|
||||
</table>
|
||||
"""
|
||||
|
||||
combined_html += f"""
|
||||
<div class="page-info">{title} - Seite {page_num}</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
combined_html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return combined_html
|
||||
70
klausur-service/backend/nru_worksheet_models.py
Normal file
70
klausur-service/backend/nru_worksheet_models.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
NRU Worksheet Models — data classes and entry separation logic.
|
||||
|
||||
Extracted from nru_worksheet_generator.py for modularity.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabEntry:
|
||||
english: str
|
||||
german: str
|
||||
source_page: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentenceEntry:
|
||||
german: str
|
||||
english: str # For solution sheet
|
||||
source_page: int = 1
|
||||
|
||||
|
||||
def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
|
||||
"""
|
||||
Separate vocabulary entries into single words/phrases and full sentences.
|
||||
|
||||
Sentences are identified by:
|
||||
- Ending with punctuation (. ! ?)
|
||||
- Being longer than 40 characters
|
||||
- Containing multiple words with capital letters mid-sentence
|
||||
"""
|
||||
vocab_list = []
|
||||
sentence_list = []
|
||||
|
||||
for entry in entries:
|
||||
english = entry.get("english", "").strip()
|
||||
german = entry.get("german", "").strip()
|
||||
source_page = entry.get("source_page", 1)
|
||||
|
||||
if not english or not german:
|
||||
continue
|
||||
|
||||
# Detect if this is a sentence
|
||||
is_sentence = (
|
||||
english.endswith('.') or
|
||||
english.endswith('!') or
|
||||
english.endswith('?') or
|
||||
len(english) > 50 or
|
||||
(len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
|
||||
)
|
||||
|
||||
if is_sentence:
|
||||
sentence_list.append(SentenceEntry(
|
||||
german=german,
|
||||
english=english,
|
||||
source_page=source_page
|
||||
))
|
||||
else:
|
||||
vocab_list.append(VocabEntry(
|
||||
english=english,
|
||||
german=german,
|
||||
source_page=source_page
|
||||
))
|
||||
|
||||
return vocab_list, sentence_list
|
||||
31
klausur-service/backend/nru_worksheet_pdf.py
Normal file
31
klausur-service/backend/nru_worksheet_pdf.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
NRU Worksheet PDF — PDF generation using weasyprint.
|
||||
|
||||
Extracted from nru_worksheet_generator.py for modularity.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from nru_worksheet_html import generate_nru_worksheet_html
|
||||
|
||||
|
||||
async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
Generate NRU worksheet PDFs.
|
||||
|
||||
Returns:
|
||||
Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
|
||||
"""
|
||||
from weasyprint import HTML
|
||||
|
||||
# Generate worksheet HTML
|
||||
worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
|
||||
worksheet_pdf = HTML(string=worksheet_html).write_pdf()
|
||||
|
||||
# Generate solution HTML
|
||||
solution_pdf = None
|
||||
if include_solutions:
|
||||
solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
|
||||
solution_pdf = HTML(string=solution_html).write_pdf()
|
||||
|
||||
return worksheet_pdf, solution_pdf
|
||||
333
klausur-service/backend/ocr_pipeline_overlay_grid.py
Normal file
333
klausur-service/backend/ocr_pipeline_overlay_grid.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Overlay rendering for columns, rows, and words (grid-based overlays).
|
||||
|
||||
Extracted from ocr_pipeline_overlays.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from ocr_pipeline_common import _get_base_image_png
|
||||
from ocr_pipeline_session_store import get_session_db
|
||||
from ocr_pipeline_rows import _draw_box_exclusion_overlay
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_columns_overlay(session_id: str) -> Response:
|
||||
"""Generate cropped (or dewarped) image with column borders drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise HTTPException(status_code=404, detail="No column data available")
|
||||
|
||||
# Load best available base image (cropped > dewarped > original)
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
# Color map for region types (BGR)
|
||||
colors = {
|
||||
"column_en": (255, 180, 0), # Blue
|
||||
"column_de": (0, 200, 0), # Green
|
||||
"column_example": (0, 140, 255), # Orange
|
||||
"column_text": (200, 200, 0), # Cyan/Turquoise
|
||||
"page_ref": (200, 0, 200), # Purple
|
||||
"column_marker": (0, 0, 220), # Red
|
||||
"column_ignore": (180, 180, 180), # Light Gray
|
||||
"header": (128, 128, 128), # Gray
|
||||
"footer": (128, 128, 128), # Gray
|
||||
"margin_top": (100, 100, 100), # Dark Gray
|
||||
"margin_bottom": (100, 100, 100), # Dark Gray
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
for col in column_result["columns"]:
|
||||
x, y = col["x"], col["y"]
|
||||
w, h = col["width"], col["height"]
|
||||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||
|
||||
# Solid border
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||||
|
||||
# Label with confidence
|
||||
label = col.get("type", "unknown").replace("column_", "").upper()
|
||||
conf = col.get("classification_confidence")
|
||||
if conf is not None and conf < 1.0:
|
||||
label = f"{label} {int(conf * 100)}%"
|
||||
cv2.putText(img, label, (x + 10, y + 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||||
|
||||
# Blend overlay at 20% opacity
|
||||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||||
|
||||
# Draw detected box boundaries as dashed rectangles
|
||||
zones = column_result.get("zones") or []
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||
box = zone["box"]
|
||||
bx, by = box["x"], box["y"]
|
||||
bw, bh = box["width"], box["height"]
|
||||
box_color = (0, 200, 255) # Yellow (BGR)
|
||||
# Draw dashed rectangle by drawing short line segments
|
||||
dash_len = 15
|
||||
for edge_x in range(bx, bx + bw, dash_len * 2):
|
||||
end_x = min(edge_x + dash_len, bx + bw)
|
||||
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
|
||||
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
|
||||
for edge_y in range(by, by + bh, dash_len * 2):
|
||||
end_y = min(edge_y + dash_len, by + bh)
|
||||
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
|
||||
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
|
||||
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
|
||||
|
||||
# Red semi-transparent overlay for box zones
|
||||
_draw_box_exclusion_overlay(img, zones)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
async def _get_rows_overlay(session_id: str) -> Response:
|
||||
"""Generate cropped (or dewarped) image with row bands drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
row_result = session.get("row_result")
|
||||
if not row_result or not row_result.get("rows"):
|
||||
raise HTTPException(status_code=404, detail="No row data available")
|
||||
|
||||
# Load best available base image (cropped > dewarped > original)
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
# Color map for row types (BGR)
|
||||
row_colors = {
|
||||
"content": (255, 180, 0), # Blue
|
||||
"header": (128, 128, 128), # Gray
|
||||
"footer": (128, 128, 128), # Gray
|
||||
"margin_top": (100, 100, 100), # Dark Gray
|
||||
"margin_bottom": (100, 100, 100), # Dark Gray
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
for row in row_result["rows"]:
|
||||
x, y = row["x"], row["y"]
|
||||
w, h = row["width"], row["height"]
|
||||
row_type = row.get("row_type", "content")
|
||||
color = row_colors.get(row_type, (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||
|
||||
# Solid border
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||||
|
||||
# Label
|
||||
idx = row.get("index", 0)
|
||||
label = f"R{idx} {row_type.upper()}"
|
||||
wc = row.get("word_count", 0)
|
||||
if wc:
|
||||
label = f"{label} ({wc}w)"
|
||||
cv2.putText(img, label, (x + 5, y + 18),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
|
||||
# Blend overlay at 15% opacity
|
||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||
|
||||
# Draw zone separator lines if zones exist
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
if zones:
|
||||
img_w_px = img.shape[1]
|
||||
zone_color = (0, 200, 255) # Yellow (BGR)
|
||||
dash_len = 20
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box":
|
||||
zy = zone["y"]
|
||||
zh = zone["height"]
|
||||
for line_y in [zy, zy + zh]:
|
||||
for sx in range(0, img_w_px, dash_len * 2):
|
||||
ex = min(sx + dash_len, img_w_px)
|
||||
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
|
||||
|
||||
# Red semi-transparent overlay for box zones
|
||||
_draw_box_exclusion_overlay(img, zones)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
async def _get_words_overlay(session_id: str) -> Response:
|
||||
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=404, detail="No word data available")
|
||||
|
||||
# Support both new cell-based and legacy entry-based formats
|
||||
cells = word_result.get("cells")
|
||||
if not cells and not word_result.get("entries"):
|
||||
raise HTTPException(status_code=404, detail="No word data available")
|
||||
|
||||
# Load best available base image (cropped > dewarped > original)
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
overlay = img.copy()
|
||||
|
||||
if cells:
|
||||
# New cell-based overlay: color by column index
|
||||
col_palette = [
|
||||
(255, 180, 0), # Blue (BGR)
|
||||
(0, 200, 0), # Green
|
||||
(0, 140, 255), # Orange
|
||||
(200, 100, 200), # Purple
|
||||
(200, 200, 0), # Cyan
|
||||
(100, 200, 200), # Yellow-ish
|
||||
]
|
||||
|
||||
for cell in cells:
|
||||
bbox = cell.get("bbox_px", {})
|
||||
cx = bbox.get("x", 0)
|
||||
cy = bbox.get("y", 0)
|
||||
cw = bbox.get("w", 0)
|
||||
ch = bbox.get("h", 0)
|
||||
if cw <= 0 or ch <= 0:
|
||||
continue
|
||||
|
||||
col_idx = cell.get("col_index", 0)
|
||||
color = col_palette[col_idx % len(col_palette)]
|
||||
|
||||
# Cell rectangle border
|
||||
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
||||
|
||||
# Cell-ID label (top-left corner)
|
||||
cell_id = cell.get("cell_id", "")
|
||||
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
||||
|
||||
# Text label (bottom of cell)
|
||||
text = cell.get("text", "")
|
||||
if text:
|
||||
conf = cell.get("confidence", 0)
|
||||
if conf >= 70:
|
||||
text_color = (0, 180, 0)
|
||||
elif conf >= 50:
|
||||
text_color = (0, 180, 220)
|
||||
else:
|
||||
text_color = (0, 0, 220)
|
||||
|
||||
label = text.replace('\n', ' ')[:30]
|
||||
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||||
else:
|
||||
# Legacy fallback: entry-based overlay (for old sessions)
|
||||
column_result = session.get("column_result")
|
||||
row_result = session.get("row_result")
|
||||
col_colors = {
|
||||
"column_en": (255, 180, 0),
|
||||
"column_de": (0, 200, 0),
|
||||
"column_example": (0, 140, 255),
|
||||
}
|
||||
|
||||
columns = []
|
||||
if column_result and column_result.get("columns"):
|
||||
columns = [c for c in column_result["columns"]
|
||||
if c.get("type", "").startswith("column_")]
|
||||
|
||||
content_rows_data = []
|
||||
if row_result and row_result.get("rows"):
|
||||
content_rows_data = [r for r in row_result["rows"]
|
||||
if r.get("row_type") == "content"]
|
||||
|
||||
for col in columns:
|
||||
col_type = col.get("type", "")
|
||||
color = col_colors.get(col_type, (200, 200, 200))
|
||||
cx, cw = col["x"], col["width"]
|
||||
for row in content_rows_data:
|
||||
ry, rh = row["y"], row["height"]
|
||||
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
||||
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
||||
|
||||
entries = word_result["entries"]
|
||||
entry_by_row: Dict[int, Dict] = {}
|
||||
for entry in entries:
|
||||
entry_by_row[entry.get("row_index", -1)] = entry
|
||||
|
||||
for row_idx, row in enumerate(content_rows_data):
|
||||
entry = entry_by_row.get(row_idx)
|
||||
if not entry:
|
||||
continue
|
||||
conf = entry.get("confidence", 0)
|
||||
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
||||
ry, rh = row["y"], row["height"]
|
||||
for col in columns:
|
||||
col_type = col.get("type", "")
|
||||
cx, cw = col["x"], col["width"]
|
||||
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
||||
text = entry.get(field, "") if field else ""
|
||||
if text:
|
||||
label = text.replace('\n', ' ')[:30]
|
||||
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||||
|
||||
# Blend overlay at 10% opacity
|
||||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||||
|
||||
# Red semi-transparent overlay for box zones
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
_draw_box_exclusion_overlay(img, zones)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
205
klausur-service/backend/ocr_pipeline_overlay_structure.py
Normal file
205
klausur-service/backend/ocr_pipeline_overlay_structure.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Overlay rendering for structure detection (boxes, zones, colors, graphics).
|
||||
|
||||
Extracted from ocr_pipeline_overlays.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from ocr_pipeline_common import _get_base_image_png
|
||||
from ocr_pipeline_session_store import get_session_db
|
||||
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
|
||||
from cv_box_detect import detect_boxes, split_page_into_zones
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_structure_overlay(session_id: str) -> Response:
|
||||
"""Generate overlay image showing detected boxes, zones, and color regions."""
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
# Get structure result (run detection if not cached)
|
||||
session = await get_session_db(session_id)
|
||||
structure = (session or {}).get("structure_result")
|
||||
|
||||
if not structure:
|
||||
# Run detection on-the-fly
|
||||
margin = int(min(w, h) * 0.03)
|
||||
content_x, content_y = margin, margin
|
||||
content_w_px = w - 2 * margin
|
||||
content_h_px = h - 2 * margin
|
||||
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
|
||||
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
||||
structure = {
|
||||
"boxes": [
|
||||
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
||||
"confidence": b.confidence, "border_thickness": b.border_thickness}
|
||||
for b in boxes
|
||||
],
|
||||
"zones": [
|
||||
{"index": z.index, "zone_type": z.zone_type,
|
||||
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
|
||||
for z in zones
|
||||
],
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
|
||||
# --- Draw zone boundaries ---
|
||||
zone_colors = {
|
||||
"content": (200, 200, 200), # light gray
|
||||
"box": (255, 180, 0), # blue-ish (BGR)
|
||||
}
|
||||
for zone in structure.get("zones", []):
|
||||
zx = zone["x"]
|
||||
zy = zone["y"]
|
||||
zw = zone["w"]
|
||||
zh = zone["h"]
|
||||
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
|
||||
|
||||
# Draw zone boundary as dashed line
|
||||
dash_len = 12
|
||||
for edge_x in range(zx, zx + zw, dash_len * 2):
|
||||
end_x = min(edge_x + dash_len, zx + zw)
|
||||
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
|
||||
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
|
||||
|
||||
# Zone label
|
||||
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
|
||||
cv2.putText(img, zone_label, (zx + 5, zy + 15),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
|
||||
|
||||
# --- Draw detected boxes ---
|
||||
# Color map for box backgrounds (BGR)
|
||||
bg_hex_to_bgr = {
|
||||
"#dc2626": (38, 38, 220), # red
|
||||
"#2563eb": (235, 99, 37), # blue
|
||||
"#16a34a": (74, 163, 22), # green
|
||||
"#ea580c": (12, 88, 234), # orange
|
||||
"#9333ea": (234, 51, 147), # purple
|
||||
"#ca8a04": (4, 138, 202), # yellow
|
||||
"#6b7280": (128, 114, 107), # gray
|
||||
}
|
||||
|
||||
for box_data in structure.get("boxes", []):
|
||||
bx = box_data["x"]
|
||||
by = box_data["y"]
|
||||
bw = box_data["w"]
|
||||
bh = box_data["h"]
|
||||
conf = box_data.get("confidence", 0)
|
||||
thickness = box_data.get("border_thickness", 0)
|
||||
bg_hex = box_data.get("bg_color_hex", "#6b7280")
|
||||
bg_name = box_data.get("bg_color_name", "")
|
||||
|
||||
# Box fill color
|
||||
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
|
||||
|
||||
# Solid border
|
||||
border_color = fill_bgr
|
||||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
|
||||
|
||||
# Label
|
||||
label = f"BOX"
|
||||
if bg_name and bg_name not in ("unknown", "white"):
|
||||
label += f" ({bg_name})"
|
||||
if thickness > 0:
|
||||
label += f" border={thickness}px"
|
||||
label += f" {int(conf * 100)}%"
|
||||
cv2.putText(img, label, (bx + 8, by + 22),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
|
||||
cv2.putText(img, label, (bx + 8, by + 22),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
|
||||
|
||||
# Blend overlay at 15% opacity
|
||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||
|
||||
# --- Draw color regions (HSV masks) ---
|
||||
hsv = cv2.cvtColor(
|
||||
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
|
||||
cv2.COLOR_BGR2HSV,
|
||||
)
|
||||
color_bgr_map = {
|
||||
"red": (0, 0, 255),
|
||||
"orange": (0, 140, 255),
|
||||
"yellow": (0, 200, 255),
|
||||
"green": (0, 200, 0),
|
||||
"blue": (255, 150, 0),
|
||||
"purple": (200, 0, 200),
|
||||
}
|
||||
for color_name, ranges in _COLOR_RANGES.items():
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
for lower, upper in ranges:
|
||||
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
||||
# Only draw if there are significant colored pixels
|
||||
if np.sum(mask > 0) < 100:
|
||||
continue
|
||||
# Draw colored contours
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
if area < 20:
|
||||
continue
|
||||
cv2.drawContours(img, [cnt], -1, draw_color, 2)
|
||||
|
||||
# --- Draw graphic elements ---
|
||||
graphics_data = structure.get("graphics", [])
|
||||
shape_icons = {
|
||||
"image": "IMAGE",
|
||||
"illustration": "ILLUST",
|
||||
}
|
||||
for gfx in graphics_data:
|
||||
gx, gy = gfx["x"], gfx["y"]
|
||||
gw, gh = gfx["w"], gfx["h"]
|
||||
shape = gfx.get("shape", "icon")
|
||||
color_hex = gfx.get("color_hex", "#6b7280")
|
||||
conf = gfx.get("confidence", 0)
|
||||
|
||||
# Pick draw color based on element color (BGR)
|
||||
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
|
||||
|
||||
# Draw bounding box (dashed style via short segments)
|
||||
dash = 6
|
||||
for seg_x in range(gx, gx + gw, dash * 2):
|
||||
end_x = min(seg_x + dash, gx + gw)
|
||||
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
|
||||
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
|
||||
for seg_y in range(gy, gy + gh, dash * 2):
|
||||
end_y = min(seg_y + dash, gy + gh)
|
||||
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
|
||||
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
|
||||
|
||||
# Label
|
||||
icon = shape_icons.get(shape, shape.upper()[:5])
|
||||
label = f"{icon} {int(conf * 100)}%"
|
||||
# White background for readability
|
||||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
||||
lx = gx + 2
|
||||
ly = max(gy - 4, th + 4)
|
||||
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
|
||||
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
|
||||
|
||||
# Encode result
|
||||
_, png_buf = cv2.imencode(".png", img)
|
||||
return Response(content=png_buf.tobytes(), media_type="image/png")
|
||||
@@ -1,34 +1,23 @@
|
||||
"""
|
||||
Overlay image rendering for OCR pipeline.
|
||||
Overlay image rendering for OCR pipeline — barrel re-export.
|
||||
|
||||
Generates visual overlays for structure, columns, rows, and words
|
||||
detection results.
|
||||
All implementation split into:
|
||||
ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
|
||||
ocr_pipeline_overlay_grid — columns, rows, words overlays
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_get_base_image_png,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401
|
||||
from ocr_pipeline_overlay_grid import ( # noqa: F401
|
||||
_get_columns_overlay,
|
||||
_get_rows_overlay,
|
||||
_get_words_overlay,
|
||||
)
|
||||
from ocr_pipeline_session_store import get_session_db, get_session_image
|
||||
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
|
||||
from cv_box_detect import detect_boxes, split_page_into_zones
|
||||
from ocr_pipeline_rows import _draw_box_exclusion_overlay
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def render_overlay(overlay_type: str, session_id: str) -> Response:
|
||||
@@ -43,505 +32,3 @@ async def render_overlay(overlay_type: str, session_id: str) -> Response:
|
||||
return await _get_words_overlay(session_id)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")
|
||||
|
||||
|
||||
async def _get_structure_overlay(session_id: str) -> Response:
|
||||
"""Generate overlay image showing detected boxes, zones, and color regions."""
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
# Get structure result (run detection if not cached)
|
||||
session = await get_session_db(session_id)
|
||||
structure = (session or {}).get("structure_result")
|
||||
|
||||
if not structure:
|
||||
# Run detection on-the-fly
|
||||
margin = int(min(w, h) * 0.03)
|
||||
content_x, content_y = margin, margin
|
||||
content_w_px = w - 2 * margin
|
||||
content_h_px = h - 2 * margin
|
||||
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
|
||||
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
||||
structure = {
|
||||
"boxes": [
|
||||
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
||||
"confidence": b.confidence, "border_thickness": b.border_thickness}
|
||||
for b in boxes
|
||||
],
|
||||
"zones": [
|
||||
{"index": z.index, "zone_type": z.zone_type,
|
||||
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
|
||||
for z in zones
|
||||
],
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
|
||||
# --- Draw zone boundaries ---
|
||||
zone_colors = {
|
||||
"content": (200, 200, 200), # light gray
|
||||
"box": (255, 180, 0), # blue-ish (BGR)
|
||||
}
|
||||
for zone in structure.get("zones", []):
|
||||
zx = zone["x"]
|
||||
zy = zone["y"]
|
||||
zw = zone["w"]
|
||||
zh = zone["h"]
|
||||
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
|
||||
|
||||
# Draw zone boundary as dashed line
|
||||
dash_len = 12
|
||||
for edge_x in range(zx, zx + zw, dash_len * 2):
|
||||
end_x = min(edge_x + dash_len, zx + zw)
|
||||
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
|
||||
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
|
||||
|
||||
# Zone label
|
||||
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
|
||||
cv2.putText(img, zone_label, (zx + 5, zy + 15),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
|
||||
|
||||
# --- Draw detected boxes ---
|
||||
# Color map for box backgrounds (BGR)
|
||||
bg_hex_to_bgr = {
|
||||
"#dc2626": (38, 38, 220), # red
|
||||
"#2563eb": (235, 99, 37), # blue
|
||||
"#16a34a": (74, 163, 22), # green
|
||||
"#ea580c": (12, 88, 234), # orange
|
||||
"#9333ea": (234, 51, 147), # purple
|
||||
"#ca8a04": (4, 138, 202), # yellow
|
||||
"#6b7280": (128, 114, 107), # gray
|
||||
}
|
||||
|
||||
for box_data in structure.get("boxes", []):
|
||||
bx = box_data["x"]
|
||||
by = box_data["y"]
|
||||
bw = box_data["w"]
|
||||
bh = box_data["h"]
|
||||
conf = box_data.get("confidence", 0)
|
||||
thickness = box_data.get("border_thickness", 0)
|
||||
bg_hex = box_data.get("bg_color_hex", "#6b7280")
|
||||
bg_name = box_data.get("bg_color_name", "")
|
||||
|
||||
# Box fill color
|
||||
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
|
||||
|
||||
# Solid border
|
||||
border_color = fill_bgr
|
||||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
|
||||
|
||||
# Label
|
||||
label = f"BOX"
|
||||
if bg_name and bg_name not in ("unknown", "white"):
|
||||
label += f" ({bg_name})"
|
||||
if thickness > 0:
|
||||
label += f" border={thickness}px"
|
||||
label += f" {int(conf * 100)}%"
|
||||
cv2.putText(img, label, (bx + 8, by + 22),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
|
||||
cv2.putText(img, label, (bx + 8, by + 22),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
|
||||
|
||||
# Blend overlay at 15% opacity
|
||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||
|
||||
# --- Draw color regions (HSV masks) ---
|
||||
hsv = cv2.cvtColor(
|
||||
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
|
||||
cv2.COLOR_BGR2HSV,
|
||||
)
|
||||
color_bgr_map = {
|
||||
"red": (0, 0, 255),
|
||||
"orange": (0, 140, 255),
|
||||
"yellow": (0, 200, 255),
|
||||
"green": (0, 200, 0),
|
||||
"blue": (255, 150, 0),
|
||||
"purple": (200, 0, 200),
|
||||
}
|
||||
for color_name, ranges in _COLOR_RANGES.items():
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
for lower, upper in ranges:
|
||||
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
||||
# Only draw if there are significant colored pixels
|
||||
if np.sum(mask > 0) < 100:
|
||||
continue
|
||||
# Draw colored contours
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
if area < 20:
|
||||
continue
|
||||
cv2.drawContours(img, [cnt], -1, draw_color, 2)
|
||||
|
||||
# --- Draw graphic elements ---
|
||||
graphics_data = structure.get("graphics", [])
|
||||
shape_icons = {
|
||||
"image": "IMAGE",
|
||||
"illustration": "ILLUST",
|
||||
}
|
||||
for gfx in graphics_data:
|
||||
gx, gy = gfx["x"], gfx["y"]
|
||||
gw, gh = gfx["w"], gfx["h"]
|
||||
shape = gfx.get("shape", "icon")
|
||||
color_hex = gfx.get("color_hex", "#6b7280")
|
||||
conf = gfx.get("confidence", 0)
|
||||
|
||||
# Pick draw color based on element color (BGR)
|
||||
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
|
||||
|
||||
# Draw bounding box (dashed style via short segments)
|
||||
dash = 6
|
||||
for seg_x in range(gx, gx + gw, dash * 2):
|
||||
end_x = min(seg_x + dash, gx + gw)
|
||||
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
|
||||
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
|
||||
for seg_y in range(gy, gy + gh, dash * 2):
|
||||
end_y = min(seg_y + dash, gy + gh)
|
||||
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
|
||||
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
|
||||
|
||||
# Label
|
||||
icon = shape_icons.get(shape, shape.upper()[:5])
|
||||
label = f"{icon} {int(conf * 100)}%"
|
||||
# White background for readability
|
||||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
||||
lx = gx + 2
|
||||
ly = max(gy - 4, th + 4)
|
||||
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
|
||||
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
|
||||
|
||||
# Encode result
|
||||
_, png_buf = cv2.imencode(".png", img)
|
||||
return Response(content=png_buf.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
|
||||
async def _get_columns_overlay(session_id: str) -> Response:
|
||||
"""Generate cropped (or dewarped) image with column borders drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise HTTPException(status_code=404, detail="No column data available")
|
||||
|
||||
# Load best available base image (cropped > dewarped > original)
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
# Color map for region types (BGR)
|
||||
colors = {
|
||||
"column_en": (255, 180, 0), # Blue
|
||||
"column_de": (0, 200, 0), # Green
|
||||
"column_example": (0, 140, 255), # Orange
|
||||
"column_text": (200, 200, 0), # Cyan/Turquoise
|
||||
"page_ref": (200, 0, 200), # Purple
|
||||
"column_marker": (0, 0, 220), # Red
|
||||
"column_ignore": (180, 180, 180), # Light Gray
|
||||
"header": (128, 128, 128), # Gray
|
||||
"footer": (128, 128, 128), # Gray
|
||||
"margin_top": (100, 100, 100), # Dark Gray
|
||||
"margin_bottom": (100, 100, 100), # Dark Gray
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
for col in column_result["columns"]:
|
||||
x, y = col["x"], col["y"]
|
||||
w, h = col["width"], col["height"]
|
||||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||
|
||||
# Solid border
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||||
|
||||
# Label with confidence
|
||||
label = col.get("type", "unknown").replace("column_", "").upper()
|
||||
conf = col.get("classification_confidence")
|
||||
if conf is not None and conf < 1.0:
|
||||
label = f"{label} {int(conf * 100)}%"
|
||||
cv2.putText(img, label, (x + 10, y + 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||||
|
||||
# Blend overlay at 20% opacity
|
||||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||||
|
||||
# Draw detected box boundaries as dashed rectangles
|
||||
zones = column_result.get("zones") or []
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||
box = zone["box"]
|
||||
bx, by = box["x"], box["y"]
|
||||
bw, bh = box["width"], box["height"]
|
||||
box_color = (0, 200, 255) # Yellow (BGR)
|
||||
# Draw dashed rectangle by drawing short line segments
|
||||
dash_len = 15
|
||||
for edge_x in range(bx, bx + bw, dash_len * 2):
|
||||
end_x = min(edge_x + dash_len, bx + bw)
|
||||
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
|
||||
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
|
||||
for edge_y in range(by, by + bh, dash_len * 2):
|
||||
end_y = min(edge_y + dash_len, by + bh)
|
||||
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
|
||||
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
|
||||
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
|
||||
|
||||
# Red semi-transparent overlay for box zones
|
||||
_draw_box_exclusion_overlay(img, zones)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Row Detection Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
async def _get_rows_overlay(session_id: str) -> Response:
|
||||
"""Generate cropped (or dewarped) image with row bands drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
row_result = session.get("row_result")
|
||||
if not row_result or not row_result.get("rows"):
|
||||
raise HTTPException(status_code=404, detail="No row data available")
|
||||
|
||||
# Load best available base image (cropped > dewarped > original)
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
# Color map for row types (BGR)
|
||||
row_colors = {
|
||||
"content": (255, 180, 0), # Blue
|
||||
"header": (128, 128, 128), # Gray
|
||||
"footer": (128, 128, 128), # Gray
|
||||
"margin_top": (100, 100, 100), # Dark Gray
|
||||
"margin_bottom": (100, 100, 100), # Dark Gray
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
for row in row_result["rows"]:
|
||||
x, y = row["x"], row["y"]
|
||||
w, h = row["width"], row["height"]
|
||||
row_type = row.get("row_type", "content")
|
||||
color = row_colors.get(row_type, (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||
|
||||
# Solid border
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||||
|
||||
# Label
|
||||
idx = row.get("index", 0)
|
||||
label = f"R{idx} {row_type.upper()}"
|
||||
wc = row.get("word_count", 0)
|
||||
if wc:
|
||||
label = f"{label} ({wc}w)"
|
||||
cv2.putText(img, label, (x + 5, y + 18),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
|
||||
# Blend overlay at 15% opacity
|
||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||
|
||||
# Draw zone separator lines if zones exist
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
if zones:
|
||||
img_w_px = img.shape[1]
|
||||
zone_color = (0, 200, 255) # Yellow (BGR)
|
||||
dash_len = 20
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box":
|
||||
zy = zone["y"]
|
||||
zh = zone["height"]
|
||||
for line_y in [zy, zy + zh]:
|
||||
for sx in range(0, img_w_px, dash_len * 2):
|
||||
ex = min(sx + dash_len, img_w_px)
|
||||
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
|
||||
|
||||
# Red semi-transparent overlay for box zones
|
||||
_draw_box_exclusion_overlay(img, zones)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
|
||||
async def _get_words_overlay(session_id: str) -> Response:
|
||||
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=404, detail="No word data available")
|
||||
|
||||
# Support both new cell-based and legacy entry-based formats
|
||||
cells = word_result.get("cells")
|
||||
if not cells and not word_result.get("entries"):
|
||||
raise HTTPException(status_code=404, detail="No word data available")
|
||||
|
||||
# Load best available base image (cropped > dewarped > original)
|
||||
base_png = await _get_base_image_png(session_id)
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=404, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
overlay = img.copy()
|
||||
|
||||
if cells:
|
||||
# New cell-based overlay: color by column index
|
||||
col_palette = [
|
||||
(255, 180, 0), # Blue (BGR)
|
||||
(0, 200, 0), # Green
|
||||
(0, 140, 255), # Orange
|
||||
(200, 100, 200), # Purple
|
||||
(200, 200, 0), # Cyan
|
||||
(100, 200, 200), # Yellow-ish
|
||||
]
|
||||
|
||||
for cell in cells:
|
||||
bbox = cell.get("bbox_px", {})
|
||||
cx = bbox.get("x", 0)
|
||||
cy = bbox.get("y", 0)
|
||||
cw = bbox.get("w", 0)
|
||||
ch = bbox.get("h", 0)
|
||||
if cw <= 0 or ch <= 0:
|
||||
continue
|
||||
|
||||
col_idx = cell.get("col_index", 0)
|
||||
color = col_palette[col_idx % len(col_palette)]
|
||||
|
||||
# Cell rectangle border
|
||||
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
||||
|
||||
# Cell-ID label (top-left corner)
|
||||
cell_id = cell.get("cell_id", "")
|
||||
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
||||
|
||||
# Text label (bottom of cell)
|
||||
text = cell.get("text", "")
|
||||
if text:
|
||||
conf = cell.get("confidence", 0)
|
||||
if conf >= 70:
|
||||
text_color = (0, 180, 0)
|
||||
elif conf >= 50:
|
||||
text_color = (0, 180, 220)
|
||||
else:
|
||||
text_color = (0, 0, 220)
|
||||
|
||||
label = text.replace('\n', ' ')[:30]
|
||||
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||||
else:
|
||||
# Legacy fallback: entry-based overlay (for old sessions)
|
||||
column_result = session.get("column_result")
|
||||
row_result = session.get("row_result")
|
||||
col_colors = {
|
||||
"column_en": (255, 180, 0),
|
||||
"column_de": (0, 200, 0),
|
||||
"column_example": (0, 140, 255),
|
||||
}
|
||||
|
||||
columns = []
|
||||
if column_result and column_result.get("columns"):
|
||||
columns = [c for c in column_result["columns"]
|
||||
if c.get("type", "").startswith("column_")]
|
||||
|
||||
content_rows_data = []
|
||||
if row_result and row_result.get("rows"):
|
||||
content_rows_data = [r for r in row_result["rows"]
|
||||
if r.get("row_type") == "content"]
|
||||
|
||||
for col in columns:
|
||||
col_type = col.get("type", "")
|
||||
color = col_colors.get(col_type, (200, 200, 200))
|
||||
cx, cw = col["x"], col["width"]
|
||||
for row in content_rows_data:
|
||||
ry, rh = row["y"], row["height"]
|
||||
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
||||
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
||||
|
||||
entries = word_result["entries"]
|
||||
entry_by_row: Dict[int, Dict] = {}
|
||||
for entry in entries:
|
||||
entry_by_row[entry.get("row_index", -1)] = entry
|
||||
|
||||
for row_idx, row in enumerate(content_rows_data):
|
||||
entry = entry_by_row.get(row_idx)
|
||||
if not entry:
|
||||
continue
|
||||
conf = entry.get("confidence", 0)
|
||||
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
||||
ry, rh = row["y"], row["height"]
|
||||
for col in columns:
|
||||
col_type = col.get("type", "")
|
||||
cx, cw = col["x"], col["width"]
|
||||
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
||||
text = entry.get(field, "") if field else ""
|
||||
if text:
|
||||
label = text.replace('\n', ' ')[:30]
|
||||
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||||
|
||||
# Blend overlay at 10% opacity
|
||||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||||
|
||||
# Red semi-transparent overlay for box zones
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
_draw_box_exclusion_overlay(img, zones)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
@@ -1,607 +1,22 @@
|
||||
"""
|
||||
OCR Pipeline Regression Tests — Ground Truth comparison system.
|
||||
OCR Pipeline Regression Tests — barrel re-export.
|
||||
|
||||
Allows marking sessions as "ground truth" and re-running build_grid()
|
||||
to detect regressions after code changes.
|
||||
All implementation split into:
|
||||
ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
|
||||
ocr_pipeline_regression_endpoints — FastAPI routes
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from grid_editor_api import _build_grid_core
|
||||
from ocr_pipeline_session_store import (
|
||||
get_pool,
|
||||
get_session_db,
|
||||
list_ground_truth_sessions_db,
|
||||
update_session_db,
|
||||
# Helpers (used by grid_editor_api_grid.py)
|
||||
from ocr_pipeline_regression_helpers import ( # noqa: F401
|
||||
_init_regression_table,
|
||||
_persist_regression_run,
|
||||
_extract_cells_for_comparison,
|
||||
_build_reference_snapshot,
|
||||
compare_grids,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB persistence for regression runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _init_regression_table():
|
||||
"""Ensure regression_runs table exists (idempotent)."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
migration_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"migrations/008_regression_runs.sql",
|
||||
)
|
||||
if os.path.exists(migration_path):
|
||||
with open(migration_path, "r") as f:
|
||||
sql = f.read()
|
||||
await conn.execute(sql)
|
||||
|
||||
|
||||
async def _persist_regression_run(
|
||||
status: str,
|
||||
summary: dict,
|
||||
results: list,
|
||||
duration_ms: int,
|
||||
triggered_by: str = "manual",
|
||||
) -> str:
|
||||
"""Save a regression run to the database. Returns the run ID."""
|
||||
try:
|
||||
await _init_regression_table()
|
||||
pool = await get_pool()
|
||||
run_id = str(uuid.uuid4())
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO regression_runs
|
||||
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
|
||||
""",
|
||||
run_id,
|
||||
status,
|
||||
summary.get("total", 0),
|
||||
summary.get("passed", 0),
|
||||
summary.get("failed", 0),
|
||||
summary.get("errors", 0),
|
||||
duration_ms,
|
||||
json.dumps(results),
|
||||
triggered_by,
|
||||
)
|
||||
logger.info("Regression run %s persisted: %s", run_id, status)
|
||||
return run_id
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist regression run: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
|
||||
"""Extract a flat list of cells from a grid_editor_result for comparison.
|
||||
|
||||
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
|
||||
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
|
||||
"""
|
||||
cells = []
|
||||
for zone in grid_result.get("zones", []):
|
||||
for cell in zone.get("cells", []):
|
||||
cells.append({
|
||||
"cell_id": cell.get("cell_id", ""),
|
||||
"row_index": cell.get("row_index"),
|
||||
"col_index": cell.get("col_index"),
|
||||
"col_type": cell.get("col_type", ""),
|
||||
"text": cell.get("text", ""),
|
||||
})
|
||||
return cells
|
||||
|
||||
|
||||
def _build_reference_snapshot(
|
||||
grid_result: dict,
|
||||
pipeline: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Build a ground-truth reference snapshot from a grid_editor_result."""
|
||||
cells = _extract_cells_for_comparison(grid_result)
|
||||
|
||||
total_zones = len(grid_result.get("zones", []))
|
||||
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
|
||||
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
|
||||
|
||||
snapshot = {
|
||||
"saved_at": datetime.now(timezone.utc).isoformat(),
|
||||
"version": 1,
|
||||
"pipeline": pipeline,
|
||||
"summary": {
|
||||
"total_zones": total_zones,
|
||||
"total_columns": total_columns,
|
||||
"total_rows": total_rows,
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"cells": cells,
|
||||
}
|
||||
return snapshot
|
||||
|
||||
|
||||
def compare_grids(reference: dict, current: dict) -> dict:
|
||||
"""Compare a reference grid snapshot with a newly computed one.
|
||||
|
||||
Returns a diff report with:
|
||||
- status: "pass" or "fail"
|
||||
- structural_diffs: changes in zone/row/column counts
|
||||
- cell_diffs: list of individual cell changes
|
||||
"""
|
||||
ref_summary = reference.get("summary", {})
|
||||
cur_summary = current.get("summary", {})
|
||||
|
||||
structural_diffs = []
|
||||
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
|
||||
ref_val = ref_summary.get(key, 0)
|
||||
cur_val = cur_summary.get(key, 0)
|
||||
if ref_val != cur_val:
|
||||
structural_diffs.append({
|
||||
"field": key,
|
||||
"reference": ref_val,
|
||||
"current": cur_val,
|
||||
})
|
||||
|
||||
# Build cell lookup by cell_id
|
||||
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
|
||||
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
|
||||
|
||||
cell_diffs: List[Dict[str, Any]] = []
|
||||
|
||||
# Check for missing cells (in reference but not in current)
|
||||
for cell_id in ref_cells:
|
||||
if cell_id not in cur_cells:
|
||||
cell_diffs.append({
|
||||
"type": "cell_missing",
|
||||
"cell_id": cell_id,
|
||||
"reference_text": ref_cells[cell_id].get("text", ""),
|
||||
})
|
||||
|
||||
# Check for added cells (in current but not in reference)
|
||||
for cell_id in cur_cells:
|
||||
if cell_id not in ref_cells:
|
||||
cell_diffs.append({
|
||||
"type": "cell_added",
|
||||
"cell_id": cell_id,
|
||||
"current_text": cur_cells[cell_id].get("text", ""),
|
||||
})
|
||||
|
||||
# Check for changes in shared cells
|
||||
for cell_id in ref_cells:
|
||||
if cell_id not in cur_cells:
|
||||
continue
|
||||
ref_cell = ref_cells[cell_id]
|
||||
cur_cell = cur_cells[cell_id]
|
||||
|
||||
if ref_cell.get("text", "") != cur_cell.get("text", ""):
|
||||
cell_diffs.append({
|
||||
"type": "text_change",
|
||||
"cell_id": cell_id,
|
||||
"reference": ref_cell.get("text", ""),
|
||||
"current": cur_cell.get("text", ""),
|
||||
})
|
||||
|
||||
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
|
||||
cell_diffs.append({
|
||||
"type": "col_type_change",
|
||||
"cell_id": cell_id,
|
||||
"reference": ref_cell.get("col_type", ""),
|
||||
"current": cur_cell.get("col_type", ""),
|
||||
})
|
||||
|
||||
status = "pass" if not structural_diffs and not cell_diffs else "fail"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"structural_diffs": structural_diffs,
|
||||
"cell_diffs": cell_diffs,
|
||||
"summary": {
|
||||
"structural_changes": len(structural_diffs),
|
||||
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
|
||||
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
|
||||
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
|
||||
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/mark-ground-truth")
|
||||
async def mark_ground_truth(
|
||||
session_id: str,
|
||||
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
|
||||
):
|
||||
"""Save the current build-grid result as ground-truth reference."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_result = session.get("grid_editor_result")
|
||||
if not grid_result or not grid_result.get("zones"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No grid_editor_result found. Run build-grid first.",
|
||||
)
|
||||
|
||||
# Auto-detect pipeline from word_result if not provided
|
||||
if not pipeline:
|
||||
wr = session.get("word_result") or {}
|
||||
engine = wr.get("ocr_engine", "")
|
||||
if engine in ("kombi", "rapid_kombi"):
|
||||
pipeline = "kombi"
|
||||
elif engine == "paddle_direct":
|
||||
pipeline = "paddle-direct"
|
||||
else:
|
||||
pipeline = "pipeline"
|
||||
|
||||
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
|
||||
|
||||
# Merge into existing ground_truth JSONB
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["build_grid_reference"] = reference
|
||||
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
||||
|
||||
# Compare with auto-snapshot if available (shows what the user corrected)
|
||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||
correction_diff = None
|
||||
if auto_snapshot:
|
||||
correction_diff = compare_grids(auto_snapshot, reference)
|
||||
|
||||
logger.info(
|
||||
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
||||
session_id,
|
||||
len(reference["cells"]),
|
||||
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"session_id": session_id,
|
||||
"cells_saved": len(reference["cells"]),
|
||||
"summary": reference["summary"],
|
||||
"correction_diff": correction_diff,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}/mark-ground-truth")
|
||||
async def unmark_ground_truth(session_id: str):
|
||||
"""Remove the ground-truth reference from a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
if "build_grid_reference" not in gt:
|
||||
raise HTTPException(status_code=404, detail="No ground truth reference found")
|
||||
|
||||
del gt["build_grid_reference"]
|
||||
await update_session_db(session_id, ground_truth=gt)
|
||||
|
||||
logger.info("Ground truth removed for session %s", session_id)
|
||||
return {"status": "ok", "session_id": session_id}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/correction-diff")
|
||||
async def get_correction_diff(session_id: str):
|
||||
"""Compare automatic OCR grid with manually corrected ground truth.
|
||||
|
||||
Returns a diff showing exactly which cells the user corrected,
|
||||
broken down by col_type (english, german, ipa, etc.).
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||
reference = gt.get("build_grid_reference")
|
||||
|
||||
if not auto_snapshot:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
||||
)
|
||||
if not reference:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No ground truth reference found. Mark as ground truth first.",
|
||||
)
|
||||
|
||||
diff = compare_grids(auto_snapshot, reference)
|
||||
|
||||
# Enrich with per-col_type breakdown
|
||||
col_type_stats: Dict[str, Dict[str, int]] = {}
|
||||
for cell_diff in diff.get("cell_diffs", []):
|
||||
if cell_diff["type"] != "text_change":
|
||||
continue
|
||||
# Find col_type from reference cells
|
||||
cell_id = cell_diff["cell_id"]
|
||||
ref_cell = next(
|
||||
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
||||
None,
|
||||
)
|
||||
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
||||
if ct not in col_type_stats:
|
||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||
col_type_stats[ct]["corrected"] += 1
|
||||
|
||||
# Count total cells per col_type from reference
|
||||
for cell in reference.get("cells", []):
|
||||
ct = cell.get("col_type", "unknown")
|
||||
if ct not in col_type_stats:
|
||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||
col_type_stats[ct]["total"] += 1
|
||||
|
||||
# Calculate accuracy per col_type
|
||||
for ct, stats in col_type_stats.items():
|
||||
total = stats["total"]
|
||||
corrected = stats["corrected"]
|
||||
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
||||
|
||||
diff["col_type_breakdown"] = col_type_stats
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
@router.get("/ground-truth-sessions")
|
||||
async def list_ground_truth_sessions():
|
||||
"""List all sessions that have a ground-truth reference."""
|
||||
sessions = await list_ground_truth_sessions_db()
|
||||
|
||||
result = []
|
||||
for s in sessions:
|
||||
gt = s.get("ground_truth") or {}
|
||||
ref = gt.get("build_grid_reference", {})
|
||||
result.append({
|
||||
"session_id": s["id"],
|
||||
"name": s.get("name", ""),
|
||||
"filename": s.get("filename", ""),
|
||||
"document_category": s.get("document_category"),
|
||||
"pipeline": ref.get("pipeline"),
|
||||
"saved_at": ref.get("saved_at"),
|
||||
"summary": ref.get("summary", {}),
|
||||
})
|
||||
|
||||
return {"sessions": result, "count": len(result)}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/regression/run")
|
||||
async def run_single_regression(session_id: str):
|
||||
"""Re-run build_grid for a single session and compare to ground truth."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
reference = gt.get("build_grid_reference")
|
||||
if not reference:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No ground truth reference found for this session",
|
||||
)
|
||||
|
||||
# Re-compute grid without persisting
|
||||
try:
|
||||
new_result = await _build_grid_core(session_id, session)
|
||||
except (ValueError, Exception) as e:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"name": session.get("name", ""),
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
new_snapshot = _build_reference_snapshot(new_result)
|
||||
diff = compare_grids(reference, new_snapshot)
|
||||
|
||||
logger.info(
|
||||
"Regression test session %s: %s (%d structural, %d cell diffs)",
|
||||
session_id, diff["status"],
|
||||
diff["summary"]["structural_changes"],
|
||||
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"name": session.get("name", ""),
|
||||
"status": diff["status"],
|
||||
"diff": diff,
|
||||
"reference_summary": reference.get("summary", {}),
|
||||
"current_summary": new_snapshot.get("summary", {}),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/regression/run")
|
||||
async def run_all_regressions(
|
||||
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
|
||||
):
|
||||
"""Re-run build_grid for ALL ground-truth sessions and compare."""
|
||||
start_time = time.monotonic()
|
||||
sessions = await list_ground_truth_sessions_db()
|
||||
|
||||
if not sessions:
|
||||
return {
|
||||
"status": "pass",
|
||||
"message": "No ground truth sessions found",
|
||||
"results": [],
|
||||
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
||||
}
|
||||
|
||||
results = []
|
||||
passed = 0
|
||||
failed = 0
|
||||
errors = 0
|
||||
|
||||
for s in sessions:
|
||||
session_id = s["id"]
|
||||
gt = s.get("ground_truth") or {}
|
||||
reference = gt.get("build_grid_reference")
|
||||
if not reference:
|
||||
continue
|
||||
|
||||
# Re-load full session (list query may not include all JSONB fields)
|
||||
full_session = await get_session_db(session_id)
|
||||
if not full_session:
|
||||
results.append({
|
||||
"session_id": session_id,
|
||||
"name": s.get("name", ""),
|
||||
"status": "error",
|
||||
"error": "Session not found during re-load",
|
||||
})
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
new_result = await _build_grid_core(session_id, full_session)
|
||||
except (ValueError, Exception) as e:
|
||||
results.append({
|
||||
"session_id": session_id,
|
||||
"name": s.get("name", ""),
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
})
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
new_snapshot = _build_reference_snapshot(new_result)
|
||||
diff = compare_grids(reference, new_snapshot)
|
||||
|
||||
entry = {
|
||||
"session_id": session_id,
|
||||
"name": s.get("name", ""),
|
||||
"status": diff["status"],
|
||||
"diff_summary": diff["summary"],
|
||||
"reference_summary": reference.get("summary", {}),
|
||||
"current_summary": new_snapshot.get("summary", {}),
|
||||
}
|
||||
|
||||
# Include full diffs only for failures (keep response compact)
|
||||
if diff["status"] == "fail":
|
||||
entry["structural_diffs"] = diff["structural_diffs"]
|
||||
entry["cell_diffs"] = diff["cell_diffs"]
|
||||
failed += 1
|
||||
else:
|
||||
passed += 1
|
||||
|
||||
results.append(entry)
|
||||
|
||||
overall = "pass" if failed == 0 and errors == 0 else "fail"
|
||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||
|
||||
summary = {
|
||||
"total": len(results),
|
||||
"passed": passed,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
|
||||
overall, passed, failed, errors, len(results), duration_ms,
|
||||
)
|
||||
|
||||
# Persist to DB
|
||||
run_id = await _persist_regression_run(
|
||||
status=overall,
|
||||
summary=summary,
|
||||
results=results,
|
||||
duration_ms=duration_ms,
|
||||
triggered_by=triggered_by,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": overall,
|
||||
"run_id": run_id,
|
||||
"duration_ms": duration_ms,
|
||||
"results": results,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/regression/history")
|
||||
async def get_regression_history(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""Get recent regression run history from the database."""
|
||||
try:
|
||||
await _init_regression_table()
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, run_at, status, total, passed, failed, errors,
|
||||
duration_ms, triggered_by
|
||||
FROM regression_runs
|
||||
ORDER BY run_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit,
|
||||
)
|
||||
return {
|
||||
"runs": [
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
||||
"status": row["status"],
|
||||
"total": row["total"],
|
||||
"passed": row["passed"],
|
||||
"failed": row["failed"],
|
||||
"errors": row["errors"],
|
||||
"duration_ms": row["duration_ms"],
|
||||
"triggered_by": row["triggered_by"],
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
"count": len(rows),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch regression history: %s", e)
|
||||
return {"runs": [], "count": 0, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/regression/history/{run_id}")
|
||||
async def get_regression_run_detail(run_id: str):
|
||||
"""Get detailed results of a specific regression run."""
|
||||
try:
|
||||
await _init_regression_table()
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT * FROM regression_runs WHERE id = $1",
|
||||
run_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
||||
"status": row["status"],
|
||||
"total": row["total"],
|
||||
"passed": row["passed"],
|
||||
"failed": row["failed"],
|
||||
"errors": row["errors"],
|
||||
"duration_ms": row["duration_ms"],
|
||||
"triggered_by": row["triggered_by"],
|
||||
"results": json.loads(row["results"]) if row["results"] else [],
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Endpoints (router used by ocr_pipeline_api.py)
|
||||
from ocr_pipeline_regression_endpoints import router # noqa: F401
|
||||
|
||||
421
klausur-service/backend/ocr_pipeline_regression_endpoints.py
Normal file
421
klausur-service/backend/ocr_pipeline_regression_endpoints.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
|
||||
|
||||
Extracted from ocr_pipeline_regression.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from grid_editor_api import _build_grid_core
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
list_ground_truth_sessions_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_regression_helpers import (
|
||||
_build_reference_snapshot,
|
||||
_init_regression_table,
|
||||
_persist_regression_run,
|
||||
compare_grids,
|
||||
get_pool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/mark-ground-truth")
|
||||
async def mark_ground_truth(
|
||||
session_id: str,
|
||||
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
|
||||
):
|
||||
"""Save the current build-grid result as ground-truth reference."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_result = session.get("grid_editor_result")
|
||||
if not grid_result or not grid_result.get("zones"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No grid_editor_result found. Run build-grid first.",
|
||||
)
|
||||
|
||||
# Auto-detect pipeline from word_result if not provided
|
||||
if not pipeline:
|
||||
wr = session.get("word_result") or {}
|
||||
engine = wr.get("ocr_engine", "")
|
||||
if engine in ("kombi", "rapid_kombi"):
|
||||
pipeline = "kombi"
|
||||
elif engine == "paddle_direct":
|
||||
pipeline = "paddle-direct"
|
||||
else:
|
||||
pipeline = "pipeline"
|
||||
|
||||
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
|
||||
|
||||
# Merge into existing ground_truth JSONB
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["build_grid_reference"] = reference
|
||||
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
||||
|
||||
# Compare with auto-snapshot if available (shows what the user corrected)
|
||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||
correction_diff = None
|
||||
if auto_snapshot:
|
||||
correction_diff = compare_grids(auto_snapshot, reference)
|
||||
|
||||
logger.info(
|
||||
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
||||
session_id,
|
||||
len(reference["cells"]),
|
||||
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"session_id": session_id,
|
||||
"cells_saved": len(reference["cells"]),
|
||||
"summary": reference["summary"],
|
||||
"correction_diff": correction_diff,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}/mark-ground-truth")
|
||||
async def unmark_ground_truth(session_id: str):
|
||||
"""Remove the ground-truth reference from a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
if "build_grid_reference" not in gt:
|
||||
raise HTTPException(status_code=404, detail="No ground truth reference found")
|
||||
|
||||
del gt["build_grid_reference"]
|
||||
await update_session_db(session_id, ground_truth=gt)
|
||||
|
||||
logger.info("Ground truth removed for session %s", session_id)
|
||||
return {"status": "ok", "session_id": session_id}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/correction-diff")
|
||||
async def get_correction_diff(session_id: str):
|
||||
"""Compare automatic OCR grid with manually corrected ground truth.
|
||||
|
||||
Returns a diff showing exactly which cells the user corrected,
|
||||
broken down by col_type (english, german, ipa, etc.).
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||
reference = gt.get("build_grid_reference")
|
||||
|
||||
if not auto_snapshot:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
||||
)
|
||||
if not reference:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No ground truth reference found. Mark as ground truth first.",
|
||||
)
|
||||
|
||||
diff = compare_grids(auto_snapshot, reference)
|
||||
|
||||
# Enrich with per-col_type breakdown
|
||||
col_type_stats: Dict[str, Dict[str, int]] = {}
|
||||
for cell_diff in diff.get("cell_diffs", []):
|
||||
if cell_diff["type"] != "text_change":
|
||||
continue
|
||||
# Find col_type from reference cells
|
||||
cell_id = cell_diff["cell_id"]
|
||||
ref_cell = next(
|
||||
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
||||
None,
|
||||
)
|
||||
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
||||
if ct not in col_type_stats:
|
||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||
col_type_stats[ct]["corrected"] += 1
|
||||
|
||||
# Count total cells per col_type from reference
|
||||
for cell in reference.get("cells", []):
|
||||
ct = cell.get("col_type", "unknown")
|
||||
if ct not in col_type_stats:
|
||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||
col_type_stats[ct]["total"] += 1
|
||||
|
||||
# Calculate accuracy per col_type
|
||||
for ct, stats in col_type_stats.items():
|
||||
total = stats["total"]
|
||||
corrected = stats["corrected"]
|
||||
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
||||
|
||||
diff["col_type_breakdown"] = col_type_stats
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
@router.get("/ground-truth-sessions")
|
||||
async def list_ground_truth_sessions():
|
||||
"""List all sessions that have a ground-truth reference."""
|
||||
sessions = await list_ground_truth_sessions_db()
|
||||
|
||||
result = []
|
||||
for s in sessions:
|
||||
gt = s.get("ground_truth") or {}
|
||||
ref = gt.get("build_grid_reference", {})
|
||||
result.append({
|
||||
"session_id": s["id"],
|
||||
"name": s.get("name", ""),
|
||||
"filename": s.get("filename", ""),
|
||||
"document_category": s.get("document_category"),
|
||||
"pipeline": ref.get("pipeline"),
|
||||
"saved_at": ref.get("saved_at"),
|
||||
"summary": ref.get("summary", {}),
|
||||
})
|
||||
|
||||
return {"sessions": result, "count": len(result)}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/regression/run")
|
||||
async def run_single_regression(session_id: str):
|
||||
"""Re-run build_grid for a single session and compare to ground truth."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
reference = gt.get("build_grid_reference")
|
||||
if not reference:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No ground truth reference found for this session",
|
||||
)
|
||||
|
||||
# Re-compute grid without persisting
|
||||
try:
|
||||
new_result = await _build_grid_core(session_id, session)
|
||||
except (ValueError, Exception) as e:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"name": session.get("name", ""),
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
new_snapshot = _build_reference_snapshot(new_result)
|
||||
diff = compare_grids(reference, new_snapshot)
|
||||
|
||||
logger.info(
|
||||
"Regression test session %s: %s (%d structural, %d cell diffs)",
|
||||
session_id, diff["status"],
|
||||
diff["summary"]["structural_changes"],
|
||||
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"name": session.get("name", ""),
|
||||
"status": diff["status"],
|
||||
"diff": diff,
|
||||
"reference_summary": reference.get("summary", {}),
|
||||
"current_summary": new_snapshot.get("summary", {}),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/regression/run")
|
||||
async def run_all_regressions(
|
||||
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
|
||||
):
|
||||
"""Re-run build_grid for ALL ground-truth sessions and compare."""
|
||||
start_time = time.monotonic()
|
||||
sessions = await list_ground_truth_sessions_db()
|
||||
|
||||
if not sessions:
|
||||
return {
|
||||
"status": "pass",
|
||||
"message": "No ground truth sessions found",
|
||||
"results": [],
|
||||
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
||||
}
|
||||
|
||||
results = []
|
||||
passed = 0
|
||||
failed = 0
|
||||
errors = 0
|
||||
|
||||
for s in sessions:
|
||||
session_id = s["id"]
|
||||
gt = s.get("ground_truth") or {}
|
||||
reference = gt.get("build_grid_reference")
|
||||
if not reference:
|
||||
continue
|
||||
|
||||
# Re-load full session (list query may not include all JSONB fields)
|
||||
full_session = await get_session_db(session_id)
|
||||
if not full_session:
|
||||
results.append({
|
||||
"session_id": session_id,
|
||||
"name": s.get("name", ""),
|
||||
"status": "error",
|
||||
"error": "Session not found during re-load",
|
||||
})
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
new_result = await _build_grid_core(session_id, full_session)
|
||||
except (ValueError, Exception) as e:
|
||||
results.append({
|
||||
"session_id": session_id,
|
||||
"name": s.get("name", ""),
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
})
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
new_snapshot = _build_reference_snapshot(new_result)
|
||||
diff = compare_grids(reference, new_snapshot)
|
||||
|
||||
entry = {
|
||||
"session_id": session_id,
|
||||
"name": s.get("name", ""),
|
||||
"status": diff["status"],
|
||||
"diff_summary": diff["summary"],
|
||||
"reference_summary": reference.get("summary", {}),
|
||||
"current_summary": new_snapshot.get("summary", {}),
|
||||
}
|
||||
|
||||
# Include full diffs only for failures (keep response compact)
|
||||
if diff["status"] == "fail":
|
||||
entry["structural_diffs"] = diff["structural_diffs"]
|
||||
entry["cell_diffs"] = diff["cell_diffs"]
|
||||
failed += 1
|
||||
else:
|
||||
passed += 1
|
||||
|
||||
results.append(entry)
|
||||
|
||||
overall = "pass" if failed == 0 and errors == 0 else "fail"
|
||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||
|
||||
summary = {
|
||||
"total": len(results),
|
||||
"passed": passed,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
|
||||
overall, passed, failed, errors, len(results), duration_ms,
|
||||
)
|
||||
|
||||
# Persist to DB
|
||||
run_id = await _persist_regression_run(
|
||||
status=overall,
|
||||
summary=summary,
|
||||
results=results,
|
||||
duration_ms=duration_ms,
|
||||
triggered_by=triggered_by,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": overall,
|
||||
"run_id": run_id,
|
||||
"duration_ms": duration_ms,
|
||||
"results": results,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/regression/history")
|
||||
async def get_regression_history(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""Get recent regression run history from the database."""
|
||||
try:
|
||||
await _init_regression_table()
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, run_at, status, total, passed, failed, errors,
|
||||
duration_ms, triggered_by
|
||||
FROM regression_runs
|
||||
ORDER BY run_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit,
|
||||
)
|
||||
return {
|
||||
"runs": [
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
||||
"status": row["status"],
|
||||
"total": row["total"],
|
||||
"passed": row["passed"],
|
||||
"failed": row["failed"],
|
||||
"errors": row["errors"],
|
||||
"duration_ms": row["duration_ms"],
|
||||
"triggered_by": row["triggered_by"],
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
"count": len(rows),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch regression history: %s", e)
|
||||
return {"runs": [], "count": 0, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/regression/history/{run_id}")
|
||||
async def get_regression_run_detail(run_id: str):
|
||||
"""Get detailed results of a specific regression run."""
|
||||
try:
|
||||
await _init_regression_table()
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT * FROM regression_runs WHERE id = $1",
|
||||
run_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
||||
"status": row["status"],
|
||||
"total": row["total"],
|
||||
"passed": row["passed"],
|
||||
"failed": row["failed"],
|
||||
"errors": row["errors"],
|
||||
"duration_ms": row["duration_ms"],
|
||||
"triggered_by": row["triggered_by"],
|
||||
"results": json.loads(row["results"]) if row["results"] else [],
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
207
klausur-service/backend/ocr_pipeline_regression_helpers.py
Normal file
207
klausur-service/backend/ocr_pipeline_regression_helpers.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
|
||||
|
||||
Extracted from ocr_pipeline_regression.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ocr_pipeline_session_store import get_pool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB persistence for regression runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _init_regression_table():
|
||||
"""Ensure regression_runs table exists (idempotent)."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
migration_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"migrations/008_regression_runs.sql",
|
||||
)
|
||||
if os.path.exists(migration_path):
|
||||
with open(migration_path, "r") as f:
|
||||
sql = f.read()
|
||||
await conn.execute(sql)
|
||||
|
||||
|
||||
async def _persist_regression_run(
|
||||
status: str,
|
||||
summary: dict,
|
||||
results: list,
|
||||
duration_ms: int,
|
||||
triggered_by: str = "manual",
|
||||
) -> str:
|
||||
"""Save a regression run to the database. Returns the run ID."""
|
||||
try:
|
||||
await _init_regression_table()
|
||||
pool = await get_pool()
|
||||
run_id = str(uuid.uuid4())
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO regression_runs
|
||||
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
|
||||
""",
|
||||
run_id,
|
||||
status,
|
||||
summary.get("total", 0),
|
||||
summary.get("passed", 0),
|
||||
summary.get("failed", 0),
|
||||
summary.get("errors", 0),
|
||||
duration_ms,
|
||||
json.dumps(results),
|
||||
triggered_by,
|
||||
)
|
||||
logger.info("Regression run %s persisted: %s", run_id, status)
|
||||
return run_id
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist regression run: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
|
||||
"""Extract a flat list of cells from a grid_editor_result for comparison.
|
||||
|
||||
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
|
||||
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
|
||||
"""
|
||||
cells = []
|
||||
for zone in grid_result.get("zones", []):
|
||||
for cell in zone.get("cells", []):
|
||||
cells.append({
|
||||
"cell_id": cell.get("cell_id", ""),
|
||||
"row_index": cell.get("row_index"),
|
||||
"col_index": cell.get("col_index"),
|
||||
"col_type": cell.get("col_type", ""),
|
||||
"text": cell.get("text", ""),
|
||||
})
|
||||
return cells
|
||||
|
||||
|
||||
def _build_reference_snapshot(
|
||||
grid_result: dict,
|
||||
pipeline: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Build a ground-truth reference snapshot from a grid_editor_result."""
|
||||
cells = _extract_cells_for_comparison(grid_result)
|
||||
|
||||
total_zones = len(grid_result.get("zones", []))
|
||||
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
|
||||
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
|
||||
|
||||
snapshot = {
|
||||
"saved_at": datetime.now(timezone.utc).isoformat(),
|
||||
"version": 1,
|
||||
"pipeline": pipeline,
|
||||
"summary": {
|
||||
"total_zones": total_zones,
|
||||
"total_columns": total_columns,
|
||||
"total_rows": total_rows,
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"cells": cells,
|
||||
}
|
||||
return snapshot
|
||||
|
||||
|
||||
def compare_grids(reference: dict, current: dict) -> dict:
|
||||
"""Compare a reference grid snapshot with a newly computed one.
|
||||
|
||||
Returns a diff report with:
|
||||
- status: "pass" or "fail"
|
||||
- structural_diffs: changes in zone/row/column counts
|
||||
- cell_diffs: list of individual cell changes
|
||||
"""
|
||||
ref_summary = reference.get("summary", {})
|
||||
cur_summary = current.get("summary", {})
|
||||
|
||||
structural_diffs = []
|
||||
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
|
||||
ref_val = ref_summary.get(key, 0)
|
||||
cur_val = cur_summary.get(key, 0)
|
||||
if ref_val != cur_val:
|
||||
structural_diffs.append({
|
||||
"field": key,
|
||||
"reference": ref_val,
|
||||
"current": cur_val,
|
||||
})
|
||||
|
||||
# Build cell lookup by cell_id
|
||||
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
|
||||
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
|
||||
|
||||
cell_diffs: List[Dict[str, Any]] = []
|
||||
|
||||
# Check for missing cells (in reference but not in current)
|
||||
for cell_id in ref_cells:
|
||||
if cell_id not in cur_cells:
|
||||
cell_diffs.append({
|
||||
"type": "cell_missing",
|
||||
"cell_id": cell_id,
|
||||
"reference_text": ref_cells[cell_id].get("text", ""),
|
||||
})
|
||||
|
||||
# Check for added cells (in current but not in reference)
|
||||
for cell_id in cur_cells:
|
||||
if cell_id not in ref_cells:
|
||||
cell_diffs.append({
|
||||
"type": "cell_added",
|
||||
"cell_id": cell_id,
|
||||
"current_text": cur_cells[cell_id].get("text", ""),
|
||||
})
|
||||
|
||||
# Check for changes in shared cells
|
||||
for cell_id in ref_cells:
|
||||
if cell_id not in cur_cells:
|
||||
continue
|
||||
ref_cell = ref_cells[cell_id]
|
||||
cur_cell = cur_cells[cell_id]
|
||||
|
||||
if ref_cell.get("text", "") != cur_cell.get("text", ""):
|
||||
cell_diffs.append({
|
||||
"type": "text_change",
|
||||
"cell_id": cell_id,
|
||||
"reference": ref_cell.get("text", ""),
|
||||
"current": cur_cell.get("text", ""),
|
||||
})
|
||||
|
||||
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
|
||||
cell_diffs.append({
|
||||
"type": "col_type_change",
|
||||
"cell_id": cell_id,
|
||||
"reference": ref_cell.get("col_type", ""),
|
||||
"current": cur_cell.get("col_type", ""),
|
||||
})
|
||||
|
||||
status = "pass" if not structural_diffs and not cell_diffs else "fail"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"structural_diffs": structural_diffs,
|
||||
"cell_diffs": cell_diffs,
|
||||
"summary": {
|
||||
"structural_changes": len(structural_diffs),
|
||||
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
|
||||
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
|
||||
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
|
||||
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
|
||||
},
|
||||
}
|
||||
@@ -1,597 +1,20 @@
|
||||
"""
|
||||
OCR Pipeline Sessions API - Session management and image serving endpoints.
|
||||
OCR Pipeline Sessions API — barrel re-export.
|
||||
|
||||
Extracted from ocr_pipeline_api.py for modularity.
|
||||
Handles: CRUD for sessions, thumbnails, pipeline logs, categories,
|
||||
image serving (with overlay dispatch), and document type detection.
|
||||
All implementation split into:
|
||||
ocr_pipeline_sessions_crud — session CRUD, box sessions
|
||||
ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import APIRouter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401
|
||||
from ocr_pipeline_sessions_images import router as _images_router # noqa: F401
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
create_ocr_image,
|
||||
detect_document_type,
|
||||
render_image_high_res,
|
||||
render_pdf_high_res,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
VALID_DOCUMENT_CATEGORIES,
|
||||
UpdateSessionRequest,
|
||||
_append_pipeline_log,
|
||||
_cache,
|
||||
_get_base_image_png,
|
||||
_get_cached,
|
||||
_load_session_to_cache,
|
||||
)
|
||||
from ocr_pipeline_overlays import render_overlay
|
||||
from ocr_pipeline_session_store import (
|
||||
create_session_db,
|
||||
delete_all_sessions_db,
|
||||
delete_session_db,
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
get_sub_sessions,
|
||||
list_sessions_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Management Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(include_sub_sessions: bool = False):
|
||||
"""List OCR pipeline sessions.
|
||||
|
||||
By default, sub-sessions (box regions) are hidden.
|
||||
Pass ?include_sub_sessions=true to show them.
|
||||
"""
|
||||
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
||||
return {"sessions": sessions}
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_session(
|
||||
file: UploadFile = File(...),
|
||||
name: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload a PDF or image file and create a pipeline session.
|
||||
|
||||
For multi-page PDFs (> 1 page), each page becomes its own session
|
||||
grouped under a ``document_group_id``. The response includes a
|
||||
``pages`` array with one entry per page/session.
|
||||
"""
|
||||
file_data = await file.read()
|
||||
filename = file.filename or "upload"
|
||||
content_type = file.content_type or ""
|
||||
|
||||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||||
session_name = name or filename
|
||||
|
||||
# --- Multi-page PDF handling ---
|
||||
if is_pdf:
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
|
||||
page_count = pdf_doc.page_count
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
|
||||
|
||||
if page_count > 1:
|
||||
return await _create_multi_page_sessions(
|
||||
file_data, filename, session_name, page_count,
|
||||
)
|
||||
|
||||
# --- Single page (image or 1-page PDF) ---
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
if is_pdf:
|
||||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||||
else:
|
||||
img_bgr = render_image_high_res(file_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||||
|
||||
# Encode original as PNG bytes
|
||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||||
|
||||
original_png = png_buf.tobytes()
|
||||
|
||||
# Persist to DB
|
||||
await create_session_db(
|
||||
session_id=session_id,
|
||||
name=session_name,
|
||||
filename=filename,
|
||||
original_png=original_png,
|
||||
)
|
||||
|
||||
# Cache BGR array for immediate processing
|
||||
_cache[session_id] = {
|
||||
"id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"original_bgr": img_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"image_width": img_bgr.shape[1],
|
||||
"image_height": img_bgr.shape[0],
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
}
|
||||
|
||||
|
||||
async def _create_multi_page_sessions(
|
||||
pdf_data: bytes,
|
||||
filename: str,
|
||||
base_name: str,
|
||||
page_count: int,
|
||||
) -> dict:
|
||||
"""Create one session per PDF page, grouped by document_group_id."""
|
||||
document_group_id = str(uuid.uuid4())
|
||||
pages = []
|
||||
|
||||
for page_idx in range(page_count):
|
||||
session_id = str(uuid.uuid4())
|
||||
page_name = f"{base_name} — Seite {page_idx + 1}"
|
||||
|
||||
try:
|
||||
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
ok, png_buf = cv2.imencode(".png", img_bgr)
|
||||
if not ok:
|
||||
continue
|
||||
page_png = png_buf.tobytes()
|
||||
|
||||
await create_session_db(
|
||||
session_id=session_id,
|
||||
name=page_name,
|
||||
filename=filename,
|
||||
original_png=page_png,
|
||||
document_group_id=document_group_id,
|
||||
page_number=page_idx + 1,
|
||||
)
|
||||
|
||||
_cache[session_id] = {
|
||||
"id": session_id,
|
||||
"filename": filename,
|
||||
"name": page_name,
|
||||
"original_bgr": img_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
pages.append({
|
||||
"session_id": session_id,
|
||||
"name": page_name,
|
||||
"page_number": page_idx + 1,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"OCR Pipeline: created page session {session_id} "
|
||||
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
|
||||
)
|
||||
|
||||
# Include session_id pointing to first page for backwards compatibility
|
||||
# (frontends that expect a single session_id will navigate to page 1)
|
||||
first_session_id = pages[0]["session_id"] if pages else None
|
||||
|
||||
return {
|
||||
"session_id": first_session_id,
|
||||
"document_group_id": document_group_id,
|
||||
"filename": filename,
|
||||
"name": base_name,
|
||||
"page_count": page_count,
|
||||
"pages": pages,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
async def get_session_info(session_id: str):
|
||||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# Get image dimensions from original PNG
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if original_png:
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||||
else:
|
||||
img_w, img_h = 0, 0
|
||||
|
||||
result = {
|
||||
"session_id": session["id"],
|
||||
"filename": session.get("filename", ""),
|
||||
"name": session.get("name", ""),
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
"current_step": session.get("current_step", 1),
|
||||
"document_category": session.get("document_category"),
|
||||
"doc_type": session.get("doc_type"),
|
||||
}
|
||||
|
||||
if session.get("orientation_result"):
|
||||
result["orientation_result"] = session["orientation_result"]
|
||||
if session.get("crop_result"):
|
||||
result["crop_result"] = session["crop_result"]
|
||||
if session.get("deskew_result"):
|
||||
result["deskew_result"] = session["deskew_result"]
|
||||
if session.get("dewarp_result"):
|
||||
result["dewarp_result"] = session["dewarp_result"]
|
||||
if session.get("column_result"):
|
||||
result["column_result"] = session["column_result"]
|
||||
if session.get("row_result"):
|
||||
result["row_result"] = session["row_result"]
|
||||
if session.get("word_result"):
|
||||
result["word_result"] = session["word_result"]
|
||||
if session.get("doc_type_result"):
|
||||
result["doc_type_result"] = session["doc_type_result"]
|
||||
if session.get("structure_result"):
|
||||
result["structure_result"] = session["structure_result"]
|
||||
if session.get("grid_editor_result"):
|
||||
# Include summary only to keep response small
|
||||
gr = session["grid_editor_result"]
|
||||
result["grid_editor_result"] = {
|
||||
"summary": gr.get("summary", {}),
|
||||
"zones_count": len(gr.get("zones", [])),
|
||||
"edited": gr.get("edited", False),
|
||||
}
|
||||
if session.get("ground_truth"):
|
||||
result["ground_truth"] = session["ground_truth"]
|
||||
|
||||
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
|
||||
if session.get("parent_session_id"):
|
||||
result["parent_session_id"] = session["parent_session_id"]
|
||||
result["box_index"] = session.get("box_index")
|
||||
else:
|
||||
# Check for box sub-sessions (column detection creates these)
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
result["sub_sessions"] = [
|
||||
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
||||
for s in subs
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/sessions/{session_id}")
|
||||
async def update_session(session_id: str, req: UpdateSessionRequest):
|
||||
"""Update session name and/or document category."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if req.name is not None:
|
||||
kwargs["name"] = req.name
|
||||
if req.document_category is not None:
|
||||
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
||||
)
|
||||
kwargs["document_category"] = req.document_category
|
||||
if not kwargs:
|
||||
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||
updated = await update_session_db(session_id, **kwargs)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, **kwargs}
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a session."""
|
||||
_cache.pop(session_id, None)
|
||||
deleted = await delete_session_db(session_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "deleted": True}
|
||||
|
||||
|
||||
@router.delete("/sessions")
|
||||
async def delete_all_sessions():
|
||||
"""Delete ALL sessions (cleanup)."""
|
||||
_cache.clear()
|
||||
count = await delete_all_sessions_db()
|
||||
return {"deleted_count": count}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/create-box-sessions")
|
||||
async def create_box_sessions(session_id: str):
|
||||
"""Create sub-sessions for each detected box region.
|
||||
|
||||
Crops box regions from the cropped/dewarped image and creates
|
||||
independent sub-sessions that can be processed through the pipeline.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
if not column_result:
|
||||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||||
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
if not box_zones:
|
||||
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
||||
|
||||
# Check for existing sub-sessions
|
||||
existing = await get_sub_sessions(session_id)
|
||||
if existing:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
||||
"message": f"{len(existing)} sub-session(s) already exist",
|
||||
}
|
||||
|
||||
# Load base image
|
||||
base_png = await get_session_image(session_id, "cropped")
|
||||
if not base_png:
|
||||
base_png = await get_session_image(session_id, "dewarped")
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=400, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
parent_name = session.get("name", "Session")
|
||||
created = []
|
||||
|
||||
for i, zone in enumerate(box_zones):
|
||||
box = zone["box"]
|
||||
bx, by = box["x"], box["y"]
|
||||
bw, bh = box["width"], box["height"]
|
||||
|
||||
# Crop box region with small padding
|
||||
pad = 5
|
||||
y1 = max(0, by - pad)
|
||||
y2 = min(img.shape[0], by + bh + pad)
|
||||
x1 = max(0, bx - pad)
|
||||
x2 = min(img.shape[1], bx + bw + pad)
|
||||
crop = img[y1:y2, x1:x2]
|
||||
|
||||
# Encode as PNG
|
||||
success, png_buf = cv2.imencode(".png", crop)
|
||||
if not success:
|
||||
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
||||
continue
|
||||
|
||||
sub_id = str(uuid.uuid4())
|
||||
sub_name = f"{parent_name} — Box {i + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=session.get("filename", "box-crop.png"),
|
||||
original_png=png_buf.tobytes(),
|
||||
parent_session_id=session_id,
|
||||
box_index=i,
|
||||
)
|
||||
|
||||
# Cache the BGR for immediate processing
|
||||
# Promote original to cropped so column/row/word detection finds it
|
||||
box_bgr = crop.copy()
|
||||
_cache[sub_id] = {
|
||||
"id": sub_id,
|
||||
"filename": session.get("filename", "box-crop.png"),
|
||||
"name": sub_name,
|
||||
"parent_session_id": session_id,
|
||||
"original_bgr": box_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": box_bgr,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
created.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"box_index": i,
|
||||
"box": box,
|
||||
"image_width": crop.shape[1],
|
||||
"image_height": crop.shape[0],
|
||||
})
|
||||
|
||||
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
||||
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"sub_sessions": created,
|
||||
"total": len(created),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/thumbnail")
|
||||
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
||||
"""Return a small thumbnail of the original image."""
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if not original_png:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
h, w = img.shape[:2]
|
||||
scale = size / max(h, w)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
_, png_bytes = cv2.imencode(".png", thumb)
|
||||
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
||||
headers={"Cache-Control": "public, max-age=3600"})
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/pipeline-log")
|
||||
async def get_pipeline_log(session_id: str):
|
||||
"""Get the pipeline execution log for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def list_categories():
|
||||
"""List valid document categories."""
|
||||
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
||||
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||||
if image_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
if image_type == "structure-overlay":
|
||||
return await render_overlay("structure", session_id)
|
||||
|
||||
if image_type == "columns-overlay":
|
||||
return await render_overlay("columns", session_id)
|
||||
|
||||
if image_type == "rows-overlay":
|
||||
return await render_overlay("rows", session_id)
|
||||
|
||||
if image_type == "words-overlay":
|
||||
return await render_overlay("words", session_id)
|
||||
|
||||
# Try cache first for fast serving
|
||||
cached = _cache.get(session_id)
|
||||
if cached:
|
||||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||||
|
||||
# For binarized, check if we have it cached as PNG
|
||||
if image_type == "binarized" and cached.get("binarized_png"):
|
||||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||||
|
||||
# Load from DB — for cropped/dewarped, fall back through the chain
|
||||
if image_type in ("cropped", "dewarped"):
|
||||
data = await _get_base_image_png(session_id)
|
||||
else:
|
||||
data = await get_session_image(session_id, image_type)
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||
|
||||
return Response(content=data, media_type="image/png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document Type Detection (between Dewarp and Columns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/detect-type")
|
||||
async def detect_type(session_id: str):
|
||||
"""Detect document type (vocab_table, full_text, generic_table).
|
||||
|
||||
Should be called after crop (clean image available).
|
||||
Falls back to dewarped if crop was skipped.
|
||||
Stores result in session for frontend to decide pipeline flow.
|
||||
"""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(img_bgr)
|
||||
result = detect_document_type(ocr_img, img_bgr)
|
||||
duration = time.time() - t0
|
||||
|
||||
result_dict = {
|
||||
"doc_type": result.doc_type,
|
||||
"confidence": result.confidence,
|
||||
"pipeline": result.pipeline,
|
||||
"skip_steps": result.skip_steps,
|
||||
"features": result.features,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
doc_type=result.doc_type,
|
||||
doc_type_result=result_dict,
|
||||
)
|
||||
|
||||
cached["doc_type_result"] = result_dict
|
||||
|
||||
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||
|
||||
await _append_pipeline_log(session_id, "detect_type", {
|
||||
"doc_type": result.doc_type,
|
||||
"pipeline": result.pipeline,
|
||||
"confidence": result.confidence,
|
||||
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {"session_id": session_id, **result_dict}
|
||||
# Composite router (used by ocr_pipeline_api.py)
|
||||
router = APIRouter()
|
||||
router.include_router(_crud_router)
|
||||
router.include_router(_images_router)
|
||||
|
||||
449
klausur-service/backend/ocr_pipeline_sessions_crud.py
Normal file
449
klausur-service/backend/ocr_pipeline_sessions_crud.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
|
||||
|
||||
Extracted from ocr_pipeline_sessions.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||
|
||||
from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
|
||||
from ocr_pipeline_common import (
|
||||
VALID_DOCUMENT_CATEGORIES,
|
||||
UpdateSessionRequest,
|
||||
_cache,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
create_session_db,
|
||||
delete_all_sessions_db,
|
||||
delete_session_db,
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
get_sub_sessions,
|
||||
list_sessions_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Management Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(include_sub_sessions: bool = False):
|
||||
"""List OCR pipeline sessions.
|
||||
|
||||
By default, sub-sessions (box regions) are hidden.
|
||||
Pass ?include_sub_sessions=true to show them.
|
||||
"""
|
||||
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
||||
return {"sessions": sessions}
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_session(
|
||||
file: UploadFile = File(...),
|
||||
name: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload a PDF or image file and create a pipeline session.
|
||||
|
||||
For multi-page PDFs (> 1 page), each page becomes its own session
|
||||
grouped under a ``document_group_id``. The response includes a
|
||||
``pages`` array with one entry per page/session.
|
||||
"""
|
||||
file_data = await file.read()
|
||||
filename = file.filename or "upload"
|
||||
content_type = file.content_type or ""
|
||||
|
||||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||||
session_name = name or filename
|
||||
|
||||
# --- Multi-page PDF handling ---
|
||||
if is_pdf:
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
|
||||
page_count = pdf_doc.page_count
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
|
||||
|
||||
if page_count > 1:
|
||||
return await _create_multi_page_sessions(
|
||||
file_data, filename, session_name, page_count,
|
||||
)
|
||||
|
||||
# --- Single page (image or 1-page PDF) ---
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
if is_pdf:
|
||||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||||
else:
|
||||
img_bgr = render_image_high_res(file_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||||
|
||||
# Encode original as PNG bytes
|
||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||||
|
||||
original_png = png_buf.tobytes()
|
||||
|
||||
# Persist to DB
|
||||
await create_session_db(
|
||||
session_id=session_id,
|
||||
name=session_name,
|
||||
filename=filename,
|
||||
original_png=original_png,
|
||||
)
|
||||
|
||||
# Cache BGR array for immediate processing
|
||||
_cache[session_id] = {
|
||||
"id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"original_bgr": img_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"image_width": img_bgr.shape[1],
|
||||
"image_height": img_bgr.shape[0],
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
}
|
||||
|
||||
|
||||
async def _create_multi_page_sessions(
|
||||
pdf_data: bytes,
|
||||
filename: str,
|
||||
base_name: str,
|
||||
page_count: int,
|
||||
) -> dict:
|
||||
"""Create one session per PDF page, grouped by document_group_id."""
|
||||
document_group_id = str(uuid.uuid4())
|
||||
pages = []
|
||||
|
||||
for page_idx in range(page_count):
|
||||
session_id = str(uuid.uuid4())
|
||||
page_name = f"{base_name} — Seite {page_idx + 1}"
|
||||
|
||||
try:
|
||||
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
ok, png_buf = cv2.imencode(".png", img_bgr)
|
||||
if not ok:
|
||||
continue
|
||||
page_png = png_buf.tobytes()
|
||||
|
||||
await create_session_db(
|
||||
session_id=session_id,
|
||||
name=page_name,
|
||||
filename=filename,
|
||||
original_png=page_png,
|
||||
document_group_id=document_group_id,
|
||||
page_number=page_idx + 1,
|
||||
)
|
||||
|
||||
_cache[session_id] = {
|
||||
"id": session_id,
|
||||
"filename": filename,
|
||||
"name": page_name,
|
||||
"original_bgr": img_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
pages.append({
|
||||
"session_id": session_id,
|
||||
"name": page_name,
|
||||
"page_number": page_idx + 1,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"OCR Pipeline: created page session {session_id} "
|
||||
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
|
||||
)
|
||||
|
||||
# Include session_id pointing to first page for backwards compatibility
|
||||
# (frontends that expect a single session_id will navigate to page 1)
|
||||
first_session_id = pages[0]["session_id"] if pages else None
|
||||
|
||||
return {
|
||||
"session_id": first_session_id,
|
||||
"document_group_id": document_group_id,
|
||||
"filename": filename,
|
||||
"name": base_name,
|
||||
"page_count": page_count,
|
||||
"pages": pages,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
async def get_session_info(session_id: str):
|
||||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# Get image dimensions from original PNG
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if original_png:
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||||
else:
|
||||
img_w, img_h = 0, 0
|
||||
|
||||
result = {
|
||||
"session_id": session["id"],
|
||||
"filename": session.get("filename", ""),
|
||||
"name": session.get("name", ""),
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
"current_step": session.get("current_step", 1),
|
||||
"document_category": session.get("document_category"),
|
||||
"doc_type": session.get("doc_type"),
|
||||
}
|
||||
|
||||
if session.get("orientation_result"):
|
||||
result["orientation_result"] = session["orientation_result"]
|
||||
if session.get("crop_result"):
|
||||
result["crop_result"] = session["crop_result"]
|
||||
if session.get("deskew_result"):
|
||||
result["deskew_result"] = session["deskew_result"]
|
||||
if session.get("dewarp_result"):
|
||||
result["dewarp_result"] = session["dewarp_result"]
|
||||
if session.get("column_result"):
|
||||
result["column_result"] = session["column_result"]
|
||||
if session.get("row_result"):
|
||||
result["row_result"] = session["row_result"]
|
||||
if session.get("word_result"):
|
||||
result["word_result"] = session["word_result"]
|
||||
if session.get("doc_type_result"):
|
||||
result["doc_type_result"] = session["doc_type_result"]
|
||||
if session.get("structure_result"):
|
||||
result["structure_result"] = session["structure_result"]
|
||||
if session.get("grid_editor_result"):
|
||||
# Include summary only to keep response small
|
||||
gr = session["grid_editor_result"]
|
||||
result["grid_editor_result"] = {
|
||||
"summary": gr.get("summary", {}),
|
||||
"zones_count": len(gr.get("zones", [])),
|
||||
"edited": gr.get("edited", False),
|
||||
}
|
||||
if session.get("ground_truth"):
|
||||
result["ground_truth"] = session["ground_truth"]
|
||||
|
||||
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
|
||||
if session.get("parent_session_id"):
|
||||
result["parent_session_id"] = session["parent_session_id"]
|
||||
result["box_index"] = session.get("box_index")
|
||||
else:
|
||||
# Check for box sub-sessions (column detection creates these)
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
result["sub_sessions"] = [
|
||||
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
||||
for s in subs
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/sessions/{session_id}")
|
||||
async def update_session(session_id: str, req: UpdateSessionRequest):
|
||||
"""Update session name and/or document category."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if req.name is not None:
|
||||
kwargs["name"] = req.name
|
||||
if req.document_category is not None:
|
||||
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
||||
)
|
||||
kwargs["document_category"] = req.document_category
|
||||
if not kwargs:
|
||||
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||
updated = await update_session_db(session_id, **kwargs)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, **kwargs}
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a session."""
|
||||
_cache.pop(session_id, None)
|
||||
deleted = await delete_session_db(session_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "deleted": True}
|
||||
|
||||
|
||||
@router.delete("/sessions")
|
||||
async def delete_all_sessions():
|
||||
"""Delete ALL sessions (cleanup)."""
|
||||
_cache.clear()
|
||||
count = await delete_all_sessions_db()
|
||||
return {"deleted_count": count}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/create-box-sessions")
|
||||
async def create_box_sessions(session_id: str):
|
||||
"""Create sub-sessions for each detected box region.
|
||||
|
||||
Crops box regions from the cropped/dewarped image and creates
|
||||
independent sub-sessions that can be processed through the pipeline.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
if not column_result:
|
||||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||||
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
if not box_zones:
|
||||
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
||||
|
||||
# Check for existing sub-sessions
|
||||
existing = await get_sub_sessions(session_id)
|
||||
if existing:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
||||
"message": f"{len(existing)} sub-session(s) already exist",
|
||||
}
|
||||
|
||||
# Load base image
|
||||
base_png = await get_session_image(session_id, "cropped")
|
||||
if not base_png:
|
||||
base_png = await get_session_image(session_id, "dewarped")
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=400, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
parent_name = session.get("name", "Session")
|
||||
created = []
|
||||
|
||||
for i, zone in enumerate(box_zones):
|
||||
box = zone["box"]
|
||||
bx, by = box["x"], box["y"]
|
||||
bw, bh = box["width"], box["height"]
|
||||
|
||||
# Crop box region with small padding
|
||||
pad = 5
|
||||
y1 = max(0, by - pad)
|
||||
y2 = min(img.shape[0], by + bh + pad)
|
||||
x1 = max(0, bx - pad)
|
||||
x2 = min(img.shape[1], bx + bw + pad)
|
||||
crop = img[y1:y2, x1:x2]
|
||||
|
||||
# Encode as PNG
|
||||
success, png_buf = cv2.imencode(".png", crop)
|
||||
if not success:
|
||||
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
||||
continue
|
||||
|
||||
sub_id = str(uuid.uuid4())
|
||||
sub_name = f"{parent_name} — Box {i + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=session.get("filename", "box-crop.png"),
|
||||
original_png=png_buf.tobytes(),
|
||||
parent_session_id=session_id,
|
||||
box_index=i,
|
||||
)
|
||||
|
||||
# Cache the BGR for immediate processing
|
||||
# Promote original to cropped so column/row/word detection finds it
|
||||
box_bgr = crop.copy()
|
||||
_cache[sub_id] = {
|
||||
"id": sub_id,
|
||||
"filename": session.get("filename", "box-crop.png"),
|
||||
"name": sub_name,
|
||||
"parent_session_id": session_id,
|
||||
"original_bgr": box_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": box_bgr,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
created.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"box_index": i,
|
||||
"box": box,
|
||||
"image_width": crop.shape[1],
|
||||
"image_height": crop.shape[0],
|
||||
})
|
||||
|
||||
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
||||
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"sub_sessions": created,
|
||||
"total": len(created),
|
||||
}
|
||||
176
klausur-service/backend/ocr_pipeline_sessions_images.py
Normal file
176
klausur-service/backend/ocr_pipeline_sessions_images.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
|
||||
categories, and document type detection.
|
||||
|
||||
Extracted from ocr_pipeline_sessions.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
|
||||
from cv_vocab_pipeline import create_ocr_image, detect_document_type
|
||||
from ocr_pipeline_common import (
|
||||
VALID_DOCUMENT_CATEGORIES,
|
||||
_append_pipeline_log,
|
||||
_cache,
|
||||
_get_base_image_png,
|
||||
_get_cached,
|
||||
_load_session_to_cache,
|
||||
)
|
||||
from ocr_pipeline_overlays import render_overlay
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thumbnail & Log Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/thumbnail")
|
||||
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
||||
"""Return a small thumbnail of the original image."""
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if not original_png:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
h, w = img.shape[:2]
|
||||
scale = size / max(h, w)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
_, png_bytes = cv2.imencode(".png", thumb)
|
||||
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
||||
headers={"Cache-Control": "public, max-age=3600"})
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/pipeline-log")
|
||||
async def get_pipeline_log(session_id: str):
|
||||
"""Get the pipeline execution log for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def list_categories():
|
||||
"""List valid document categories."""
|
||||
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
||||
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||||
if image_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
if image_type == "structure-overlay":
|
||||
return await render_overlay("structure", session_id)
|
||||
|
||||
if image_type == "columns-overlay":
|
||||
return await render_overlay("columns", session_id)
|
||||
|
||||
if image_type == "rows-overlay":
|
||||
return await render_overlay("rows", session_id)
|
||||
|
||||
if image_type == "words-overlay":
|
||||
return await render_overlay("words", session_id)
|
||||
|
||||
# Try cache first for fast serving
|
||||
cached = _cache.get(session_id)
|
||||
if cached:
|
||||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||||
|
||||
# For binarized, check if we have it cached as PNG
|
||||
if image_type == "binarized" and cached.get("binarized_png"):
|
||||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||||
|
||||
# Load from DB — for cropped/dewarped, fall back through the chain
|
||||
if image_type in ("cropped", "dewarped"):
|
||||
data = await _get_base_image_png(session_id)
|
||||
else:
|
||||
data = await get_session_image(session_id, image_type)
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||
|
||||
return Response(content=data, media_type="image/png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document Type Detection (between Dewarp and Columns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/detect-type")
|
||||
async def detect_type(session_id: str):
|
||||
"""Detect document type (vocab_table, full_text, generic_table).
|
||||
|
||||
Should be called after crop (clean image available).
|
||||
Falls back to dewarped if crop was skipped.
|
||||
Stores result in session for frontend to decide pipeline flow.
|
||||
"""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(img_bgr)
|
||||
result = detect_document_type(ocr_img, img_bgr)
|
||||
duration = time.time() - t0
|
||||
|
||||
result_dict = {
|
||||
"doc_type": result.doc_type,
|
||||
"confidence": result.confidence,
|
||||
"pipeline": result.pipeline,
|
||||
"skip_steps": result.skip_steps,
|
||||
"features": result.features,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
doc_type=result.doc_type,
|
||||
doc_type_result=result_dict,
|
||||
)
|
||||
|
||||
cached["doc_type_result"] = result_dict
|
||||
|
||||
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||
|
||||
await _append_pipeline_log(session_id, "detect_type", {
|
||||
"doc_type": result.doc_type,
|
||||
"pipeline": result.pipeline,
|
||||
"confidence": result.confidence,
|
||||
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {"session_id": session_id, **result_dict}
|
||||
@@ -1,529 +1,38 @@
|
||||
"""
|
||||
Self-RAG / Corrective RAG Module
|
||||
Self-RAG / Corrective RAG Module — barrel re-export.
|
||||
|
||||
Implements self-reflective RAG that can:
|
||||
1. Grade retrieved documents for relevance
|
||||
2. Decide if more retrieval is needed
|
||||
3. Reformulate queries if initial retrieval fails
|
||||
4. Filter irrelevant passages before generation
|
||||
5. Grade answers for groundedness and hallucination
|
||||
All implementation split into:
|
||||
self_rag_grading — document relevance grading, filtering, decisions
|
||||
self_rag_retrieval — query reformulation, retrieval loop, info
|
||||
|
||||
IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
|
||||
When enabled, search queries and retrieved documents are sent to OpenAI API.
|
||||
|
||||
Based on research:
|
||||
- Self-RAG (Asai et al., 2023): Learning to retrieve, generate, and critique
|
||||
- Corrective RAG (Yan et al., 2024): Self-correcting retrieval augmented generation
|
||||
|
||||
This is especially useful for German educational documents where:
|
||||
- Queries may use informal language
|
||||
- Documents use formal/technical terminology
|
||||
- Context must be precisely matched to scoring criteria
|
||||
- Self-RAG (Asai et al., 2023)
|
||||
- Corrective RAG (Yan et al., 2024)
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from enum import Enum
|
||||
import httpx
|
||||
|
||||
# Configuration
|
||||
# IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
|
||||
# When enabled, search queries and retrieved documents are sent to OpenAI API
|
||||
# for relevance grading and query reformulation. This exposes user data to third parties.
|
||||
# Only enable if you have explicit user consent for data processing.
|
||||
SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
|
||||
|
||||
# Thresholds for self-reflection
|
||||
RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
|
||||
GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
|
||||
MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
|
||||
|
||||
|
||||
class RetrievalDecision(Enum):
|
||||
"""Decision after grading retrieval."""
|
||||
SUFFICIENT = "sufficient" # Context is good, proceed to generation
|
||||
NEEDS_MORE = "needs_more" # Need to retrieve more documents
|
||||
REFORMULATE = "reformulate" # Query needs reformulation
|
||||
FALLBACK = "fallback" # Use fallback (no good context found)
|
||||
|
||||
|
||||
class SelfRAGError(Exception):
|
||||
"""Error during Self-RAG processing."""
|
||||
pass
|
||||
|
||||
|
||||
async def grade_document_relevance(
|
||||
query: str,
|
||||
document: str,
|
||||
) -> Tuple[float, str]:
|
||||
"""
|
||||
Grade whether a document is relevant to the query.
|
||||
|
||||
Returns a score between 0 (irrelevant) and 1 (highly relevant)
|
||||
along with an explanation.
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
# Fallback: simple keyword overlap
|
||||
query_words = set(query.lower().split())
|
||||
doc_words = set(document.lower().split())
|
||||
overlap = len(query_words & doc_words) / max(len(query_words), 1)
|
||||
return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
|
||||
|
||||
prompt = f"""Bewerte, ob das folgende Dokument relevant für die Suchanfrage ist.
|
||||
|
||||
SUCHANFRAGE: {query}
|
||||
|
||||
DOKUMENT:
|
||||
{document[:2000]}
|
||||
|
||||
Ist dieses Dokument relevant, um die Anfrage zu beantworten?
|
||||
Berücksichtige:
|
||||
- Thematische Übereinstimmung
|
||||
- Enthält das Dokument spezifische Informationen zur Anfrage?
|
||||
- Würde dieses Dokument bei der Beantwortung helfen?
|
||||
|
||||
Antworte im Format:
|
||||
SCORE: [0.0-1.0]
|
||||
BEGRÜNDUNG: [Kurze Erklärung]"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return 0.5, f"API error: {response.status_code}"
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
|
||||
reason_match = re.search(r'BEGRÜNDUNG:\s*(.+)', result, re.DOTALL)
|
||||
reason = reason_match.group(1).strip() if reason_match else result
|
||||
|
||||
return min(max(score, 0.0), 1.0), reason
|
||||
|
||||
except Exception as e:
|
||||
return 0.5, f"Grading error: {str(e)}"
|
||||
|
||||
|
||||
async def grade_documents_batch(
|
||||
query: str,
|
||||
documents: List[str],
|
||||
) -> List[Tuple[float, str]]:
|
||||
"""
|
||||
Grade multiple documents for relevance.
|
||||
|
||||
Returns list of (score, reason) tuples.
|
||||
"""
|
||||
results = []
|
||||
for doc in documents:
|
||||
score, reason = await grade_document_relevance(query, doc)
|
||||
results.append((score, reason))
|
||||
return results
|
||||
|
||||
|
||||
async def filter_relevant_documents(
|
||||
query: str,
|
||||
documents: List[Dict],
|
||||
threshold: float = RELEVANCE_THRESHOLD,
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Filter documents by relevance, separating relevant from irrelevant.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of document dicts with 'text' field
|
||||
threshold: Minimum relevance score to keep
|
||||
|
||||
Returns:
|
||||
Tuple of (relevant_docs, filtered_out_docs)
|
||||
"""
|
||||
relevant = []
|
||||
filtered = []
|
||||
|
||||
for doc in documents:
|
||||
text = doc.get("text", "")
|
||||
score, reason = await grade_document_relevance(query, text)
|
||||
|
||||
doc_with_grade = doc.copy()
|
||||
doc_with_grade["relevance_score"] = score
|
||||
doc_with_grade["relevance_reason"] = reason
|
||||
|
||||
if score >= threshold:
|
||||
relevant.append(doc_with_grade)
|
||||
else:
|
||||
filtered.append(doc_with_grade)
|
||||
|
||||
# Sort relevant by score
|
||||
relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
|
||||
|
||||
return relevant, filtered
|
||||
|
||||
|
||||
async def decide_retrieval_strategy(
|
||||
query: str,
|
||||
documents: List[Dict],
|
||||
attempt: int = 1,
|
||||
) -> Tuple[RetrievalDecision, Dict]:
|
||||
"""
|
||||
Decide what to do based on current retrieval results.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: Retrieved documents with relevance scores
|
||||
attempt: Current retrieval attempt number
|
||||
|
||||
Returns:
|
||||
Tuple of (decision, metadata)
|
||||
"""
|
||||
if not documents:
|
||||
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
||||
return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
|
||||
return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
|
||||
|
||||
# Check average relevance
|
||||
scores = [doc.get("relevance_score", 0.5) for doc in documents]
|
||||
avg_score = sum(scores) / len(scores)
|
||||
max_score = max(scores)
|
||||
|
||||
if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
|
||||
return RetrievalDecision.SUFFICIENT, {
|
||||
"avg_relevance": avg_score,
|
||||
"max_relevance": max_score,
|
||||
"doc_count": len(documents),
|
||||
}
|
||||
|
||||
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
||||
if max_score >= RELEVANCE_THRESHOLD * 0.5:
|
||||
# At least some relevant context, proceed with caution
|
||||
return RetrievalDecision.SUFFICIENT, {
|
||||
"avg_relevance": avg_score,
|
||||
"warning": "Low relevance after max attempts",
|
||||
}
|
||||
return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
|
||||
|
||||
if avg_score < 0.3:
|
||||
return RetrievalDecision.REFORMULATE, {
|
||||
"reason": "Very low relevance, query reformulation needed",
|
||||
"avg_relevance": avg_score,
|
||||
}
|
||||
|
||||
return RetrievalDecision.NEEDS_MORE, {
|
||||
"reason": "Moderate relevance, retrieving more documents",
|
||||
"avg_relevance": avg_score,
|
||||
}
|
||||
|
||||
|
||||
async def reformulate_query(
|
||||
original_query: str,
|
||||
context: Optional[str] = None,
|
||||
previous_results_summary: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Reformulate a query to improve retrieval.
|
||||
|
||||
Uses LLM to generate a better query based on:
|
||||
- Original query
|
||||
- Optional context (subject, niveau, etc.)
|
||||
- Summary of why previous retrieval failed
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
# Simple reformulation: expand abbreviations, add synonyms
|
||||
reformulated = original_query
|
||||
expansions = {
|
||||
"EA": "erhöhtes Anforderungsniveau",
|
||||
"eA": "erhöhtes Anforderungsniveau",
|
||||
"GA": "grundlegendes Anforderungsniveau",
|
||||
"gA": "grundlegendes Anforderungsniveau",
|
||||
"AFB": "Anforderungsbereich",
|
||||
"Abi": "Abitur",
|
||||
}
|
||||
for abbr, expansion in expansions.items():
|
||||
if abbr in original_query:
|
||||
reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
|
||||
return reformulated
|
||||
|
||||
prompt = f"""Du bist ein Experte für deutsche Bildungsstandards und Prüfungsanforderungen.
|
||||
|
||||
Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
|
||||
ORIGINAL: {original_query}
|
||||
|
||||
{f"KONTEXT: {context}" if context else ""}
|
||||
{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
|
||||
|
||||
Formuliere die Anfrage so um, dass sie:
|
||||
1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
|
||||
2. Relevante Synonyme oder verwandte Begriffe einschließt
|
||||
3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
|
||||
|
||||
Antworte NUR mit der umformulierten Suchanfrage, ohne Erklärung."""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.3,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return original_query
|
||||
|
||||
return response.json()["choices"][0]["message"]["content"].strip()
|
||||
|
||||
except Exception:
|
||||
return original_query
|
||||
|
||||
|
||||
async def grade_answer_groundedness(
|
||||
answer: str,
|
||||
contexts: List[str],
|
||||
) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
Grade whether an answer is grounded in the provided contexts.
|
||||
|
||||
Returns:
|
||||
Tuple of (grounding_score, list of unsupported claims)
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
return 0.5, ["LLM not configured for grounding check"]
|
||||
|
||||
context_text = "\n---\n".join(contexts[:5])
|
||||
|
||||
prompt = f"""Analysiere, ob die folgende Antwort vollständig durch die Kontexte gestützt wird.
|
||||
|
||||
KONTEXTE:
|
||||
{context_text}
|
||||
|
||||
ANTWORT:
|
||||
{answer}
|
||||
|
||||
Identifiziere:
|
||||
1. Welche Aussagen sind durch die Kontexte belegt?
|
||||
2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
|
||||
|
||||
Antworte im Format:
|
||||
SCORE: [0.0-1.0] (1.0 = vollständig belegt)
|
||||
NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return 0.5, [f"API error: {response.status_code}"]
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
|
||||
unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
|
||||
unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
|
||||
|
||||
if unsupported_text.lower() == "keine":
|
||||
unsupported = []
|
||||
else:
|
||||
unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
|
||||
|
||||
return min(max(score, 0.0), 1.0), unsupported
|
||||
|
||||
except Exception as e:
|
||||
return 0.5, [f"Grounding check error: {str(e)}"]
|
||||
|
||||
|
||||
async def self_rag_retrieve(
|
||||
query: str,
|
||||
search_func,
|
||||
subject: Optional[str] = None,
|
||||
niveau: Optional[str] = None,
|
||||
initial_top_k: int = 10,
|
||||
final_top_k: int = 5,
|
||||
**search_kwargs
|
||||
) -> Dict:
|
||||
"""
|
||||
Perform Self-RAG enhanced retrieval with reflection and correction.
|
||||
|
||||
This implements a retrieval loop that:
|
||||
1. Retrieves initial documents
|
||||
2. Grades them for relevance
|
||||
3. Decides if more retrieval is needed
|
||||
4. Reformulates query if necessary
|
||||
5. Returns filtered, high-quality context
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
search_func: Async function to perform the actual search
|
||||
subject: Optional subject context
|
||||
niveau: Optional niveau context
|
||||
initial_top_k: Number of documents for initial retrieval
|
||||
final_top_k: Maximum documents to return
|
||||
**search_kwargs: Additional args for search_func
|
||||
|
||||
Returns:
|
||||
Dict with results, metadata, and reflection trace
|
||||
"""
|
||||
if not SELF_RAG_ENABLED:
|
||||
# Fall back to simple search
|
||||
results = await search_func(query=query, limit=final_top_k, **search_kwargs)
|
||||
return {
|
||||
"results": results,
|
||||
"self_rag_enabled": False,
|
||||
"query_used": query,
|
||||
}
|
||||
|
||||
trace = []
|
||||
current_query = query
|
||||
attempt = 1
|
||||
|
||||
while attempt <= MAX_RETRIEVAL_ATTEMPTS:
|
||||
# Step 1: Retrieve documents
|
||||
results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
|
||||
|
||||
trace.append({
|
||||
"attempt": attempt,
|
||||
"query": current_query,
|
||||
"retrieved_count": len(results) if results else 0,
|
||||
})
|
||||
|
||||
if not results:
|
||||
attempt += 1
|
||||
if attempt <= MAX_RETRIEVAL_ATTEMPTS:
|
||||
current_query = await reformulate_query(
|
||||
query,
|
||||
context=f"Fach: {subject}" if subject else None,
|
||||
previous_results_summary="Keine Dokumente gefunden"
|
||||
)
|
||||
trace[-1]["action"] = "reformulate"
|
||||
trace[-1]["new_query"] = current_query
|
||||
continue
|
||||
|
||||
# Step 2: Grade documents for relevance
|
||||
relevant, filtered = await filter_relevant_documents(current_query, results)
|
||||
|
||||
trace[-1]["relevant_count"] = len(relevant)
|
||||
trace[-1]["filtered_count"] = len(filtered)
|
||||
|
||||
# Step 3: Decide what to do
|
||||
decision, decision_meta = await decide_retrieval_strategy(
|
||||
current_query, relevant, attempt
|
||||
)
|
||||
|
||||
trace[-1]["decision"] = decision.value
|
||||
trace[-1]["decision_meta"] = decision_meta
|
||||
|
||||
if decision == RetrievalDecision.SUFFICIENT:
|
||||
# We have good context, return it
|
||||
return {
|
||||
"results": relevant[:final_top_k],
|
||||
"self_rag_enabled": True,
|
||||
"query_used": current_query,
|
||||
"original_query": query if current_query != query else None,
|
||||
"attempts": attempt,
|
||||
"decision": decision.value,
|
||||
"trace": trace,
|
||||
"filtered_out_count": len(filtered),
|
||||
}
|
||||
|
||||
elif decision == RetrievalDecision.REFORMULATE:
|
||||
# Reformulate and try again
|
||||
avg_score = decision_meta.get("avg_relevance", 0)
|
||||
current_query = await reformulate_query(
|
||||
query,
|
||||
context=f"Fach: {subject}" if subject else None,
|
||||
previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
|
||||
)
|
||||
trace[-1]["action"] = "reformulate"
|
||||
trace[-1]["new_query"] = current_query
|
||||
|
||||
elif decision == RetrievalDecision.NEEDS_MORE:
|
||||
# Retrieve more with expanded query
|
||||
current_query = f"{current_query} Bewertungskriterien Anforderungen"
|
||||
trace[-1]["action"] = "expand_query"
|
||||
trace[-1]["new_query"] = current_query
|
||||
|
||||
elif decision == RetrievalDecision.FALLBACK:
|
||||
# Return what we have, even if not ideal
|
||||
return {
|
||||
"results": (relevant or results)[:final_top_k],
|
||||
"self_rag_enabled": True,
|
||||
"query_used": current_query,
|
||||
"original_query": query if current_query != query else None,
|
||||
"attempts": attempt,
|
||||
"decision": decision.value,
|
||||
"warning": "Fallback mode - low relevance context",
|
||||
"trace": trace,
|
||||
}
|
||||
|
||||
attempt += 1
|
||||
|
||||
# Max attempts reached
|
||||
return {
|
||||
"results": results[:final_top_k] if results else [],
|
||||
"self_rag_enabled": True,
|
||||
"query_used": current_query,
|
||||
"original_query": query if current_query != query else None,
|
||||
"attempts": attempt - 1,
|
||||
"decision": "max_attempts",
|
||||
"warning": "Max retrieval attempts reached",
|
||||
"trace": trace,
|
||||
}
|
||||
|
||||
|
||||
def get_self_rag_info() -> dict:
|
||||
"""Get information about Self-RAG configuration."""
|
||||
return {
|
||||
"enabled": SELF_RAG_ENABLED,
|
||||
"llm_configured": bool(OPENAI_API_KEY),
|
||||
"model": SELF_RAG_MODEL,
|
||||
"relevance_threshold": RELEVANCE_THRESHOLD,
|
||||
"grounding_threshold": GROUNDING_THRESHOLD,
|
||||
"max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
|
||||
"features": [
|
||||
"document_grading",
|
||||
"relevance_filtering",
|
||||
"query_reformulation",
|
||||
"answer_grounding_check",
|
||||
"retrieval_decision",
|
||||
],
|
||||
"sends_data_externally": True, # ALWAYS true when enabled - documents sent to OpenAI
|
||||
"privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
|
||||
"default_enabled": False, # Disabled by default for privacy
|
||||
}
|
||||
# Grading: relevance, filtering, decisions, groundedness
|
||||
from self_rag_grading import ( # noqa: F401
|
||||
SELF_RAG_ENABLED,
|
||||
OPENAI_API_KEY,
|
||||
SELF_RAG_MODEL,
|
||||
RELEVANCE_THRESHOLD,
|
||||
GROUNDING_THRESHOLD,
|
||||
MAX_RETRIEVAL_ATTEMPTS,
|
||||
RetrievalDecision,
|
||||
SelfRAGError,
|
||||
grade_document_relevance,
|
||||
grade_documents_batch,
|
||||
filter_relevant_documents,
|
||||
decide_retrieval_strategy,
|
||||
grade_answer_groundedness,
|
||||
)
|
||||
|
||||
# Retrieval: reformulation, loop, info
|
||||
from self_rag_retrieval import ( # noqa: F401
|
||||
reformulate_query,
|
||||
self_rag_retrieve,
|
||||
get_self_rag_info,
|
||||
)
|
||||
|
||||
285
klausur-service/backend/self_rag_grading.py
Normal file
285
klausur-service/backend/self_rag_grading.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Self-RAG Grading — document relevance grading, filtering, retrieval decisions.
|
||||
|
||||
Extracted from self_rag.py for modularity.
|
||||
|
||||
Based on research:
|
||||
- Self-RAG (Asai et al., 2023)
|
||||
- Corrective RAG (Yan et al., 2024)
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from enum import Enum
|
||||
import httpx
|
||||
|
||||
# Configuration
|
||||
SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
|
||||
|
||||
# Thresholds for self-reflection
|
||||
RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
|
||||
GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
|
||||
MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
|
||||
|
||||
|
||||
class RetrievalDecision(Enum):
|
||||
"""Decision after grading retrieval."""
|
||||
SUFFICIENT = "sufficient" # Context is good, proceed to generation
|
||||
NEEDS_MORE = "needs_more" # Need to retrieve more documents
|
||||
REFORMULATE = "reformulate" # Query needs reformulation
|
||||
FALLBACK = "fallback" # Use fallback (no good context found)
|
||||
|
||||
|
||||
class SelfRAGError(Exception):
|
||||
"""Error during Self-RAG processing."""
|
||||
pass
|
||||
|
||||
|
||||
async def grade_document_relevance(
|
||||
query: str,
|
||||
document: str,
|
||||
) -> Tuple[float, str]:
|
||||
"""
|
||||
Grade whether a document is relevant to the query.
|
||||
|
||||
Returns a score between 0 (irrelevant) and 1 (highly relevant)
|
||||
along with an explanation.
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
# Fallback: simple keyword overlap
|
||||
query_words = set(query.lower().split())
|
||||
doc_words = set(document.lower().split())
|
||||
overlap = len(query_words & doc_words) / max(len(query_words), 1)
|
||||
return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
|
||||
|
||||
prompt = f"""Bewerte, ob das folgende Dokument relevant fuer die Suchanfrage ist.
|
||||
|
||||
SUCHANFRAGE: {query}
|
||||
|
||||
DOKUMENT:
|
||||
{document[:2000]}
|
||||
|
||||
Ist dieses Dokument relevant, um die Anfrage zu beantworten?
|
||||
Beruecksichtige:
|
||||
- Thematische Uebereinstimmung
|
||||
- Enthaelt das Dokument spezifische Informationen zur Anfrage?
|
||||
- Wuerde dieses Dokument bei der Beantwortung helfen?
|
||||
|
||||
Antworte im Format:
|
||||
SCORE: [0.0-1.0]
|
||||
BEGRUENDUNG: [Kurze Erklaerung]"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return 0.5, f"API error: {response.status_code}"
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
|
||||
reason_match = re.search(r'BEGRUENDUNG:\s*(.+)', result, re.DOTALL)
|
||||
reason = reason_match.group(1).strip() if reason_match else result
|
||||
|
||||
return min(max(score, 0.0), 1.0), reason
|
||||
|
||||
except Exception as e:
|
||||
return 0.5, f"Grading error: {str(e)}"
|
||||
|
||||
|
||||
async def grade_documents_batch(
|
||||
query: str,
|
||||
documents: List[str],
|
||||
) -> List[Tuple[float, str]]:
|
||||
"""
|
||||
Grade multiple documents for relevance.
|
||||
|
||||
Returns list of (score, reason) tuples.
|
||||
"""
|
||||
results = []
|
||||
for doc in documents:
|
||||
score, reason = await grade_document_relevance(query, doc)
|
||||
results.append((score, reason))
|
||||
return results
|
||||
|
||||
|
||||
async def filter_relevant_documents(
|
||||
query: str,
|
||||
documents: List[Dict],
|
||||
threshold: float = RELEVANCE_THRESHOLD,
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Filter documents by relevance, separating relevant from irrelevant.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of document dicts with 'text' field
|
||||
threshold: Minimum relevance score to keep
|
||||
|
||||
Returns:
|
||||
Tuple of (relevant_docs, filtered_out_docs)
|
||||
"""
|
||||
relevant = []
|
||||
filtered = []
|
||||
|
||||
for doc in documents:
|
||||
text = doc.get("text", "")
|
||||
score, reason = await grade_document_relevance(query, text)
|
||||
|
||||
doc_with_grade = doc.copy()
|
||||
doc_with_grade["relevance_score"] = score
|
||||
doc_with_grade["relevance_reason"] = reason
|
||||
|
||||
if score >= threshold:
|
||||
relevant.append(doc_with_grade)
|
||||
else:
|
||||
filtered.append(doc_with_grade)
|
||||
|
||||
# Sort relevant by score
|
||||
relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
|
||||
|
||||
return relevant, filtered
|
||||
|
||||
|
||||
async def decide_retrieval_strategy(
|
||||
query: str,
|
||||
documents: List[Dict],
|
||||
attempt: int = 1,
|
||||
) -> Tuple[RetrievalDecision, Dict]:
|
||||
"""
|
||||
Decide what to do based on current retrieval results.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: Retrieved documents with relevance scores
|
||||
attempt: Current retrieval attempt number
|
||||
|
||||
Returns:
|
||||
Tuple of (decision, metadata)
|
||||
"""
|
||||
if not documents:
|
||||
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
||||
return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
|
||||
return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
|
||||
|
||||
# Check average relevance
|
||||
scores = [doc.get("relevance_score", 0.5) for doc in documents]
|
||||
avg_score = sum(scores) / len(scores)
|
||||
max_score = max(scores)
|
||||
|
||||
if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
|
||||
return RetrievalDecision.SUFFICIENT, {
|
||||
"avg_relevance": avg_score,
|
||||
"max_relevance": max_score,
|
||||
"doc_count": len(documents),
|
||||
}
|
||||
|
||||
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
||||
if max_score >= RELEVANCE_THRESHOLD * 0.5:
|
||||
# At least some relevant context, proceed with caution
|
||||
return RetrievalDecision.SUFFICIENT, {
|
||||
"avg_relevance": avg_score,
|
||||
"warning": "Low relevance after max attempts",
|
||||
}
|
||||
return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
|
||||
|
||||
if avg_score < 0.3:
|
||||
return RetrievalDecision.REFORMULATE, {
|
||||
"reason": "Very low relevance, query reformulation needed",
|
||||
"avg_relevance": avg_score,
|
||||
}
|
||||
|
||||
return RetrievalDecision.NEEDS_MORE, {
|
||||
"reason": "Moderate relevance, retrieving more documents",
|
||||
"avg_relevance": avg_score,
|
||||
}
|
||||
|
||||
|
||||
async def grade_answer_groundedness(
|
||||
answer: str,
|
||||
contexts: List[str],
|
||||
) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
Grade whether an answer is grounded in the provided contexts.
|
||||
|
||||
Returns:
|
||||
Tuple of (grounding_score, list of unsupported claims)
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
return 0.5, ["LLM not configured for grounding check"]
|
||||
|
||||
context_text = "\n---\n".join(contexts[:5])
|
||||
|
||||
prompt = f"""Analysiere, ob die folgende Antwort vollstaendig durch die Kontexte gestuetzt wird.
|
||||
|
||||
KONTEXTE:
|
||||
{context_text}
|
||||
|
||||
ANTWORT:
|
||||
{answer}
|
||||
|
||||
Identifiziere:
|
||||
1. Welche Aussagen sind durch die Kontexte belegt?
|
||||
2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
|
||||
|
||||
Antworte im Format:
|
||||
SCORE: [0.0-1.0] (1.0 = vollstaendig belegt)
|
||||
NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return 0.5, [f"API error: {response.status_code}"]
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
|
||||
unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
|
||||
unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
|
||||
|
||||
if unsupported_text.lower() == "keine":
|
||||
unsupported = []
|
||||
else:
|
||||
unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
|
||||
|
||||
return min(max(score, 0.0), 1.0), unsupported
|
||||
|
||||
except Exception as e:
|
||||
return 0.5, [f"Grounding check error: {str(e)}"]
|
||||
255
klausur-service/backend/self_rag_retrieval.py
Normal file
255
klausur-service/backend/self_rag_retrieval.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Self-RAG Retrieval — query reformulation, retrieval loop, info.
|
||||
|
||||
Extracted from self_rag.py for modularity.
|
||||
|
||||
IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
|
||||
When enabled, search queries and retrieved documents are sent to OpenAI API
|
||||
for relevance grading and query reformulation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
import httpx
|
||||
|
||||
from self_rag_grading import (
|
||||
SELF_RAG_ENABLED,
|
||||
OPENAI_API_KEY,
|
||||
SELF_RAG_MODEL,
|
||||
RELEVANCE_THRESHOLD,
|
||||
GROUNDING_THRESHOLD,
|
||||
MAX_RETRIEVAL_ATTEMPTS,
|
||||
RetrievalDecision,
|
||||
filter_relevant_documents,
|
||||
decide_retrieval_strategy,
|
||||
)
|
||||
|
||||
|
||||
async def reformulate_query(
|
||||
original_query: str,
|
||||
context: Optional[str] = None,
|
||||
previous_results_summary: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Reformulate a query to improve retrieval.
|
||||
|
||||
Uses LLM to generate a better query based on:
|
||||
- Original query
|
||||
- Optional context (subject, niveau, etc.)
|
||||
- Summary of why previous retrieval failed
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
# Simple reformulation: expand abbreviations, add synonyms
|
||||
reformulated = original_query
|
||||
expansions = {
|
||||
"EA": "erhoehtes Anforderungsniveau",
|
||||
"eA": "erhoehtes Anforderungsniveau",
|
||||
"GA": "grundlegendes Anforderungsniveau",
|
||||
"gA": "grundlegendes Anforderungsniveau",
|
||||
"AFB": "Anforderungsbereich",
|
||||
"Abi": "Abitur",
|
||||
}
|
||||
for abbr, expansion in expansions.items():
|
||||
if abbr in original_query:
|
||||
reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
|
||||
return reformulated
|
||||
|
||||
prompt = f"""Du bist ein Experte fuer deutsche Bildungsstandards und Pruefungsanforderungen.
|
||||
|
||||
Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
|
||||
ORIGINAL: {original_query}
|
||||
|
||||
{f"KONTEXT: {context}" if context else ""}
|
||||
{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
|
||||
|
||||
Formuliere die Anfrage so um, dass sie:
|
||||
1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
|
||||
2. Relevante Synonyme oder verwandte Begriffe einschliesst
|
||||
3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
|
||||
|
||||
Antworte NUR mit der umformulierten Suchanfrage, ohne Erklaerung."""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.3,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return original_query
|
||||
|
||||
return response.json()["choices"][0]["message"]["content"].strip()
|
||||
|
||||
except Exception:
|
||||
return original_query
|
||||
|
||||
|
||||
async def self_rag_retrieve(
|
||||
query: str,
|
||||
search_func,
|
||||
subject: Optional[str] = None,
|
||||
niveau: Optional[str] = None,
|
||||
initial_top_k: int = 10,
|
||||
final_top_k: int = 5,
|
||||
**search_kwargs
|
||||
) -> Dict:
|
||||
"""
|
||||
Perform Self-RAG enhanced retrieval with reflection and correction.
|
||||
|
||||
This implements a retrieval loop that:
|
||||
1. Retrieves initial documents
|
||||
2. Grades them for relevance
|
||||
3. Decides if more retrieval is needed
|
||||
4. Reformulates query if necessary
|
||||
5. Returns filtered, high-quality context
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
search_func: Async function to perform the actual search
|
||||
subject: Optional subject context
|
||||
niveau: Optional niveau context
|
||||
initial_top_k: Number of documents for initial retrieval
|
||||
final_top_k: Maximum documents to return
|
||||
**search_kwargs: Additional args for search_func
|
||||
|
||||
Returns:
|
||||
Dict with results, metadata, and reflection trace
|
||||
"""
|
||||
if not SELF_RAG_ENABLED:
|
||||
# Fall back to simple search
|
||||
results = await search_func(query=query, limit=final_top_k, **search_kwargs)
|
||||
return {
|
||||
"results": results,
|
||||
"self_rag_enabled": False,
|
||||
"query_used": query,
|
||||
}
|
||||
|
||||
trace = []
|
||||
current_query = query
|
||||
attempt = 1
|
||||
|
||||
while attempt <= MAX_RETRIEVAL_ATTEMPTS:
|
||||
# Step 1: Retrieve documents
|
||||
results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
|
||||
|
||||
trace.append({
|
||||
"attempt": attempt,
|
||||
"query": current_query,
|
||||
"retrieved_count": len(results) if results else 0,
|
||||
})
|
||||
|
||||
if not results:
|
||||
attempt += 1
|
||||
if attempt <= MAX_RETRIEVAL_ATTEMPTS:
|
||||
current_query = await reformulate_query(
|
||||
query,
|
||||
context=f"Fach: {subject}" if subject else None,
|
||||
previous_results_summary="Keine Dokumente gefunden"
|
||||
)
|
||||
trace[-1]["action"] = "reformulate"
|
||||
trace[-1]["new_query"] = current_query
|
||||
continue
|
||||
|
||||
# Step 2: Grade documents for relevance
|
||||
relevant, filtered = await filter_relevant_documents(current_query, results)
|
||||
|
||||
trace[-1]["relevant_count"] = len(relevant)
|
||||
trace[-1]["filtered_count"] = len(filtered)
|
||||
|
||||
# Step 3: Decide what to do
|
||||
decision, decision_meta = await decide_retrieval_strategy(
|
||||
current_query, relevant, attempt
|
||||
)
|
||||
|
||||
trace[-1]["decision"] = decision.value
|
||||
trace[-1]["decision_meta"] = decision_meta
|
||||
|
||||
if decision == RetrievalDecision.SUFFICIENT:
|
||||
# We have good context, return it
|
||||
return {
|
||||
"results": relevant[:final_top_k],
|
||||
"self_rag_enabled": True,
|
||||
"query_used": current_query,
|
||||
"original_query": query if current_query != query else None,
|
||||
"attempts": attempt,
|
||||
"decision": decision.value,
|
||||
"trace": trace,
|
||||
"filtered_out_count": len(filtered),
|
||||
}
|
||||
|
||||
elif decision == RetrievalDecision.REFORMULATE:
|
||||
# Reformulate and try again
|
||||
avg_score = decision_meta.get("avg_relevance", 0)
|
||||
current_query = await reformulate_query(
|
||||
query,
|
||||
context=f"Fach: {subject}" if subject else None,
|
||||
previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
|
||||
)
|
||||
trace[-1]["action"] = "reformulate"
|
||||
trace[-1]["new_query"] = current_query
|
||||
|
||||
elif decision == RetrievalDecision.NEEDS_MORE:
|
||||
# Retrieve more with expanded query
|
||||
current_query = f"{current_query} Bewertungskriterien Anforderungen"
|
||||
trace[-1]["action"] = "expand_query"
|
||||
trace[-1]["new_query"] = current_query
|
||||
|
||||
elif decision == RetrievalDecision.FALLBACK:
|
||||
# Return what we have, even if not ideal
|
||||
return {
|
||||
"results": (relevant or results)[:final_top_k],
|
||||
"self_rag_enabled": True,
|
||||
"query_used": current_query,
|
||||
"original_query": query if current_query != query else None,
|
||||
"attempts": attempt,
|
||||
"decision": decision.value,
|
||||
"warning": "Fallback mode - low relevance context",
|
||||
"trace": trace,
|
||||
}
|
||||
|
||||
attempt += 1
|
||||
|
||||
# Max attempts reached
|
||||
return {
|
||||
"results": results[:final_top_k] if results else [],
|
||||
"self_rag_enabled": True,
|
||||
"query_used": current_query,
|
||||
"original_query": query if current_query != query else None,
|
||||
"attempts": attempt - 1,
|
||||
"decision": "max_attempts",
|
||||
"warning": "Max retrieval attempts reached",
|
||||
"trace": trace,
|
||||
}
|
||||
|
||||
|
||||
def get_self_rag_info() -> dict:
|
||||
"""Get information about Self-RAG configuration."""
|
||||
return {
|
||||
"enabled": SELF_RAG_ENABLED,
|
||||
"llm_configured": bool(OPENAI_API_KEY),
|
||||
"model": SELF_RAG_MODEL,
|
||||
"relevance_threshold": RELEVANCE_THRESHOLD,
|
||||
"grounding_threshold": GROUNDING_THRESHOLD,
|
||||
"max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
|
||||
"features": [
|
||||
"document_grading",
|
||||
"relevance_filtering",
|
||||
"query_reformulation",
|
||||
"answer_grounding_check",
|
||||
"retrieval_decision",
|
||||
],
|
||||
"sends_data_externally": True, # ALWAYS true when enabled
|
||||
"privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
|
||||
"default_enabled": False, # Disabled by default for privacy
|
||||
}
|
||||
164
klausur-service/backend/services/grid_detection_models.py
Normal file
164
klausur-service/backend/services/grid_detection_models.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Grid Detection Models v4
|
||||
|
||||
Data classes for OCR grid detection results.
|
||||
Coordinates use percentage (0-100) and mm (A4 format).
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# A4 dimensions
|
||||
A4_WIDTH_MM = 210.0
|
||||
A4_HEIGHT_MM = 297.0
|
||||
|
||||
# Column margin (1mm)
|
||||
COLUMN_MARGIN_MM = 1.0
|
||||
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
|
||||
|
||||
|
||||
class CellStatus(str, Enum):
|
||||
EMPTY = "empty"
|
||||
RECOGNIZED = "recognized"
|
||||
PROBLEMATIC = "problematic"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class ColumnType(str, Enum):
|
||||
ENGLISH = "english"
|
||||
GERMAN = "german"
|
||||
EXAMPLE = "example"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRRegion:
|
||||
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
|
||||
text: str
|
||||
confidence: float
|
||||
x: float # X position as percentage of page width
|
||||
y: float # Y position as percentage of page height
|
||||
width: float # Width as percentage of page width
|
||||
height: float # Height as percentage of page height
|
||||
|
||||
@property
|
||||
def x_mm(self) -> float:
|
||||
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def y_mm(self) -> float:
|
||||
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||
|
||||
@property
|
||||
def width_mm(self) -> float:
|
||||
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def height_mm(self) -> float:
|
||||
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||
|
||||
@property
|
||||
def center_x(self) -> float:
|
||||
return self.x + self.width / 2
|
||||
|
||||
@property
|
||||
def center_y(self) -> float:
|
||||
return self.y + self.height / 2
|
||||
|
||||
@property
|
||||
def right(self) -> float:
|
||||
return self.x + self.width
|
||||
|
||||
@property
|
||||
def bottom(self) -> float:
|
||||
return self.y + self.height
|
||||
|
||||
|
||||
@dataclass
|
||||
class GridCell:
|
||||
"""A cell in the detected grid with coordinates in percentage (0-100)."""
|
||||
row: int
|
||||
col: int
|
||||
x: float
|
||||
y: float
|
||||
width: float
|
||||
height: float
|
||||
text: str = ""
|
||||
confidence: float = 0.0
|
||||
status: CellStatus = CellStatus.EMPTY
|
||||
column_type: ColumnType = ColumnType.UNKNOWN
|
||||
logical_row: int = 0
|
||||
logical_col: int = 0
|
||||
is_continuation: bool = False
|
||||
|
||||
@property
|
||||
def x_mm(self) -> float:
|
||||
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def y_mm(self) -> float:
|
||||
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||
|
||||
@property
|
||||
def width_mm(self) -> float:
|
||||
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def height_mm(self) -> float:
|
||||
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"row": self.row,
|
||||
"col": self.col,
|
||||
"x": round(self.x, 2),
|
||||
"y": round(self.y, 2),
|
||||
"width": round(self.width, 2),
|
||||
"height": round(self.height, 2),
|
||||
"x_mm": self.x_mm,
|
||||
"y_mm": self.y_mm,
|
||||
"width_mm": self.width_mm,
|
||||
"height_mm": self.height_mm,
|
||||
"text": self.text,
|
||||
"confidence": self.confidence,
|
||||
"status": self.status.value,
|
||||
"column_type": self.column_type.value,
|
||||
"logical_row": self.logical_row,
|
||||
"logical_col": self.logical_col,
|
||||
"is_continuation": self.is_continuation,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GridResult:
|
||||
"""Result of grid detection."""
|
||||
rows: int = 0
|
||||
columns: int = 0
|
||||
cells: List[List[GridCell]] = field(default_factory=list)
|
||||
column_types: List[str] = field(default_factory=list)
|
||||
column_boundaries: List[float] = field(default_factory=list)
|
||||
row_boundaries: List[float] = field(default_factory=list)
|
||||
deskew_angle: float = 0.0
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
cells_dicts = []
|
||||
for row_cells in self.cells:
|
||||
cells_dicts.append([c.to_dict() for c in row_cells])
|
||||
|
||||
return {
|
||||
"rows": self.rows,
|
||||
"columns": self.columns,
|
||||
"cells": cells_dicts,
|
||||
"column_types": self.column_types,
|
||||
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
|
||||
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
|
||||
"deskew_angle": round(self.deskew_angle, 2),
|
||||
"stats": self.stats,
|
||||
"page_dimensions": {
|
||||
"width_mm": A4_WIDTH_MM,
|
||||
"height_mm": A4_HEIGHT_MM,
|
||||
"format": "A4",
|
||||
},
|
||||
}
|
||||
@@ -10,166 +10,21 @@ Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
|
||||
import math
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from typing import List
|
||||
|
||||
from .grid_detection_models import (
|
||||
A4_WIDTH_MM,
|
||||
A4_HEIGHT_MM,
|
||||
COLUMN_MARGIN_MM,
|
||||
CellStatus,
|
||||
ColumnType,
|
||||
OCRRegion,
|
||||
GridCell,
|
||||
GridResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A4 dimensions
|
||||
A4_WIDTH_MM = 210.0
|
||||
A4_HEIGHT_MM = 297.0
|
||||
|
||||
# Column margin (1mm)
|
||||
COLUMN_MARGIN_MM = 1.0
|
||||
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
|
||||
|
||||
|
||||
class CellStatus(str, Enum):
|
||||
EMPTY = "empty"
|
||||
RECOGNIZED = "recognized"
|
||||
PROBLEMATIC = "problematic"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class ColumnType(str, Enum):
|
||||
ENGLISH = "english"
|
||||
GERMAN = "german"
|
||||
EXAMPLE = "example"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRRegion:
|
||||
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
|
||||
text: str
|
||||
confidence: float
|
||||
x: float # X position as percentage of page width
|
||||
y: float # Y position as percentage of page height
|
||||
width: float # Width as percentage of page width
|
||||
height: float # Height as percentage of page height
|
||||
|
||||
@property
|
||||
def x_mm(self) -> float:
|
||||
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def y_mm(self) -> float:
|
||||
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||
|
||||
@property
|
||||
def width_mm(self) -> float:
|
||||
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def height_mm(self) -> float:
|
||||
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||
|
||||
@property
|
||||
def center_x(self) -> float:
|
||||
return self.x + self.width / 2
|
||||
|
||||
@property
|
||||
def center_y(self) -> float:
|
||||
return self.y + self.height / 2
|
||||
|
||||
@property
|
||||
def right(self) -> float:
|
||||
return self.x + self.width
|
||||
|
||||
@property
|
||||
def bottom(self) -> float:
|
||||
return self.y + self.height
|
||||
|
||||
|
||||
@dataclass
|
||||
class GridCell:
|
||||
"""A cell in the detected grid with coordinates in percentage (0-100)."""
|
||||
row: int
|
||||
col: int
|
||||
x: float
|
||||
y: float
|
||||
width: float
|
||||
height: float
|
||||
text: str = ""
|
||||
confidence: float = 0.0
|
||||
status: CellStatus = CellStatus.EMPTY
|
||||
column_type: ColumnType = ColumnType.UNKNOWN
|
||||
logical_row: int = 0
|
||||
logical_col: int = 0
|
||||
is_continuation: bool = False
|
||||
|
||||
@property
|
||||
def x_mm(self) -> float:
|
||||
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def y_mm(self) -> float:
|
||||
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||
|
||||
@property
|
||||
def width_mm(self) -> float:
|
||||
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def height_mm(self) -> float:
|
||||
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"row": self.row,
|
||||
"col": self.col,
|
||||
"x": round(self.x, 2),
|
||||
"y": round(self.y, 2),
|
||||
"width": round(self.width, 2),
|
||||
"height": round(self.height, 2),
|
||||
"x_mm": self.x_mm,
|
||||
"y_mm": self.y_mm,
|
||||
"width_mm": self.width_mm,
|
||||
"height_mm": self.height_mm,
|
||||
"text": self.text,
|
||||
"confidence": self.confidence,
|
||||
"status": self.status.value,
|
||||
"column_type": self.column_type.value,
|
||||
"logical_row": self.logical_row,
|
||||
"logical_col": self.logical_col,
|
||||
"is_continuation": self.is_continuation,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GridResult:
|
||||
"""Result of grid detection."""
|
||||
rows: int = 0
|
||||
columns: int = 0
|
||||
cells: List[List[GridCell]] = field(default_factory=list)
|
||||
column_types: List[str] = field(default_factory=list)
|
||||
column_boundaries: List[float] = field(default_factory=list)
|
||||
row_boundaries: List[float] = field(default_factory=list)
|
||||
deskew_angle: float = 0.0
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
cells_dicts = []
|
||||
for row_cells in self.cells:
|
||||
cells_dicts.append([c.to_dict() for c in row_cells])
|
||||
|
||||
return {
|
||||
"rows": self.rows,
|
||||
"columns": self.columns,
|
||||
"cells": cells_dicts,
|
||||
"column_types": self.column_types,
|
||||
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
|
||||
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
|
||||
"deskew_angle": round(self.deskew_angle, 2),
|
||||
"stats": self.stats,
|
||||
"page_dimensions": {
|
||||
"width_mm": A4_WIDTH_MM,
|
||||
"height_mm": A4_HEIGHT_MM,
|
||||
"format": "A4",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class GridDetectionService:
|
||||
"""Detect grid/table structure from OCR bounding-box regions."""
|
||||
@@ -184,7 +39,7 @@ class GridDetectionService:
|
||||
"""Calculate page skew angle from OCR region positions.
|
||||
|
||||
Uses left-edge alignment of regions to detect consistent tilt.
|
||||
Returns angle in degrees, clamped to ±5°.
|
||||
Returns angle in degrees, clamped to +/-5 degrees.
|
||||
"""
|
||||
if len(regions) < 3:
|
||||
return 0.0
|
||||
@@ -229,12 +84,12 @@ class GridDetectionService:
|
||||
slope = (n * sum_xy - sum_y * sum_x) / denom
|
||||
|
||||
# Convert slope to angle (slope is dx/dy in percent space)
|
||||
# Adjust for aspect ratio: A4 is 210/297 ≈ 0.707
|
||||
# Adjust for aspect ratio: A4 is 210/297 ~ 0.707
|
||||
aspect = A4_WIDTH_MM / A4_HEIGHT_MM
|
||||
angle_rad = math.atan(slope * aspect)
|
||||
angle_deg = math.degrees(angle_rad)
|
||||
|
||||
# Clamp to ±5°
|
||||
# Clamp to +/-5 degrees
|
||||
return max(-5.0, min(5.0, round(angle_deg, 2)))
|
||||
|
||||
def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]:
|
||||
|
||||
@@ -1,594 +1,25 @@
|
||||
"""
|
||||
SmartSpellChecker — Language-aware OCR post-correction without LLMs.
|
||||
SmartSpellChecker — barrel re-export.
|
||||
|
||||
Uses pyspellchecker (MIT) with dual EN+DE dictionaries for:
|
||||
- Automatic language detection per word (dual-dictionary heuristic)
|
||||
- OCR error correction (digit↔letter, umlauts, transpositions)
|
||||
- Context-based disambiguation (a/I, l/I) via bigram lookup
|
||||
- Mixed-language support for example sentences
|
||||
All implementation split into:
|
||||
smart_spell_core — init, data types, language detection, word correction
|
||||
smart_spell_text — full text correction, boundary repair, context split
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from spellchecker import SpellChecker as _SpellChecker
|
||||
_en_spell = _SpellChecker(language='en', distance=1)
|
||||
_de_spell = _SpellChecker(language='de', distance=1)
|
||||
_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AVAILABLE = False
|
||||
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
|
||||
|
||||
Lang = Literal["en", "de", "both", "unknown"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bigram context for a/I disambiguation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Words that commonly follow "I" (subject pronoun → verb/modal)
|
||||
_I_FOLLOWERS: frozenset = frozenset({
|
||||
"am", "was", "have", "had", "do", "did", "will", "would", "can",
|
||||
"could", "should", "shall", "may", "might", "must",
|
||||
"think", "know", "see", "want", "need", "like", "love", "hate",
|
||||
"go", "went", "come", "came", "say", "said", "get", "got",
|
||||
"make", "made", "take", "took", "give", "gave", "tell", "told",
|
||||
"feel", "felt", "find", "found", "believe", "hope", "wish",
|
||||
"remember", "forget", "understand", "mean", "meant",
|
||||
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
|
||||
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
|
||||
"really", "just", "also", "always", "never", "often", "sometimes",
|
||||
})
|
||||
|
||||
# Words that commonly follow "a" (article → noun/adjective)
|
||||
_A_FOLLOWERS: frozenset = frozenset({
|
||||
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
|
||||
"long", "short", "big", "small", "large", "huge", "tiny",
|
||||
"nice", "beautiful", "wonderful", "terrible", "horrible",
|
||||
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
|
||||
"book", "car", "house", "room", "school", "teacher", "student",
|
||||
"day", "week", "month", "year", "time", "place", "way",
|
||||
"friend", "family", "person", "problem", "question", "story",
|
||||
"very", "really", "quite", "rather", "pretty", "single",
|
||||
})
|
||||
|
||||
# Digit→letter substitutions (OCR confusion)
|
||||
_DIGIT_SUBS: Dict[str, List[str]] = {
|
||||
'0': ['o', 'O'],
|
||||
'1': ['l', 'I'],
|
||||
'5': ['s', 'S'],
|
||||
'6': ['g', 'G'],
|
||||
'8': ['b', 'B'],
|
||||
'|': ['I', 'l'],
|
||||
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" → "pl")
|
||||
}
|
||||
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
|
||||
|
||||
# Umlaut confusion: OCR drops dots (ü→u, ä→a, ö→o)
|
||||
_UMLAUT_MAP = {
|
||||
'a': 'ä', 'o': 'ö', 'u': 'ü', 'i': 'ü',
|
||||
'A': 'Ä', 'O': 'Ö', 'U': 'Ü', 'I': 'Ü',
|
||||
}
|
||||
|
||||
# Tokenizer — includes | and / so OCR artifacts like "p/" are treated as words
|
||||
_TOKEN_RE = re.compile(r"([A-Za-zÄÖÜäöüß'|/]+)([^A-Za-zÄÖÜäöüß'|/]*)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class CorrectionResult:
|
||||
original: str
|
||||
corrected: str
|
||||
lang_detected: Lang
|
||||
changed: bool
|
||||
changes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SmartSpellChecker:
|
||||
"""Language-aware OCR spell checker using pyspellchecker (no LLM)."""
|
||||
|
||||
def __init__(self):
|
||||
if not _AVAILABLE:
|
||||
raise RuntimeError("pyspellchecker not installed")
|
||||
self.en = _en_spell
|
||||
self.de = _de_spell
|
||||
|
||||
# --- Language detection ---
|
||||
|
||||
def detect_word_lang(self, word: str) -> Lang:
|
||||
"""Detect language of a single word using dual-dict heuristic."""
|
||||
w = word.lower().strip(".,;:!?\"'()")
|
||||
if not w:
|
||||
return "unknown"
|
||||
in_en = bool(self.en.known([w]))
|
||||
in_de = bool(self.de.known([w]))
|
||||
if in_en and in_de:
|
||||
return "both"
|
||||
if in_en:
|
||||
return "en"
|
||||
if in_de:
|
||||
return "de"
|
||||
return "unknown"
|
||||
|
||||
def detect_text_lang(self, text: str) -> Lang:
|
||||
"""Detect dominant language of a text string (sentence/phrase)."""
|
||||
words = re.findall(r"[A-Za-zÄÖÜäöüß]+", text)
|
||||
if not words:
|
||||
return "unknown"
|
||||
|
||||
en_count = 0
|
||||
de_count = 0
|
||||
for w in words:
|
||||
lang = self.detect_word_lang(w)
|
||||
if lang == "en":
|
||||
en_count += 1
|
||||
elif lang == "de":
|
||||
de_count += 1
|
||||
# "both" doesn't count for either
|
||||
|
||||
if en_count > de_count:
|
||||
return "en"
|
||||
if de_count > en_count:
|
||||
return "de"
|
||||
if en_count == de_count and en_count > 0:
|
||||
return "both"
|
||||
return "unknown"
|
||||
|
||||
# --- Single-word correction ---
|
||||
|
||||
def _known(self, word: str) -> bool:
|
||||
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
|
||||
w = word.lower()
|
||||
if bool(self.en.known([w])) or bool(self.de.known([w])):
|
||||
return True
|
||||
# Also accept known abbreviations (sth, sb, adj, etc.)
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
if w in _KNOWN_ABBREVIATIONS:
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _word_freq(self, word: str) -> float:
|
||||
"""Get word frequency (max of EN and DE)."""
|
||||
w = word.lower()
|
||||
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
|
||||
|
||||
def _known_in(self, word: str, lang: str) -> bool:
|
||||
"""True if word is known in a specific language dictionary."""
|
||||
w = word.lower()
|
||||
spell = self.en if lang == "en" else self.de
|
||||
return bool(spell.known([w]))
|
||||
|
||||
def correct_word(self, word: str, lang: str = "en",
|
||||
prev_word: str = "", next_word: str = "") -> Optional[str]:
|
||||
"""Correct a single word for the given language.
|
||||
|
||||
Returns None if no correction needed, or the corrected string.
|
||||
|
||||
Args:
|
||||
word: The word to check/correct
|
||||
lang: Expected language ("en" or "de")
|
||||
prev_word: Previous word (for context)
|
||||
next_word: Next word (for context)
|
||||
"""
|
||||
if not word or not word.strip():
|
||||
return None
|
||||
|
||||
# Skip numbers, abbreviations with dots, very short tokens
|
||||
if word.isdigit() or '.' in word:
|
||||
return None
|
||||
|
||||
# Skip IPA/phonetic content in brackets
|
||||
if '[' in word or ']' in word:
|
||||
return None
|
||||
|
||||
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
|
||||
|
||||
# 1. Already known → no fix
|
||||
if self._known(word):
|
||||
# But check a/I disambiguation for single-char words
|
||||
if word.lower() in ('l', '|') and next_word:
|
||||
return self._disambiguate_a_I(word, next_word)
|
||||
return None
|
||||
|
||||
# 2. Digit/pipe substitution
|
||||
if has_suspicious:
|
||||
if word == '|':
|
||||
return 'I'
|
||||
# Try single-char substitutions
|
||||
for i, ch in enumerate(word):
|
||||
if ch not in _DIGIT_SUBS:
|
||||
continue
|
||||
for replacement in _DIGIT_SUBS[ch]:
|
||||
candidate = word[:i] + replacement + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
# Try multi-char substitution (e.g., "sch00l" → "school")
|
||||
multi = self._try_multi_digit_sub(word)
|
||||
if multi:
|
||||
return multi
|
||||
|
||||
# 3. Umlaut correction (German)
|
||||
if lang == "de" and len(word) >= 3 and word.isalpha():
|
||||
umlaut_fix = self._try_umlaut_fix(word)
|
||||
if umlaut_fix:
|
||||
return umlaut_fix
|
||||
|
||||
# 4. General spell correction
|
||||
if not has_suspicious and len(word) >= 3 and word.isalpha():
|
||||
# Safety: don't correct if the word is valid in the OTHER language
|
||||
# (either directly or via umlaut fix)
|
||||
other_lang = "de" if lang == "en" else "en"
|
||||
if self._known_in(word, other_lang):
|
||||
return None
|
||||
if other_lang == "de" and self._try_umlaut_fix(word):
|
||||
return None # has a valid DE umlaut variant → don't touch
|
||||
|
||||
spell = self.en if lang == "en" else self.de
|
||||
correction = spell.correction(word.lower())
|
||||
if correction and correction != word.lower():
|
||||
if word[0].isupper():
|
||||
correction = correction[0].upper() + correction[1:]
|
||||
if self._known(correction):
|
||||
return correction
|
||||
|
||||
return None
|
||||
|
||||
# --- Multi-digit substitution ---
|
||||
|
||||
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
|
||||
"""Try replacing multiple digits simultaneously."""
|
||||
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
|
||||
if len(positions) < 1 or len(positions) > 4:
|
||||
return None
|
||||
|
||||
# Try all combinations (max 2^4 = 16 for 4 positions)
|
||||
chars = list(word)
|
||||
best = None
|
||||
self._multi_sub_recurse(chars, positions, 0, best_result=[None])
|
||||
return self._multi_sub_recurse_result
|
||||
|
||||
_multi_sub_recurse_result: Optional[str] = None
|
||||
|
||||
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
|
||||
"""Try replacing multiple digits simultaneously using BFS."""
|
||||
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
|
||||
if not positions or len(positions) > 4:
|
||||
return None
|
||||
|
||||
# BFS over substitution combinations
|
||||
queue = [list(word)]
|
||||
for pos, ch in positions:
|
||||
next_queue = []
|
||||
for current in queue:
|
||||
# Keep original
|
||||
next_queue.append(current[:])
|
||||
# Try each substitution
|
||||
for repl in _DIGIT_SUBS[ch]:
|
||||
variant = current[:]
|
||||
variant[pos] = repl
|
||||
next_queue.append(variant)
|
||||
queue = next_queue
|
||||
|
||||
# Check which combinations produce known words
|
||||
for combo in queue:
|
||||
candidate = "".join(combo)
|
||||
if candidate != word and self._known(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
# --- Umlaut fix ---
|
||||
|
||||
def _try_umlaut_fix(self, word: str) -> Optional[str]:
|
||||
"""Try single-char umlaut substitutions for German words."""
|
||||
for i, ch in enumerate(word):
|
||||
if ch in _UMLAUT_MAP:
|
||||
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
# --- Boundary repair (shifted word boundaries) ---
|
||||
|
||||
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
|
||||
"""Fix shifted word boundaries between adjacent tokens.
|
||||
|
||||
OCR sometimes shifts the boundary: "at sth." → "ats th."
|
||||
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
|
||||
Returns (fixed_word1, fixed_word2) or None.
|
||||
"""
|
||||
# Import known abbreviations for vocabulary context
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
except ImportError:
|
||||
_KNOWN_ABBREVIATIONS = set()
|
||||
|
||||
# Strip trailing punctuation for checking, preserve for result
|
||||
w2_stripped = word2.rstrip(".,;:!?")
|
||||
w2_punct = word2[len(w2_stripped):]
|
||||
|
||||
# Try shifting 1-2 chars from word1 → word2
|
||||
for shift in (1, 2):
|
||||
if len(word1) <= shift:
|
||||
continue
|
||||
new_w1 = word1[:-shift]
|
||||
new_w2_base = word1[-shift:] + w2_stripped
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
# Try shifting 1-2 chars from word2 → word1
|
||||
for shift in (1, 2):
|
||||
if len(w2_stripped) <= shift:
|
||||
continue
|
||||
new_w1 = word1 + w2_stripped[:shift]
|
||||
new_w2_base = w2_stripped[shift:]
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
return None
|
||||
|
||||
# --- Context-based word split for ambiguous merges ---
|
||||
|
||||
# Patterns where a valid word is actually "a" + adjective/noun
|
||||
_ARTICLE_SPLIT_CANDIDATES = {
|
||||
# word → (article, remainder) — only when followed by a compatible word
|
||||
"anew": ("a", "new"),
|
||||
"areal": ("a", "real"),
|
||||
"alive": None, # genuinely one word, never split
|
||||
"alone": None,
|
||||
"aware": None,
|
||||
"alike": None,
|
||||
"apart": None,
|
||||
"aside": None,
|
||||
"above": None,
|
||||
"about": None,
|
||||
"among": None,
|
||||
"along": None,
|
||||
}
|
||||
|
||||
def _try_context_split(self, word: str, next_word: str,
|
||||
prev_word: str) -> Optional[str]:
|
||||
"""Split words like 'anew' → 'a new' when context indicates a merge.
|
||||
|
||||
Only splits when:
|
||||
- The word is in the split candidates list
|
||||
- The following word makes sense as a noun (for "a + adj + noun" pattern)
|
||||
- OR the word is unknown and can be split into article + known word
|
||||
"""
|
||||
w_lower = word.lower()
|
||||
|
||||
# Check explicit candidates
|
||||
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
|
||||
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
|
||||
if split is None:
|
||||
return None # explicitly marked as "don't split"
|
||||
article, remainder = split
|
||||
# Only split if followed by a word (noun pattern)
|
||||
if next_word and next_word[0].islower():
|
||||
return f"{article} {remainder}"
|
||||
# Also split if remainder + next_word makes a common phrase
|
||||
if next_word and self._known(next_word):
|
||||
return f"{article} {remainder}"
|
||||
|
||||
# Generic: if word starts with 'a' and rest is a known adjective/word
|
||||
if (len(word) >= 4 and word[0].lower() == 'a'
|
||||
and not self._known(word) # only for UNKNOWN words
|
||||
and self._known(word[1:])):
|
||||
return f"a {word[1:]}"
|
||||
|
||||
return None
|
||||
|
||||
# --- a/I disambiguation ---
|
||||
|
||||
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
|
||||
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
|
||||
nw = next_word.lower().strip(".,;:!?")
|
||||
if nw in _I_FOLLOWERS:
|
||||
return "I"
|
||||
if nw in _A_FOLLOWERS:
|
||||
return "a"
|
||||
# Fallback: check if next word is more commonly a verb (→I) or noun/adj (→a)
|
||||
# Simple heuristic: if next word starts with uppercase (and isn't first in sentence)
|
||||
# it's likely a German noun following "I"... but in English context, uppercase
|
||||
# after "I" is unusual.
|
||||
return None # uncertain, don't change
|
||||
|
||||
# --- Full text correction ---
|
||||
|
||||
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
|
||||
"""Correct a full text string (field value).
|
||||
|
||||
Three passes:
|
||||
1. Boundary repair — fix shifted word boundaries between adjacent tokens
|
||||
2. Context split — split ambiguous merges (anew → a new)
|
||||
3. Per-word correction — spell check individual words
|
||||
|
||||
Args:
|
||||
text: The text to correct
|
||||
lang: Expected language ("en" or "de")
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return CorrectionResult(text, text, "unknown", False)
|
||||
|
||||
detected = self.detect_text_lang(text) if lang == "auto" else lang
|
||||
effective_lang = detected if detected in ("en", "de") else "en"
|
||||
|
||||
changes: List[str] = []
|
||||
tokens = list(_TOKEN_RE.finditer(text))
|
||||
|
||||
# Extract token list: [(word, separator), ...]
|
||||
token_list: List[List[str]] = [] # [[word, sep], ...]
|
||||
for m in tokens:
|
||||
token_list.append([m.group(1), m.group(2)])
|
||||
|
||||
# --- Pass 1: Boundary repair between adjacent unknown words ---
|
||||
# Import abbreviations for the heuristic below
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
|
||||
except ImportError:
|
||||
_ABBREVS = set()
|
||||
|
||||
for i in range(len(token_list) - 1):
|
||||
w1 = token_list[i][0]
|
||||
w2_raw = token_list[i + 1][0]
|
||||
|
||||
# Skip boundary repair for IPA/bracket content
|
||||
# Brackets may be in the token OR in the adjacent separators
|
||||
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
|
||||
sep_after_w1 = token_list[i][1]
|
||||
sep_after_w2 = token_list[i + 1][1]
|
||||
has_bracket = (
|
||||
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
|
||||
or ']' in sep_after_w1 # w1 text was inside [brackets]
|
||||
or '[' in sep_after_w1 # w2 starts a bracket
|
||||
or ']' in sep_after_w2 # w2 text was inside [brackets]
|
||||
or '[' in sep_before_w1 # w1 starts a bracket
|
||||
)
|
||||
if has_bracket:
|
||||
continue
|
||||
|
||||
# Include trailing punct from separator in w2 for abbreviation matching
|
||||
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
|
||||
|
||||
# Try boundary repair — always, even if both words are valid.
|
||||
# Use word-frequency scoring to decide if repair is better.
|
||||
repair = self._try_boundary_repair(w1, w2_with_punct)
|
||||
if not repair and w2_with_punct != w2_raw:
|
||||
repair = self._try_boundary_repair(w1, w2_raw)
|
||||
if repair:
|
||||
new_w1, new_w2_full = repair
|
||||
new_w2_base = new_w2_full.rstrip(".,;:!?")
|
||||
|
||||
# Frequency-based scoring: product of word frequencies
|
||||
# Higher product = more common word pair = better
|
||||
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
|
||||
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
|
||||
|
||||
# Abbreviation bonus: if repair produces a known abbreviation
|
||||
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
|
||||
if has_abbrev:
|
||||
# Accept abbreviation repair ONLY if at least one of the
|
||||
# original words is rare/unknown (prevents "Can I" → "Ca nI"
|
||||
# where both original words are common and correct).
|
||||
# "Rare" = frequency < 1e-6 (covers "ats", "th" but not "Can", "I")
|
||||
RARE_THRESHOLD = 1e-6
|
||||
orig_both_common = (
|
||||
self._word_freq(w1) > RARE_THRESHOLD
|
||||
and self._word_freq(w2_raw) > RARE_THRESHOLD
|
||||
)
|
||||
if not orig_both_common:
|
||||
new_freq = max(new_freq, old_freq * 10)
|
||||
else:
|
||||
has_abbrev = False # both originals common → don't trust
|
||||
|
||||
# Accept if repair produces a more frequent word pair
|
||||
# (threshold: at least 5x more frequent to avoid false positives)
|
||||
if new_freq > old_freq * 5:
|
||||
new_w2_punct = new_w2_full[len(new_w2_base):]
|
||||
changes.append(f"{w1} {w2_raw}→{new_w1} {new_w2_base}")
|
||||
token_list[i][0] = new_w1
|
||||
token_list[i + 1][0] = new_w2_base
|
||||
if new_w2_punct:
|
||||
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
|
||||
|
||||
# --- Pass 2: Context split (anew → a new) ---
|
||||
expanded: List[List[str]] = []
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
split = self._try_context_split(word, next_word, prev_word)
|
||||
if split and split != word:
|
||||
changes.append(f"{word}→{split}")
|
||||
expanded.append([split, sep])
|
||||
else:
|
||||
expanded.append([word, sep])
|
||||
token_list = expanded
|
||||
|
||||
# --- Pass 3: Per-word correction ---
|
||||
parts: List[str] = []
|
||||
|
||||
# Preserve any leading text before the first token match
|
||||
# (e.g., "(= " before "I won and he lost.")
|
||||
first_start = tokens[0].start() if tokens else 0
|
||||
if first_start > 0:
|
||||
parts.append(text[:first_start])
|
||||
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
# Skip words inside IPA brackets (brackets land in separators)
|
||||
prev_sep = token_list[i - 1][1] if i > 0 else ""
|
||||
if '[' in prev_sep or ']' in sep:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
continue
|
||||
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
|
||||
correction = self.correct_word(
|
||||
word, lang=effective_lang,
|
||||
prev_word=prev_word, next_word=next_word,
|
||||
)
|
||||
if correction and correction != word:
|
||||
changes.append(f"{word}→{correction}")
|
||||
parts.append(correction)
|
||||
else:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
|
||||
# Append any trailing text
|
||||
last_end = tokens[-1].end() if tokens else 0
|
||||
if last_end < len(text):
|
||||
parts.append(text[last_end:])
|
||||
|
||||
corrected = "".join(parts)
|
||||
return CorrectionResult(
|
||||
original=text,
|
||||
corrected=corrected,
|
||||
lang_detected=detected,
|
||||
changed=corrected != text,
|
||||
changes=changes,
|
||||
)
|
||||
|
||||
# --- Vocabulary entry correction ---
|
||||
|
||||
def correct_vocab_entry(self, english: str, german: str,
|
||||
example: str = "") -> Dict[str, CorrectionResult]:
|
||||
"""Correct a full vocabulary entry (EN + DE + example).
|
||||
|
||||
Uses column position to determine language — the most reliable signal.
|
||||
"""
|
||||
results = {}
|
||||
results["english"] = self.correct_text(english, lang="en")
|
||||
results["german"] = self.correct_text(german, lang="de")
|
||||
if example:
|
||||
# For examples, auto-detect language
|
||||
results["example"] = self.correct_text(example, lang="auto")
|
||||
return results
|
||||
# Core: data types, lang detection (re-exported for tests)
|
||||
from smart_spell_core import ( # noqa: F401
|
||||
_AVAILABLE,
|
||||
_DIGIT_SUBS,
|
||||
_SUSPICIOUS_CHARS,
|
||||
_UMLAUT_MAP,
|
||||
_TOKEN_RE,
|
||||
_I_FOLLOWERS,
|
||||
_A_FOLLOWERS,
|
||||
CorrectionResult,
|
||||
Lang,
|
||||
)
|
||||
|
||||
# Text: SmartSpellChecker class (the main public API)
|
||||
from smart_spell_text import SmartSpellChecker # noqa: F401
|
||||
|
||||
298
klausur-service/backend/smart_spell_core.py
Normal file
298
klausur-service/backend/smart_spell_core.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
SmartSpellChecker Core — init, data types, language detection, word correction.
|
||||
|
||||
Extracted from smart_spell.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from spellchecker import SpellChecker as _SpellChecker
|
||||
_en_spell = _SpellChecker(language='en', distance=1)
|
||||
_de_spell = _SpellChecker(language='de', distance=1)
|
||||
_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AVAILABLE = False
|
||||
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
|
||||
|
||||
Lang = Literal["en", "de", "both", "unknown"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bigram context for a/I disambiguation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Words that commonly follow "I" (subject pronoun -> verb/modal)
|
||||
_I_FOLLOWERS: frozenset = frozenset({
|
||||
"am", "was", "have", "had", "do", "did", "will", "would", "can",
|
||||
"could", "should", "shall", "may", "might", "must",
|
||||
"think", "know", "see", "want", "need", "like", "love", "hate",
|
||||
"go", "went", "come", "came", "say", "said", "get", "got",
|
||||
"make", "made", "take", "took", "give", "gave", "tell", "told",
|
||||
"feel", "felt", "find", "found", "believe", "hope", "wish",
|
||||
"remember", "forget", "understand", "mean", "meant",
|
||||
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
|
||||
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
|
||||
"really", "just", "also", "always", "never", "often", "sometimes",
|
||||
})
|
||||
|
||||
# Words that commonly follow "a" (article -> noun/adjective)
|
||||
_A_FOLLOWERS: frozenset = frozenset({
|
||||
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
|
||||
"long", "short", "big", "small", "large", "huge", "tiny",
|
||||
"nice", "beautiful", "wonderful", "terrible", "horrible",
|
||||
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
|
||||
"book", "car", "house", "room", "school", "teacher", "student",
|
||||
"day", "week", "month", "year", "time", "place", "way",
|
||||
"friend", "family", "person", "problem", "question", "story",
|
||||
"very", "really", "quite", "rather", "pretty", "single",
|
||||
})
|
||||
|
||||
# Digit->letter substitutions (OCR confusion)
|
||||
_DIGIT_SUBS: Dict[str, List[str]] = {
|
||||
'0': ['o', 'O'],
|
||||
'1': ['l', 'I'],
|
||||
'5': ['s', 'S'],
|
||||
'6': ['g', 'G'],
|
||||
'8': ['b', 'B'],
|
||||
'|': ['I', 'l'],
|
||||
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
|
||||
}
|
||||
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
|
||||
|
||||
# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
|
||||
_UMLAUT_MAP = {
|
||||
'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
|
||||
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
|
||||
}
|
||||
|
||||
# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
|
||||
_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class CorrectionResult:
|
||||
original: str
|
||||
corrected: str
|
||||
lang_detected: Lang
|
||||
changed: bool
|
||||
changes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core class — language detection and word-level correction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SmartSpellCoreBase:
|
||||
"""Base class with language detection and single-word correction.
|
||||
|
||||
Not intended for direct use — SmartSpellChecker inherits from this.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not _AVAILABLE:
|
||||
raise RuntimeError("pyspellchecker not installed")
|
||||
self.en = _en_spell
|
||||
self.de = _de_spell
|
||||
|
||||
# --- Language detection ---
|
||||
|
||||
def detect_word_lang(self, word: str) -> Lang:
|
||||
"""Detect language of a single word using dual-dict heuristic."""
|
||||
w = word.lower().strip(".,;:!?\"'()")
|
||||
if not w:
|
||||
return "unknown"
|
||||
in_en = bool(self.en.known([w]))
|
||||
in_de = bool(self.de.known([w]))
|
||||
if in_en and in_de:
|
||||
return "both"
|
||||
if in_en:
|
||||
return "en"
|
||||
if in_de:
|
||||
return "de"
|
||||
return "unknown"
|
||||
|
||||
def detect_text_lang(self, text: str) -> Lang:
|
||||
"""Detect dominant language of a text string (sentence/phrase)."""
|
||||
words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
|
||||
if not words:
|
||||
return "unknown"
|
||||
|
||||
en_count = 0
|
||||
de_count = 0
|
||||
for w in words:
|
||||
lang = self.detect_word_lang(w)
|
||||
if lang == "en":
|
||||
en_count += 1
|
||||
elif lang == "de":
|
||||
de_count += 1
|
||||
# "both" doesn't count for either
|
||||
|
||||
if en_count > de_count:
|
||||
return "en"
|
||||
if de_count > en_count:
|
||||
return "de"
|
||||
if en_count == de_count and en_count > 0:
|
||||
return "both"
|
||||
return "unknown"
|
||||
|
||||
# --- Single-word correction ---
|
||||
|
||||
def _known(self, word: str) -> bool:
|
||||
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
|
||||
w = word.lower()
|
||||
if bool(self.en.known([w])) or bool(self.de.known([w])):
|
||||
return True
|
||||
# Also accept known abbreviations (sth, sb, adj, etc.)
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
if w in _KNOWN_ABBREVIATIONS:
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _word_freq(self, word: str) -> float:
|
||||
"""Get word frequency (max of EN and DE)."""
|
||||
w = word.lower()
|
||||
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
|
||||
|
||||
def _known_in(self, word: str, lang: str) -> bool:
|
||||
"""True if word is known in a specific language dictionary."""
|
||||
w = word.lower()
|
||||
spell = self.en if lang == "en" else self.de
|
||||
return bool(spell.known([w]))
|
||||
|
||||
def correct_word(self, word: str, lang: str = "en",
|
||||
prev_word: str = "", next_word: str = "") -> Optional[str]:
|
||||
"""Correct a single word for the given language.
|
||||
|
||||
Returns None if no correction needed, or the corrected string.
|
||||
"""
|
||||
if not word or not word.strip():
|
||||
return None
|
||||
|
||||
# Skip numbers, abbreviations with dots, very short tokens
|
||||
if word.isdigit() or '.' in word:
|
||||
return None
|
||||
|
||||
# Skip IPA/phonetic content in brackets
|
||||
if '[' in word or ']' in word:
|
||||
return None
|
||||
|
||||
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
|
||||
|
||||
# 1. Already known -> no fix
|
||||
if self._known(word):
|
||||
# But check a/I disambiguation for single-char words
|
||||
if word.lower() in ('l', '|') and next_word:
|
||||
return self._disambiguate_a_I(word, next_word)
|
||||
return None
|
||||
|
||||
# 2. Digit/pipe substitution
|
||||
if has_suspicious:
|
||||
if word == '|':
|
||||
return 'I'
|
||||
# Try single-char substitutions
|
||||
for i, ch in enumerate(word):
|
||||
if ch not in _DIGIT_SUBS:
|
||||
continue
|
||||
for replacement in _DIGIT_SUBS[ch]:
|
||||
candidate = word[:i] + replacement + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
# Try multi-char substitution (e.g., "sch00l" -> "school")
|
||||
multi = self._try_multi_digit_sub(word)
|
||||
if multi:
|
||||
return multi
|
||||
|
||||
# 3. Umlaut correction (German)
|
||||
if lang == "de" and len(word) >= 3 and word.isalpha():
|
||||
umlaut_fix = self._try_umlaut_fix(word)
|
||||
if umlaut_fix:
|
||||
return umlaut_fix
|
||||
|
||||
# 4. General spell correction
|
||||
if not has_suspicious and len(word) >= 3 and word.isalpha():
|
||||
# Safety: don't correct if the word is valid in the OTHER language
|
||||
other_lang = "de" if lang == "en" else "en"
|
||||
if self._known_in(word, other_lang):
|
||||
return None
|
||||
if other_lang == "de" and self._try_umlaut_fix(word):
|
||||
return None # has a valid DE umlaut variant -> don't touch
|
||||
|
||||
spell = self.en if lang == "en" else self.de
|
||||
correction = spell.correction(word.lower())
|
||||
if correction and correction != word.lower():
|
||||
if word[0].isupper():
|
||||
correction = correction[0].upper() + correction[1:]
|
||||
if self._known(correction):
|
||||
return correction
|
||||
|
||||
return None
|
||||
|
||||
# --- Multi-digit substitution ---
|
||||
|
||||
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
|
||||
"""Try replacing multiple digits simultaneously using BFS."""
|
||||
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
|
||||
if not positions or len(positions) > 4:
|
||||
return None
|
||||
|
||||
# BFS over substitution combinations
|
||||
queue = [list(word)]
|
||||
for pos, ch in positions:
|
||||
next_queue = []
|
||||
for current in queue:
|
||||
# Keep original
|
||||
next_queue.append(current[:])
|
||||
# Try each substitution
|
||||
for repl in _DIGIT_SUBS[ch]:
|
||||
variant = current[:]
|
||||
variant[pos] = repl
|
||||
next_queue.append(variant)
|
||||
queue = next_queue
|
||||
|
||||
# Check which combinations produce known words
|
||||
for combo in queue:
|
||||
candidate = "".join(combo)
|
||||
if candidate != word and self._known(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
# --- Umlaut fix ---
|
||||
|
||||
def _try_umlaut_fix(self, word: str) -> Optional[str]:
|
||||
"""Try single-char umlaut substitutions for German words."""
|
||||
for i, ch in enumerate(word):
|
||||
if ch in _UMLAUT_MAP:
|
||||
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
# --- a/I disambiguation ---
|
||||
|
||||
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
|
||||
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
|
||||
nw = next_word.lower().strip(".,;:!?")
|
||||
if nw in _I_FOLLOWERS:
|
||||
return "I"
|
||||
if nw in _A_FOLLOWERS:
|
||||
return "a"
|
||||
return None # uncertain, don't change
|
||||
289
klausur-service/backend/smart_spell_text.py
Normal file
289
klausur-service/backend/smart_spell_text.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
SmartSpellChecker Text — full text correction, boundary repair, context split.
|
||||
|
||||
Extracted from smart_spell.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from smart_spell_core import (
|
||||
_SmartSpellCoreBase,
|
||||
_TOKEN_RE,
|
||||
CorrectionResult,
|
||||
Lang,
|
||||
)
|
||||
|
||||
|
||||
class SmartSpellChecker(_SmartSpellCoreBase):
|
||||
"""Language-aware OCR spell checker using pyspellchecker (no LLM).
|
||||
|
||||
Inherits single-word correction from _SmartSpellCoreBase.
|
||||
Adds text-level passes: boundary repair, context split, full correction.
|
||||
"""
|
||||
|
||||
# --- Boundary repair (shifted word boundaries) ---
|
||||
|
||||
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
|
||||
"""Fix shifted word boundaries between adjacent tokens.
|
||||
|
||||
OCR sometimes shifts the boundary: "at sth." -> "ats th."
|
||||
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
|
||||
Returns (fixed_word1, fixed_word2) or None.
|
||||
"""
|
||||
# Import known abbreviations for vocabulary context
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
except ImportError:
|
||||
_KNOWN_ABBREVIATIONS = set()
|
||||
|
||||
# Strip trailing punctuation for checking, preserve for result
|
||||
w2_stripped = word2.rstrip(".,;:!?")
|
||||
w2_punct = word2[len(w2_stripped):]
|
||||
|
||||
# Try shifting 1-2 chars from word1 -> word2
|
||||
for shift in (1, 2):
|
||||
if len(word1) <= shift:
|
||||
continue
|
||||
new_w1 = word1[:-shift]
|
||||
new_w2_base = word1[-shift:] + w2_stripped
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
# Try shifting 1-2 chars from word2 -> word1
|
||||
for shift in (1, 2):
|
||||
if len(w2_stripped) <= shift:
|
||||
continue
|
||||
new_w1 = word1 + w2_stripped[:shift]
|
||||
new_w2_base = w2_stripped[shift:]
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
return None
|
||||
|
||||
# --- Context-based word split for ambiguous merges ---
|
||||
|
||||
# Patterns where a valid word is actually "a" + adjective/noun
|
||||
_ARTICLE_SPLIT_CANDIDATES = {
|
||||
# word -> (article, remainder) -- only when followed by a compatible word
|
||||
"anew": ("a", "new"),
|
||||
"areal": ("a", "real"),
|
||||
"alive": None, # genuinely one word, never split
|
||||
"alone": None,
|
||||
"aware": None,
|
||||
"alike": None,
|
||||
"apart": None,
|
||||
"aside": None,
|
||||
"above": None,
|
||||
"about": None,
|
||||
"among": None,
|
||||
"along": None,
|
||||
}
|
||||
|
||||
def _try_context_split(self, word: str, next_word: str,
|
||||
prev_word: str) -> Optional[str]:
|
||||
"""Split words like 'anew' -> 'a new' when context indicates a merge.
|
||||
|
||||
Only splits when:
|
||||
- The word is in the split candidates list
|
||||
- The following word makes sense as a noun (for "a + adj + noun" pattern)
|
||||
- OR the word is unknown and can be split into article + known word
|
||||
"""
|
||||
w_lower = word.lower()
|
||||
|
||||
# Check explicit candidates
|
||||
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
|
||||
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
|
||||
if split is None:
|
||||
return None # explicitly marked as "don't split"
|
||||
article, remainder = split
|
||||
# Only split if followed by a word (noun pattern)
|
||||
if next_word and next_word[0].islower():
|
||||
return f"{article} {remainder}"
|
||||
# Also split if remainder + next_word makes a common phrase
|
||||
if next_word and self._known(next_word):
|
||||
return f"{article} {remainder}"
|
||||
|
||||
# Generic: if word starts with 'a' and rest is a known adjective/word
|
||||
if (len(word) >= 4 and word[0].lower() == 'a'
|
||||
and not self._known(word) # only for UNKNOWN words
|
||||
and self._known(word[1:])):
|
||||
return f"a {word[1:]}"
|
||||
|
||||
return None
|
||||
|
||||
# --- Full text correction ---
|
||||
|
||||
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
|
||||
"""Correct a full text string (field value).
|
||||
|
||||
Three passes:
|
||||
1. Boundary repair -- fix shifted word boundaries between adjacent tokens
|
||||
2. Context split -- split ambiguous merges (anew -> a new)
|
||||
3. Per-word correction -- spell check individual words
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return CorrectionResult(text, text, "unknown", False)
|
||||
|
||||
detected = self.detect_text_lang(text) if lang == "auto" else lang
|
||||
effective_lang = detected if detected in ("en", "de") else "en"
|
||||
|
||||
changes: List[str] = []
|
||||
tokens = list(_TOKEN_RE.finditer(text))
|
||||
|
||||
# Extract token list: [(word, separator), ...]
|
||||
token_list: List[List[str]] = [] # [[word, sep], ...]
|
||||
for m in tokens:
|
||||
token_list.append([m.group(1), m.group(2)])
|
||||
|
||||
# --- Pass 1: Boundary repair between adjacent unknown words ---
|
||||
# Import abbreviations for the heuristic below
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
|
||||
except ImportError:
|
||||
_ABBREVS = set()
|
||||
|
||||
for i in range(len(token_list) - 1):
|
||||
w1 = token_list[i][0]
|
||||
w2_raw = token_list[i + 1][0]
|
||||
|
||||
# Skip boundary repair for IPA/bracket content
|
||||
# Brackets may be in the token OR in the adjacent separators
|
||||
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
|
||||
sep_after_w1 = token_list[i][1]
|
||||
sep_after_w2 = token_list[i + 1][1]
|
||||
has_bracket = (
|
||||
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
|
||||
or ']' in sep_after_w1 # w1 text was inside [brackets]
|
||||
or '[' in sep_after_w1 # w2 starts a bracket
|
||||
or ']' in sep_after_w2 # w2 text was inside [brackets]
|
||||
or '[' in sep_before_w1 # w1 starts a bracket
|
||||
)
|
||||
if has_bracket:
|
||||
continue
|
||||
|
||||
# Include trailing punct from separator in w2 for abbreviation matching
|
||||
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
|
||||
|
||||
# Try boundary repair -- always, even if both words are valid.
|
||||
# Use word-frequency scoring to decide if repair is better.
|
||||
repair = self._try_boundary_repair(w1, w2_with_punct)
|
||||
if not repair and w2_with_punct != w2_raw:
|
||||
repair = self._try_boundary_repair(w1, w2_raw)
|
||||
if repair:
|
||||
new_w1, new_w2_full = repair
|
||||
new_w2_base = new_w2_full.rstrip(".,;:!?")
|
||||
|
||||
# Frequency-based scoring: product of word frequencies
|
||||
# Higher product = more common word pair = better
|
||||
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
|
||||
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
|
||||
|
||||
# Abbreviation bonus: if repair produces a known abbreviation
|
||||
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
|
||||
if has_abbrev:
|
||||
# Accept abbreviation repair ONLY if at least one of the
|
||||
# original words is rare/unknown (prevents "Can I" -> "Ca nI"
|
||||
# where both original words are common and correct).
|
||||
RARE_THRESHOLD = 1e-6
|
||||
orig_both_common = (
|
||||
self._word_freq(w1) > RARE_THRESHOLD
|
||||
and self._word_freq(w2_raw) > RARE_THRESHOLD
|
||||
)
|
||||
if not orig_both_common:
|
||||
new_freq = max(new_freq, old_freq * 10)
|
||||
else:
|
||||
has_abbrev = False # both originals common -> don't trust
|
||||
|
||||
# Accept if repair produces a more frequent word pair
|
||||
# (threshold: at least 5x more frequent to avoid false positives)
|
||||
if new_freq > old_freq * 5:
|
||||
new_w2_punct = new_w2_full[len(new_w2_base):]
|
||||
changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
|
||||
token_list[i][0] = new_w1
|
||||
token_list[i + 1][0] = new_w2_base
|
||||
if new_w2_punct:
|
||||
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
|
||||
|
||||
# --- Pass 2: Context split (anew -> a new) ---
|
||||
expanded: List[List[str]] = []
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
split = self._try_context_split(word, next_word, prev_word)
|
||||
if split and split != word:
|
||||
changes.append(f"{word}\u2192{split}")
|
||||
expanded.append([split, sep])
|
||||
else:
|
||||
expanded.append([word, sep])
|
||||
token_list = expanded
|
||||
|
||||
# --- Pass 3: Per-word correction ---
|
||||
parts: List[str] = []
|
||||
|
||||
# Preserve any leading text before the first token match
|
||||
first_start = tokens[0].start() if tokens else 0
|
||||
if first_start > 0:
|
||||
parts.append(text[:first_start])
|
||||
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
# Skip words inside IPA brackets (brackets land in separators)
|
||||
prev_sep = token_list[i - 1][1] if i > 0 else ""
|
||||
if '[' in prev_sep or ']' in sep:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
continue
|
||||
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
|
||||
correction = self.correct_word(
|
||||
word, lang=effective_lang,
|
||||
prev_word=prev_word, next_word=next_word,
|
||||
)
|
||||
if correction and correction != word:
|
||||
changes.append(f"{word}\u2192{correction}")
|
||||
parts.append(correction)
|
||||
else:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
|
||||
# Append any trailing text
|
||||
last_end = tokens[-1].end() if tokens else 0
|
||||
if last_end < len(text):
|
||||
parts.append(text[last_end:])
|
||||
|
||||
corrected = "".join(parts)
|
||||
return CorrectionResult(
|
||||
original=text,
|
||||
corrected=corrected,
|
||||
lang_detected=detected,
|
||||
changed=corrected != text,
|
||||
changes=changes,
|
||||
)
|
||||
|
||||
# --- Vocabulary entry correction ---
|
||||
|
||||
def correct_vocab_entry(self, english: str, german: str,
|
||||
example: str = "") -> Dict[str, CorrectionResult]:
|
||||
"""Correct a full vocabulary entry (EN + DE + example).
|
||||
|
||||
Uses column position to determine language -- the most reliable signal.
|
||||
"""
|
||||
results = {}
|
||||
results["english"] = self.correct_text(english, lang="en")
|
||||
results["german"] = self.correct_text(german, lang="de")
|
||||
if example:
|
||||
# For examples, auto-detect language
|
||||
results["example"] = self.correct_text(example, lang="auto")
|
||||
return results
|
||||
@@ -1,602 +1,29 @@
|
||||
"""
|
||||
Mobile Upload API for Klausur-Service
|
||||
Mobile Upload API — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
|
||||
upload_api_mobile — mobile HTML upload page
|
||||
|
||||
Provides chunked upload endpoints for large PDF files (100MB+) from mobile devices.
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configuration
|
||||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||||
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
|
||||
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
|
||||
|
||||
# Ensure directories exist
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# In-memory storage for upload sessions (for simplicity)
|
||||
# In production, use Redis or database
|
||||
_upload_sessions: Dict[str, dict] = {}
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
class InitUploadRequest(BaseModel):
|
||||
filename: str
|
||||
filesize: int
|
||||
chunks: int
|
||||
destination: str = "klausur" # "klausur" or "rag"
|
||||
|
||||
|
||||
class InitUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_size: int
|
||||
total_chunks: int
|
||||
message: str
|
||||
|
||||
|
||||
class ChunkUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_index: int
|
||||
received: bool
|
||||
chunks_received: int
|
||||
total_chunks: int
|
||||
|
||||
|
||||
class FinalizeResponse(BaseModel):
|
||||
upload_id: str
|
||||
filename: str
|
||||
filepath: str
|
||||
filesize: int
|
||||
checksum: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/init", response_model=InitUploadResponse)
|
||||
async def init_upload(request: InitUploadRequest):
|
||||
"""
|
||||
Initialize a chunked upload session.
|
||||
|
||||
Returns an upload_id that must be used for subsequent chunk uploads.
|
||||
"""
|
||||
upload_id = str(uuid.uuid4())
|
||||
|
||||
# Create session directory
|
||||
session_dir = CHUNK_DIR / upload_id
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store session info
|
||||
_upload_sessions[upload_id] = {
|
||||
"filename": request.filename,
|
||||
"filesize": request.filesize,
|
||||
"total_chunks": request.chunks,
|
||||
"received_chunks": set(),
|
||||
"destination": request.destination,
|
||||
"session_dir": str(session_dir),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
return InitUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_size=5 * 1024 * 1024, # 5 MB
|
||||
total_chunks=request.chunks,
|
||||
message="Upload-Session erstellt"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chunk", response_model=ChunkUploadResponse)
|
||||
async def upload_chunk(
|
||||
chunk: UploadFile = File(...),
|
||||
upload_id: str = Form(...),
|
||||
chunk_index: int = Form(...)
|
||||
):
|
||||
"""
|
||||
Upload a single chunk of a file.
|
||||
|
||||
Chunks are stored temporarily until finalize is called.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
|
||||
)
|
||||
|
||||
# Save chunk
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
|
||||
|
||||
with open(chunk_path, "wb") as f:
|
||||
content = await chunk.read()
|
||||
f.write(content)
|
||||
|
||||
# Track received chunks
|
||||
session["received_chunks"].add(chunk_index)
|
||||
|
||||
return ChunkUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_index=chunk_index,
|
||||
received=True,
|
||||
chunks_received=len(session["received_chunks"]),
|
||||
total_chunks=session["total_chunks"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/finalize", response_model=FinalizeResponse)
|
||||
async def finalize_upload(upload_id: str = Form(...)):
|
||||
"""
|
||||
Finalize the upload by combining all chunks into a single file.
|
||||
|
||||
Validates that all chunks were received and calculates checksum.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Check if all chunks received
|
||||
if len(session["received_chunks"]) != session["total_chunks"]:
|
||||
missing = session["total_chunks"] - len(session["received_chunks"])
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
|
||||
)
|
||||
|
||||
# Determine destination directory
|
||||
if session["destination"] == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = session["filename"].replace(" ", "_")
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Combine chunks
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as outfile:
|
||||
for i in range(session["total_chunks"]):
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
|
||||
|
||||
if not chunk_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Chunk {i} nicht gefunden"
|
||||
)
|
||||
|
||||
with open(chunk_path, "rb") as infile:
|
||||
data = infile.read()
|
||||
outfile.write(data)
|
||||
hasher.update(data)
|
||||
total_size += len(data)
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
checksum = hasher.hexdigest()
|
||||
|
||||
return FinalizeResponse(
|
||||
upload_id=upload_id,
|
||||
filename=final_filename,
|
||||
filepath=str(final_path),
|
||||
filesize=total_size,
|
||||
checksum=checksum,
|
||||
message="Upload erfolgreich abgeschlossen"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/simple")
|
||||
async def simple_upload(
|
||||
file: UploadFile = File(...),
|
||||
destination: str = Form("klausur")
|
||||
):
|
||||
"""
|
||||
Simple single-request upload for smaller files (<10MB).
|
||||
|
||||
For larger files, use the chunked upload endpoints.
|
||||
"""
|
||||
# Determine destination directory
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Calculate checksum while writing
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024) # Read 1MB at a time
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
return {
|
||||
"filename": final_filename,
|
||||
"filepath": str(final_path),
|
||||
"filesize": total_size,
|
||||
"checksum": hasher.hexdigest(),
|
||||
"message": "Upload erfolgreich"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status/{upload_id}")
|
||||
async def get_upload_status(upload_id: str):
|
||||
"""
|
||||
Get the status of an ongoing upload.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
return {
|
||||
"upload_id": upload_id,
|
||||
"filename": session["filename"],
|
||||
"total_chunks": session["total_chunks"],
|
||||
"received_chunks": len(session["received_chunks"]),
|
||||
"progress_percent": round(
|
||||
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
|
||||
),
|
||||
"destination": session["destination"],
|
||||
"created_at": session["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cancel/{upload_id}")
|
||||
async def cancel_upload(upload_id: str):
|
||||
"""
|
||||
Cancel an ongoing upload and clean up temporary files.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
return {"message": "Upload abgebrochen", "upload_id": upload_id}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_uploads(destination: str = "klausur"):
|
||||
"""
|
||||
List all uploaded files in the specified destination.
|
||||
"""
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
files = []
|
||||
|
||||
for f in dest_dir.iterdir():
|
||||
if f.is_file() and f.suffix.lower() == ".pdf":
|
||||
stat = f.stat()
|
||||
files.append({
|
||||
"filename": f.name,
|
||||
"size": stat.st_size,
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
})
|
||||
|
||||
files.sort(key=lambda x: x["modified"], reverse=True)
|
||||
|
||||
return {
|
||||
"destination": destination,
|
||||
"count": len(files),
|
||||
"files": files[:50] # Limit to 50 most recent
|
||||
}
|
||||
|
||||
|
||||
@router.get("/mobile", response_class=HTMLResponse)
|
||||
async def mobile_upload_page():
|
||||
"""
|
||||
Serve the mobile upload page directly from the klausur-service.
|
||||
This allows mobile devices to upload without needing the Next.js website.
|
||||
"""
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
html_content = '''<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<title>BreakPilot Upload</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
|
||||
color: white;
|
||||
min-height: 100vh;
|
||||
padding: 16px;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
border-bottom: 1px solid #334155;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.header h1 { font-size: 20px; color: #60a5fa; }
|
||||
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
|
||||
.destination-selector {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.dest-btn {
|
||||
flex: 1;
|
||||
padding: 14px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
|
||||
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
|
||||
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
|
||||
.upload-zone {
|
||||
border: 2px dashed #475569;
|
||||
border-radius: 16px;
|
||||
padding: 40px 20px;
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
transition: all 0.2s;
|
||||
position: relative;
|
||||
}
|
||||
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
|
||||
.upload-zone input[type="file"] {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
}
|
||||
.upload-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
background: #334155;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0 auto 16px;
|
||||
font-size: 28px;
|
||||
}
|
||||
.upload-title { font-size: 18px; margin-bottom: 8px; }
|
||||
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
|
||||
.upload-hint { font-size: 12px; color: #64748b; }
|
||||
.file-list { margin-bottom: 24px; }
|
||||
.file-item {
|
||||
background: #1e293b;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
|
||||
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
|
||||
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
|
||||
.file-name { font-weight: 500; word-break: break-all; }
|
||||
.file-size { font-size: 14px; color: #94a3b8; }
|
||||
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
|
||||
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
|
||||
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
|
||||
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
|
||||
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
|
||||
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
|
||||
.info-box {
|
||||
background: rgba(30, 41, 59, 0.5);
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
font-size: 14px;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
|
||||
.info-box ul { padding-left: 20px; }
|
||||
.info-box li { margin-bottom: 4px; }
|
||||
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
|
||||
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header class="header">
|
||||
<h1>BreakPilot Upload</h1>
|
||||
<span class="badge">DSGVO-konform</span>
|
||||
</header>
|
||||
|
||||
<div class="destination-selector">
|
||||
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
|
||||
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
|
||||
</div>
|
||||
|
||||
<div class="upload-zone" id="upload-zone">
|
||||
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
|
||||
<div class="upload-icon">☁</div>
|
||||
<div class="upload-title">PDF-Dateien hochladen</div>
|
||||
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
|
||||
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
|
||||
</div>
|
||||
|
||||
<div class="stats" id="stats" style="display: none;">
|
||||
<span id="completed-count">0 von 0 fertig</span>
|
||||
<span id="total-size">0 B gesamt</span>
|
||||
</div>
|
||||
|
||||
<div class="file-list" id="file-list"></div>
|
||||
|
||||
<div class="info-box">
|
||||
<h3>Hinweise:</h3>
|
||||
<ul>
|
||||
<li>Die Dateien werden lokal im WLAN uebertragen</li>
|
||||
<li>Keine Daten werden ins Internet gesendet</li>
|
||||
<li>Unterstuetzte Formate: PDF</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
|
||||
|
||||
<script>
|
||||
const CHUNK_SIZE = 5 * 1024 * 1024;
|
||||
let destination = 'klausur';
|
||||
let files = [];
|
||||
const serverUrl = window.location.origin;
|
||||
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
|
||||
|
||||
function setDestination(dest) {
|
||||
destination = dest;
|
||||
document.querySelectorAll('.dest-btn').forEach(btn => {
|
||||
btn.classList.remove('active-klausur', 'active-rag');
|
||||
});
|
||||
if (dest === 'klausur') {
|
||||
document.getElementById('btn-klausur').classList.add('active-klausur');
|
||||
} else {
|
||||
document.getElementById('btn-rag').classList.add('active-rag');
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes) {
|
||||
if (bytes === 0) return '0 B';
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
}
|
||||
|
||||
function updateStats() {
|
||||
const completed = files.filter(f => f.status === 'complete').length;
|
||||
const total = files.reduce((sum, f) => sum + f.size, 0);
|
||||
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
|
||||
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
|
||||
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
|
||||
}
|
||||
|
||||
function renderFiles() {
|
||||
const list = document.getElementById('file-list');
|
||||
list.innerHTML = files.map(f => {
|
||||
let statusHtml = '';
|
||||
if (f.status === 'uploading' || f.status === 'pending') {
|
||||
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
|
||||
} else if (f.status === 'complete') {
|
||||
statusHtml = '<div class="status-complete">✓ Erfolgreich hochgeladen</div>';
|
||||
} else if (f.status === 'error') {
|
||||
statusHtml = '<div class="status-error">⚠ ' + (f.error || 'Fehler beim Hochladen') + '</div>';
|
||||
}
|
||||
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">×</button></div>' + statusHtml + '</div>';
|
||||
}).join('');
|
||||
updateStats();
|
||||
}
|
||||
|
||||
function removeFile(id) {
|
||||
files = files.filter(f => f.id !== id);
|
||||
renderFiles();
|
||||
}
|
||||
|
||||
async function uploadFile(file, fileId) {
|
||||
const updateProgress = (progress) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.progress = progress; renderFiles(); }
|
||||
};
|
||||
const setStatus = (status, error) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
|
||||
};
|
||||
|
||||
try {
|
||||
setStatus('uploading');
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
// Chunked upload
|
||||
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
|
||||
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
|
||||
});
|
||||
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
|
||||
const { upload_id } = await initRes.json();
|
||||
|
||||
for (let i = 0; i < totalChunks; i++) {
|
||||
const start = i * CHUNK_SIZE;
|
||||
const end = Math.min(start + CHUNK_SIZE, file.size);
|
||||
const chunk = file.slice(start, end);
|
||||
const formData = new FormData();
|
||||
formData.append('chunk', chunk);
|
||||
formData.append('upload_id', upload_id);
|
||||
formData.append('chunk_index', i.toString());
|
||||
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
|
||||
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
|
||||
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
|
||||
}
|
||||
const finalizeForm = new FormData();
|
||||
finalizeForm.append('upload_id', upload_id);
|
||||
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
|
||||
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
|
||||
} else {
|
||||
// Simple upload
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('destination', destination);
|
||||
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
|
||||
if (!res.ok) throw new Error('Upload fehlgeschlagen');
|
||||
updateProgress(100);
|
||||
}
|
||||
setStatus('complete');
|
||||
} catch (e) {
|
||||
setStatus('error', e.message);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFiles(fileList) {
|
||||
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
|
||||
newFiles.forEach(file => {
|
||||
const id = Math.random().toString(36).substr(2, 9);
|
||||
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
|
||||
renderFiles();
|
||||
uploadFile(file, id);
|
||||
});
|
||||
}
|
||||
|
||||
// Drag & Drop
|
||||
const zone = document.getElementById('upload-zone');
|
||||
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
|
||||
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
|
||||
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
|
||||
</script>
|
||||
</body>
|
||||
</html>'''
|
||||
return HTMLResponse(content=html_content)
|
||||
from fastapi import APIRouter
|
||||
|
||||
from upload_api_chunked import ( # noqa: F401
|
||||
router as _chunked_router,
|
||||
UPLOAD_DIR,
|
||||
CHUNK_DIR,
|
||||
EH_UPLOAD_DIR,
|
||||
_upload_sessions,
|
||||
InitUploadRequest,
|
||||
InitUploadResponse,
|
||||
ChunkUploadResponse,
|
||||
FinalizeResponse,
|
||||
)
|
||||
from upload_api_mobile import router as _mobile_router # noqa: F401
|
||||
|
||||
# Composite router that includes both sub-routers
|
||||
router = APIRouter()
|
||||
router.include_router(_chunked_router)
|
||||
router.include_router(_mobile_router)
|
||||
|
||||
320
klausur-service/backend/upload_api_chunked.py
Normal file
320
klausur-service/backend/upload_api_chunked.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
|
||||
|
||||
Extracted from upload_api.py for modularity.
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configuration
|
||||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||||
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
|
||||
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
|
||||
|
||||
# Ensure directories exist
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# In-memory storage for upload sessions (for simplicity)
|
||||
# In production, use Redis or database
|
||||
_upload_sessions: Dict[str, dict] = {}
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
class InitUploadRequest(BaseModel):
|
||||
filename: str
|
||||
filesize: int
|
||||
chunks: int
|
||||
destination: str = "klausur" # "klausur" or "rag"
|
||||
|
||||
|
||||
class InitUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_size: int
|
||||
total_chunks: int
|
||||
message: str
|
||||
|
||||
|
||||
class ChunkUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_index: int
|
||||
received: bool
|
||||
chunks_received: int
|
||||
total_chunks: int
|
||||
|
||||
|
||||
class FinalizeResponse(BaseModel):
|
||||
upload_id: str
|
||||
filename: str
|
||||
filepath: str
|
||||
filesize: int
|
||||
checksum: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/init", response_model=InitUploadResponse)
|
||||
async def init_upload(request: InitUploadRequest):
|
||||
"""
|
||||
Initialize a chunked upload session.
|
||||
|
||||
Returns an upload_id that must be used for subsequent chunk uploads.
|
||||
"""
|
||||
upload_id = str(uuid.uuid4())
|
||||
|
||||
# Create session directory
|
||||
session_dir = CHUNK_DIR / upload_id
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store session info
|
||||
_upload_sessions[upload_id] = {
|
||||
"filename": request.filename,
|
||||
"filesize": request.filesize,
|
||||
"total_chunks": request.chunks,
|
||||
"received_chunks": set(),
|
||||
"destination": request.destination,
|
||||
"session_dir": str(session_dir),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
return InitUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_size=5 * 1024 * 1024, # 5 MB
|
||||
total_chunks=request.chunks,
|
||||
message="Upload-Session erstellt"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chunk", response_model=ChunkUploadResponse)
|
||||
async def upload_chunk(
|
||||
chunk: UploadFile = File(...),
|
||||
upload_id: str = Form(...),
|
||||
chunk_index: int = Form(...)
|
||||
):
|
||||
"""
|
||||
Upload a single chunk of a file.
|
||||
|
||||
Chunks are stored temporarily until finalize is called.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
|
||||
)
|
||||
|
||||
# Save chunk
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
|
||||
|
||||
with open(chunk_path, "wb") as f:
|
||||
content = await chunk.read()
|
||||
f.write(content)
|
||||
|
||||
# Track received chunks
|
||||
session["received_chunks"].add(chunk_index)
|
||||
|
||||
return ChunkUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_index=chunk_index,
|
||||
received=True,
|
||||
chunks_received=len(session["received_chunks"]),
|
||||
total_chunks=session["total_chunks"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/finalize", response_model=FinalizeResponse)
|
||||
async def finalize_upload(upload_id: str = Form(...)):
|
||||
"""
|
||||
Finalize the upload by combining all chunks into a single file.
|
||||
|
||||
Validates that all chunks were received and calculates checksum.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Check if all chunks received
|
||||
if len(session["received_chunks"]) != session["total_chunks"]:
|
||||
missing = session["total_chunks"] - len(session["received_chunks"])
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
|
||||
)
|
||||
|
||||
# Determine destination directory
|
||||
if session["destination"] == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = session["filename"].replace(" ", "_")
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Combine chunks
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as outfile:
|
||||
for i in range(session["total_chunks"]):
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
|
||||
|
||||
if not chunk_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Chunk {i} nicht gefunden"
|
||||
)
|
||||
|
||||
with open(chunk_path, "rb") as infile:
|
||||
data = infile.read()
|
||||
outfile.write(data)
|
||||
hasher.update(data)
|
||||
total_size += len(data)
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
checksum = hasher.hexdigest()
|
||||
|
||||
return FinalizeResponse(
|
||||
upload_id=upload_id,
|
||||
filename=final_filename,
|
||||
filepath=str(final_path),
|
||||
filesize=total_size,
|
||||
checksum=checksum,
|
||||
message="Upload erfolgreich abgeschlossen"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/simple")
|
||||
async def simple_upload(
|
||||
file: UploadFile = File(...),
|
||||
destination: str = Form("klausur")
|
||||
):
|
||||
"""
|
||||
Simple single-request upload for smaller files (<10MB).
|
||||
|
||||
For larger files, use the chunked upload endpoints.
|
||||
"""
|
||||
# Determine destination directory
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Calculate checksum while writing
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024) # Read 1MB at a time
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
return {
|
||||
"filename": final_filename,
|
||||
"filepath": str(final_path),
|
||||
"filesize": total_size,
|
||||
"checksum": hasher.hexdigest(),
|
||||
"message": "Upload erfolgreich"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status/{upload_id}")
|
||||
async def get_upload_status(upload_id: str):
|
||||
"""
|
||||
Get the status of an ongoing upload.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
return {
|
||||
"upload_id": upload_id,
|
||||
"filename": session["filename"],
|
||||
"total_chunks": session["total_chunks"],
|
||||
"received_chunks": len(session["received_chunks"]),
|
||||
"progress_percent": round(
|
||||
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
|
||||
),
|
||||
"destination": session["destination"],
|
||||
"created_at": session["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cancel/{upload_id}")
|
||||
async def cancel_upload(upload_id: str):
|
||||
"""
|
||||
Cancel an ongoing upload and clean up temporary files.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
return {"message": "Upload abgebrochen", "upload_id": upload_id}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_uploads(destination: str = "klausur"):
|
||||
"""
|
||||
List all uploaded files in the specified destination.
|
||||
"""
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
files = []
|
||||
|
||||
for f in dest_dir.iterdir():
|
||||
if f.is_file() and f.suffix.lower() == ".pdf":
|
||||
stat = f.stat()
|
||||
files.append({
|
||||
"filename": f.name,
|
||||
"size": stat.st_size,
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
})
|
||||
|
||||
files.sort(key=lambda x: x["modified"], reverse=True)
|
||||
|
||||
return {
|
||||
"destination": destination,
|
||||
"count": len(files),
|
||||
"files": files[:50] # Limit to 50 most recent
|
||||
}
|
||||
292
klausur-service/backend/upload_api_mobile.py
Normal file
292
klausur-service/backend/upload_api_mobile.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
|
||||
|
||||
Extracted from upload_api.py for modularity.
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
@router.get("/mobile", response_class=HTMLResponse)
|
||||
async def mobile_upload_page():
|
||||
"""
|
||||
Serve the mobile upload page directly from the klausur-service.
|
||||
This allows mobile devices to upload without needing the Next.js website.
|
||||
"""
|
||||
html_content = '''<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<title>BreakPilot Upload</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
|
||||
color: white;
|
||||
min-height: 100vh;
|
||||
padding: 16px;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
border-bottom: 1px solid #334155;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.header h1 { font-size: 20px; color: #60a5fa; }
|
||||
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
|
||||
.destination-selector {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.dest-btn {
|
||||
flex: 1;
|
||||
padding: 14px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
|
||||
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
|
||||
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
|
||||
.upload-zone {
|
||||
border: 2px dashed #475569;
|
||||
border-radius: 16px;
|
||||
padding: 40px 20px;
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
transition: all 0.2s;
|
||||
position: relative;
|
||||
}
|
||||
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
|
||||
.upload-zone input[type="file"] {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
}
|
||||
.upload-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
background: #334155;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0 auto 16px;
|
||||
font-size: 28px;
|
||||
}
|
||||
.upload-title { font-size: 18px; margin-bottom: 8px; }
|
||||
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
|
||||
.upload-hint { font-size: 12px; color: #64748b; }
|
||||
.file-list { margin-bottom: 24px; }
|
||||
.file-item {
|
||||
background: #1e293b;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
|
||||
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
|
||||
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
|
||||
.file-name { font-weight: 500; word-break: break-all; }
|
||||
.file-size { font-size: 14px; color: #94a3b8; }
|
||||
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
|
||||
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
|
||||
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
|
||||
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
|
||||
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
|
||||
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
|
||||
.info-box {
|
||||
background: rgba(30, 41, 59, 0.5);
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
font-size: 14px;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
|
||||
.info-box ul { padding-left: 20px; }
|
||||
.info-box li { margin-bottom: 4px; }
|
||||
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
|
||||
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header class="header">
|
||||
<h1>BreakPilot Upload</h1>
|
||||
<span class="badge">DSGVO-konform</span>
|
||||
</header>
|
||||
|
||||
<div class="destination-selector">
|
||||
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
|
||||
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
|
||||
</div>
|
||||
|
||||
<div class="upload-zone" id="upload-zone">
|
||||
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
|
||||
<div class="upload-icon">☁</div>
|
||||
<div class="upload-title">PDF-Dateien hochladen</div>
|
||||
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
|
||||
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
|
||||
</div>
|
||||
|
||||
<div class="stats" id="stats" style="display: none;">
|
||||
<span id="completed-count">0 von 0 fertig</span>
|
||||
<span id="total-size">0 B gesamt</span>
|
||||
</div>
|
||||
|
||||
<div class="file-list" id="file-list"></div>
|
||||
|
||||
<div class="info-box">
|
||||
<h3>Hinweise:</h3>
|
||||
<ul>
|
||||
<li>Die Dateien werden lokal im WLAN uebertragen</li>
|
||||
<li>Keine Daten werden ins Internet gesendet</li>
|
||||
<li>Unterstuetzte Formate: PDF</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
|
||||
|
||||
<script>
|
||||
const CHUNK_SIZE = 5 * 1024 * 1024;
|
||||
let destination = 'klausur';
|
||||
let files = [];
|
||||
const serverUrl = window.location.origin;
|
||||
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
|
||||
|
||||
function setDestination(dest) {
|
||||
destination = dest;
|
||||
document.querySelectorAll('.dest-btn').forEach(btn => {
|
||||
btn.classList.remove('active-klausur', 'active-rag');
|
||||
});
|
||||
if (dest === 'klausur') {
|
||||
document.getElementById('btn-klausur').classList.add('active-klausur');
|
||||
} else {
|
||||
document.getElementById('btn-rag').classList.add('active-rag');
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes) {
|
||||
if (bytes === 0) return '0 B';
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
}
|
||||
|
||||
function updateStats() {
|
||||
const completed = files.filter(f => f.status === 'complete').length;
|
||||
const total = files.reduce((sum, f) => sum + f.size, 0);
|
||||
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
|
||||
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
|
||||
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
|
||||
}
|
||||
|
||||
function renderFiles() {
|
||||
const list = document.getElementById('file-list');
|
||||
list.innerHTML = files.map(f => {
|
||||
let statusHtml = '';
|
||||
if (f.status === 'uploading' || f.status === 'pending') {
|
||||
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
|
||||
} else if (f.status === 'complete') {
|
||||
statusHtml = '<div class="status-complete">✓ Erfolgreich hochgeladen</div>';
|
||||
} else if (f.status === 'error') {
|
||||
statusHtml = '<div class="status-error">⚠ ' + (f.error || 'Fehler beim Hochladen') + '</div>';
|
||||
}
|
||||
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">×</button></div>' + statusHtml + '</div>';
|
||||
}).join('');
|
||||
updateStats();
|
||||
}
|
||||
|
||||
function removeFile(id) {
|
||||
files = files.filter(f => f.id !== id);
|
||||
renderFiles();
|
||||
}
|
||||
|
||||
async function uploadFile(file, fileId) {
|
||||
const updateProgress = (progress) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.progress = progress; renderFiles(); }
|
||||
};
|
||||
const setStatus = (status, error) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
|
||||
};
|
||||
|
||||
try {
|
||||
setStatus('uploading');
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
// Chunked upload
|
||||
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
|
||||
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
|
||||
});
|
||||
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
|
||||
const { upload_id } = await initRes.json();
|
||||
|
||||
for (let i = 0; i < totalChunks; i++) {
|
||||
const start = i * CHUNK_SIZE;
|
||||
const end = Math.min(start + CHUNK_SIZE, file.size);
|
||||
const chunk = file.slice(start, end);
|
||||
const formData = new FormData();
|
||||
formData.append('chunk', chunk);
|
||||
formData.append('upload_id', upload_id);
|
||||
formData.append('chunk_index', i.toString());
|
||||
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
|
||||
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
|
||||
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
|
||||
}
|
||||
const finalizeForm = new FormData();
|
||||
finalizeForm.append('upload_id', upload_id);
|
||||
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
|
||||
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
|
||||
} else {
|
||||
// Simple upload
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('destination', destination);
|
||||
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
|
||||
if (!res.ok) throw new Error('Upload fehlgeschlagen');
|
||||
updateProgress(100);
|
||||
}
|
||||
setStatus('complete');
|
||||
} catch (e) {
|
||||
setStatus('error', e.message);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFiles(fileList) {
|
||||
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
|
||||
newFiles.forEach(file => {
|
||||
const id = Math.random().toString(36).substr(2, 9);
|
||||
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
|
||||
renderFiles();
|
||||
uploadFile(file, id);
|
||||
});
|
||||
}
|
||||
|
||||
// Drag & Drop
|
||||
const zone = document.getElementById('upload-zone');
|
||||
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
|
||||
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
|
||||
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
|
||||
</script>
|
||||
</body>
|
||||
</html>'''
|
||||
return HTMLResponse(content=html_content)
|
||||
@@ -1,537 +1,19 @@
|
||||
"""
|
||||
Zeugnis Rights-Aware Crawler - API Endpoints
|
||||
Zeugnis Rights-Aware Crawler — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
zeugnis_api_sources — sources, seed URLs, initialization
|
||||
zeugnis_api_docs — documents, crawler, statistics, audit
|
||||
|
||||
FastAPI router for managing zeugnis sources, documents, and crawler operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter
|
||||
|
||||
from zeugnis_models import (
|
||||
ZeugnisSource, ZeugnisSourceCreate, ZeugnisSourceVerify,
|
||||
SeedUrl, SeedUrlCreate,
|
||||
ZeugnisDocument, ZeugnisStats,
|
||||
CrawlerStatus, CrawlRequest, CrawlQueueItem,
|
||||
UsageEvent, AuditExport,
|
||||
LicenseType, CrawlStatus, DocType, EventType,
|
||||
BUNDESLAENDER, TRAINING_PERMISSIONS,
|
||||
generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
|
||||
)
|
||||
from zeugnis_crawler import (
|
||||
start_crawler, stop_crawler, get_crawler_status,
|
||||
)
|
||||
from metrics_db import (
|
||||
get_zeugnis_sources, upsert_zeugnis_source,
|
||||
get_zeugnis_documents, get_zeugnis_stats,
|
||||
log_zeugnis_event, get_pool,
|
||||
)
|
||||
from zeugnis_api_sources import router as _sources_router # noqa: F401
|
||||
from zeugnis_api_docs import router as _docs_router # noqa: F401
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sources Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/sources", response_model=List[dict])
|
||||
async def list_sources():
|
||||
"""Get all zeugnis sources (Bundesländer)."""
|
||||
sources = await get_zeugnis_sources()
|
||||
if not sources:
|
||||
# Return default sources if none exist
|
||||
return [
|
||||
{
|
||||
"id": None,
|
||||
"bundesland": code,
|
||||
"name": info["name"],
|
||||
"base_url": None,
|
||||
"license_type": str(get_license_for_bundesland(code).value),
|
||||
"training_allowed": get_training_allowed(code),
|
||||
"verified_by": None,
|
||||
"verified_at": None,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
for code, info in BUNDESLAENDER.items()
|
||||
]
|
||||
return sources
|
||||
|
||||
|
||||
@router.post("/sources", response_model=dict)
|
||||
async def create_source(source: ZeugnisSourceCreate):
|
||||
"""Create or update a zeugnis source."""
|
||||
source_id = generate_id()
|
||||
success = await upsert_zeugnis_source(
|
||||
id=source_id,
|
||||
bundesland=source.bundesland,
|
||||
name=source.name,
|
||||
license_type=source.license_type.value,
|
||||
training_allowed=source.training_allowed,
|
||||
base_url=source.base_url,
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create source")
|
||||
return {"id": source_id, "success": True}
|
||||
|
||||
|
||||
@router.put("/sources/{source_id}/verify", response_model=dict)
|
||||
async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
|
||||
"""Verify a source's license status."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE zeugnis_sources
|
||||
SET license_type = $2,
|
||||
training_allowed = $3,
|
||||
verified_by = $4,
|
||||
verified_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
source_id, verification.license_type.value,
|
||||
verification.training_allowed, verification.verified_by
|
||||
)
|
||||
return {"success": True, "source_id": source_id}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/sources/{bundesland}", response_model=dict)
|
||||
async def get_source_by_bundesland(bundesland: str):
|
||||
"""Get source details for a specific Bundesland."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
# Return default info
|
||||
if bundesland not in BUNDESLAENDER:
|
||||
raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
|
||||
return {
|
||||
"bundesland": bundesland,
|
||||
"name": get_bundesland_name(bundesland),
|
||||
"training_allowed": get_training_allowed(bundesland),
|
||||
"license_type": get_license_for_bundesland(bundesland).value,
|
||||
"document_count": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
source = await conn.fetchrow(
|
||||
"SELECT * FROM zeugnis_sources WHERE bundesland = $1",
|
||||
bundesland
|
||||
)
|
||||
if source:
|
||||
doc_count = await conn.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*) FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
WHERE u.source_id = $1
|
||||
""",
|
||||
source["id"]
|
||||
)
|
||||
return {**dict(source), "document_count": doc_count or 0}
|
||||
|
||||
# Return default
|
||||
return {
|
||||
"bundesland": bundesland,
|
||||
"name": get_bundesland_name(bundesland),
|
||||
"training_allowed": get_training_allowed(bundesland),
|
||||
"license_type": get_license_for_bundesland(bundesland).value,
|
||||
"document_count": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Seed URLs Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/sources/{source_id}/urls", response_model=List[dict])
|
||||
async def list_seed_urls(source_id: str):
|
||||
"""Get all seed URLs for a source."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
|
||||
source_id
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/sources/{source_id}/urls", response_model=dict)
|
||||
async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
|
||||
"""Add a new seed URL to a source."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
url_id = generate_id()
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
|
||||
VALUES ($1, $2, $3, $4, 'pending')
|
||||
""",
|
||||
url_id, source_id, seed_url.url, seed_url.doc_type.value
|
||||
)
|
||||
return {"id": url_id, "success": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/urls/{url_id}", response_model=dict)
|
||||
async def delete_seed_url(url_id: str):
|
||||
"""Delete a seed URL."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM zeugnis_seed_urls WHERE id = $1",
|
||||
url_id
|
||||
)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Documents Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/documents", response_model=List[dict])
|
||||
async def list_documents(
|
||||
bundesland: Optional[str] = None,
|
||||
limit: int = Query(100, le=500),
|
||||
offset: int = 0,
|
||||
):
|
||||
"""Get all zeugnis documents with optional filtering."""
|
||||
documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
|
||||
return documents
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=dict)
|
||||
async def get_document(document_id: str):
|
||||
"""Get details for a specific document."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
doc = await conn.fetchrow(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
WHERE d.id = $1
|
||||
""",
|
||||
document_id
|
||||
)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Log view event
|
||||
await log_zeugnis_event(document_id, EventType.VIEWED.value)
|
||||
|
||||
return dict(doc)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/versions", response_model=List[dict])
|
||||
async def get_document_versions(document_id: str):
|
||||
"""Get version history for a document."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT * FROM zeugnis_document_versions
|
||||
WHERE document_id = $1
|
||||
ORDER BY version DESC
|
||||
""",
|
||||
document_id
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler Control Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/crawler/status", response_model=dict)
|
||||
async def crawler_status():
|
||||
"""Get current crawler status."""
|
||||
return get_crawler_status()
|
||||
|
||||
|
||||
@router.post("/crawler/start", response_model=dict)
|
||||
async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
|
||||
"""Start the crawler."""
|
||||
success = await start_crawler(
|
||||
bundesland=request.bundesland,
|
||||
source_id=request.source_id,
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Crawler already running")
|
||||
return {"success": True, "message": "Crawler started"}
|
||||
|
||||
|
||||
@router.post("/crawler/stop", response_model=dict)
|
||||
async def stop_crawl():
|
||||
"""Stop the crawler."""
|
||||
success = await stop_crawler()
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Crawler not running")
|
||||
return {"success": True, "message": "Crawler stopped"}
|
||||
|
||||
|
||||
@router.get("/crawler/queue", response_model=List[dict])
|
||||
async def get_queue():
|
||||
"""Get the crawler queue."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT q.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_crawler_queue q
|
||||
JOIN zeugnis_sources s ON q.source_id = s.id
|
||||
ORDER BY q.priority DESC, q.created_at
|
||||
"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/crawler/queue", response_model=dict)
|
||||
async def add_to_queue(request: CrawlRequest):
|
||||
"""Add a source to the crawler queue."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
queue_id = generate_id()
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
# Get source ID if bundesland provided
|
||||
source_id = request.source_id
|
||||
if not source_id and request.bundesland:
|
||||
source = await conn.fetchrow(
|
||||
"SELECT id FROM zeugnis_sources WHERE bundesland = $1",
|
||||
request.bundesland
|
||||
)
|
||||
if source:
|
||||
source_id = source["id"]
|
||||
|
||||
if not source_id:
|
||||
raise HTTPException(status_code=400, detail="Source not found")
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
|
||||
VALUES ($1, $2, $3, 'pending')
|
||||
""",
|
||||
queue_id, source_id, request.priority
|
||||
)
|
||||
return {"id": queue_id, "success": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Statistics Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_stats():
|
||||
"""Get zeugnis crawler statistics."""
|
||||
stats = await get_zeugnis_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/stats/bundesland", response_model=List[dict])
|
||||
async def get_bundesland_stats():
|
||||
"""Get statistics per Bundesland."""
|
||||
pool = await get_pool()
|
||||
|
||||
# Build stats from BUNDESLAENDER with DB data if available
|
||||
stats = []
|
||||
for code, info in BUNDESLAENDER.items():
|
||||
stat = {
|
||||
"bundesland": code,
|
||||
"name": info["name"],
|
||||
"training_allowed": get_training_allowed(code),
|
||||
"document_count": 0,
|
||||
"indexed_count": 0,
|
||||
"last_crawled": None,
|
||||
}
|
||||
|
||||
if pool:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(d.id) as doc_count,
|
||||
COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
|
||||
MAX(u.last_crawled) as last_crawled
|
||||
FROM zeugnis_sources s
|
||||
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
|
||||
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
|
||||
WHERE s.bundesland = $1
|
||||
GROUP BY s.id
|
||||
""",
|
||||
code
|
||||
)
|
||||
if row:
|
||||
stat["document_count"] = row["doc_count"] or 0
|
||||
stat["indexed_count"] = row["indexed_count"] or 0
|
||||
stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
stats.append(stat)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Audit Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/audit/events", response_model=List[dict])
|
||||
async def get_audit_events(
|
||||
document_id: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = Query(100, le=1000),
|
||||
days: int = Query(30, le=365),
|
||||
):
|
||||
"""Get audit events with optional filtering."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
return []
|
||||
|
||||
try:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
async with pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT * FROM zeugnis_usage_events
|
||||
WHERE created_at >= $1
|
||||
"""
|
||||
params = [since]
|
||||
|
||||
if document_id:
|
||||
query += " AND document_id = $2"
|
||||
params.append(document_id)
|
||||
if event_type:
|
||||
query += f" AND event_type = ${len(params) + 1}"
|
||||
params.append(event_type)
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
|
||||
params.append(limit)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/audit/export", response_model=dict)
|
||||
async def export_audit(
|
||||
days: int = Query(30, le=365),
|
||||
requested_by: str = Query(..., description="User requesting the export"),
|
||||
):
|
||||
"""Export audit data for GDPR compliance."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT * FROM zeugnis_usage_events
|
||||
WHERE created_at >= $1
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
since
|
||||
)
|
||||
|
||||
doc_count = await conn.fetchval(
|
||||
"SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
|
||||
since
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.now().isoformat(),
|
||||
"requested_by": requested_by,
|
||||
"events": [dict(r) for r in rows],
|
||||
"document_count": doc_count or 0,
|
||||
"date_range_start": since.isoformat(),
|
||||
"date_range_end": datetime.now().isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Initialization Endpoint
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/init", response_model=dict)
|
||||
async def initialize_sources():
|
||||
"""Initialize default sources from BUNDESLAENDER."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
created = 0
|
||||
try:
|
||||
for code, info in BUNDESLAENDER.items():
|
||||
source_id = generate_id()
|
||||
success = await upsert_zeugnis_source(
|
||||
id=source_id,
|
||||
bundesland=code,
|
||||
name=info["name"],
|
||||
license_type=get_license_for_bundesland(code).value,
|
||||
training_allowed=get_training_allowed(code),
|
||||
)
|
||||
if success:
|
||||
created += 1
|
||||
|
||||
return {"success": True, "sources_created": created}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Composite router (used by main.py)
|
||||
router = APIRouter()
|
||||
router.include_router(_sources_router)
|
||||
router.include_router(_docs_router)
|
||||
|
||||
321
klausur-service/backend/zeugnis_api_docs.py
Normal file
321
klausur-service/backend/zeugnis_api_docs.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Zeugnis API Docs — documents, crawler control, statistics, audit endpoints.
|
||||
|
||||
Extracted from zeugnis_api.py for modularity.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
|
||||
|
||||
from zeugnis_models import (
|
||||
CrawlRequest, EventType,
|
||||
BUNDESLAENDER,
|
||||
generate_id, get_training_allowed, get_license_for_bundesland,
|
||||
)
|
||||
from zeugnis_crawler import (
|
||||
start_crawler, stop_crawler, get_crawler_status,
|
||||
)
|
||||
from metrics_db import (
|
||||
get_zeugnis_documents, get_zeugnis_stats,
|
||||
log_zeugnis_event, get_pool,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Documents Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/documents", response_model=List[dict])
|
||||
async def list_documents(
|
||||
bundesland: Optional[str] = None,
|
||||
limit: int = Query(100, le=500),
|
||||
offset: int = 0,
|
||||
):
|
||||
"""Get all zeugnis documents with optional filtering."""
|
||||
documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
|
||||
return documents
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=dict)
|
||||
async def get_document(document_id: str):
|
||||
"""Get details for a specific document."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
doc = await conn.fetchrow(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
WHERE d.id = $1
|
||||
""",
|
||||
document_id
|
||||
)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Log view event
|
||||
await log_zeugnis_event(document_id, EventType.VIEWED.value)
|
||||
|
||||
return dict(doc)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/versions", response_model=List[dict])
|
||||
async def get_document_versions(document_id: str):
|
||||
"""Get version history for a document."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT * FROM zeugnis_document_versions
|
||||
WHERE document_id = $1
|
||||
ORDER BY version DESC
|
||||
""",
|
||||
document_id
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler Control Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/crawler/status", response_model=dict)
|
||||
async def crawler_status():
|
||||
"""Get current crawler status."""
|
||||
return get_crawler_status()
|
||||
|
||||
|
||||
@router.post("/crawler/start", response_model=dict)
|
||||
async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
|
||||
"""Start the crawler."""
|
||||
success = await start_crawler(
|
||||
bundesland=request.bundesland,
|
||||
source_id=request.source_id,
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Crawler already running")
|
||||
return {"success": True, "message": "Crawler started"}
|
||||
|
||||
|
||||
@router.post("/crawler/stop", response_model=dict)
|
||||
async def stop_crawl():
|
||||
"""Stop the crawler."""
|
||||
success = await stop_crawler()
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Crawler not running")
|
||||
return {"success": True, "message": "Crawler stopped"}
|
||||
|
||||
|
||||
@router.get("/crawler/queue", response_model=List[dict])
|
||||
async def get_queue():
|
||||
"""Get the crawler queue."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT q.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_crawler_queue q
|
||||
JOIN zeugnis_sources s ON q.source_id = s.id
|
||||
ORDER BY q.priority DESC, q.created_at
|
||||
"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/crawler/queue", response_model=dict)
|
||||
async def add_to_queue(request: CrawlRequest):
|
||||
"""Add a source to the crawler queue."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
queue_id = generate_id()
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
# Get source ID if bundesland provided
|
||||
source_id = request.source_id
|
||||
if not source_id and request.bundesland:
|
||||
source = await conn.fetchrow(
|
||||
"SELECT id FROM zeugnis_sources WHERE bundesland = $1",
|
||||
request.bundesland
|
||||
)
|
||||
if source:
|
||||
source_id = source["id"]
|
||||
|
||||
if not source_id:
|
||||
raise HTTPException(status_code=400, detail="Source not found")
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
|
||||
VALUES ($1, $2, $3, 'pending')
|
||||
""",
|
||||
queue_id, source_id, request.priority
|
||||
)
|
||||
return {"id": queue_id, "success": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Statistics Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_stats():
|
||||
"""Get zeugnis crawler statistics."""
|
||||
stats = await get_zeugnis_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/stats/bundesland", response_model=List[dict])
|
||||
async def get_bundesland_stats():
|
||||
"""Get statistics per Bundesland."""
|
||||
pool = await get_pool()
|
||||
|
||||
# Build stats from BUNDESLAENDER with DB data if available
|
||||
stats = []
|
||||
for code, info in BUNDESLAENDER.items():
|
||||
stat = {
|
||||
"bundesland": code,
|
||||
"name": info["name"],
|
||||
"training_allowed": get_training_allowed(code),
|
||||
"document_count": 0,
|
||||
"indexed_count": 0,
|
||||
"last_crawled": None,
|
||||
}
|
||||
|
||||
if pool:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(d.id) as doc_count,
|
||||
COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
|
||||
MAX(u.last_crawled) as last_crawled
|
||||
FROM zeugnis_sources s
|
||||
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
|
||||
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
|
||||
WHERE s.bundesland = $1
|
||||
GROUP BY s.id
|
||||
""",
|
||||
code
|
||||
)
|
||||
if row:
|
||||
stat["document_count"] = row["doc_count"] or 0
|
||||
stat["indexed_count"] = row["indexed_count"] or 0
|
||||
stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
stats.append(stat)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Audit Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/audit/events", response_model=List[dict])
|
||||
async def get_audit_events(
|
||||
document_id: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = Query(100, le=1000),
|
||||
days: int = Query(30, le=365),
|
||||
):
|
||||
"""Get audit events with optional filtering."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
return []
|
||||
|
||||
try:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
async with pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT * FROM zeugnis_usage_events
|
||||
WHERE created_at >= $1
|
||||
"""
|
||||
params = [since]
|
||||
|
||||
if document_id:
|
||||
query += " AND document_id = $2"
|
||||
params.append(document_id)
|
||||
if event_type:
|
||||
query += f" AND event_type = ${len(params) + 1}"
|
||||
params.append(event_type)
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
|
||||
params.append(limit)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/audit/export", response_model=dict)
|
||||
async def export_audit(
|
||||
days: int = Query(30, le=365),
|
||||
requested_by: str = Query(..., description="User requesting the export"),
|
||||
):
|
||||
"""Export audit data for GDPR compliance."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT * FROM zeugnis_usage_events
|
||||
WHERE created_at >= $1
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
since
|
||||
)
|
||||
|
||||
doc_count = await conn.fetchval(
|
||||
"SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
|
||||
since
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.now().isoformat(),
|
||||
"requested_by": requested_by,
|
||||
"events": [dict(r) for r in rows],
|
||||
"document_count": doc_count or 0,
|
||||
"date_range_start": since.isoformat(),
|
||||
"date_range_end": datetime.now().isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
232
klausur-service/backend/zeugnis_api_sources.py
Normal file
232
klausur-service/backend/zeugnis_api_sources.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Zeugnis API Sources — source and seed URL management endpoints.
|
||||
|
||||
Extracted from zeugnis_api.py for modularity.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zeugnis_models import (
|
||||
ZeugnisSourceCreate, ZeugnisSourceVerify,
|
||||
SeedUrlCreate,
|
||||
LicenseType, DocType,
|
||||
BUNDESLAENDER,
|
||||
generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
|
||||
)
|
||||
from metrics_db import (
|
||||
get_zeugnis_sources, upsert_zeugnis_source, get_pool,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sources Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/sources", response_model=List[dict])
|
||||
async def list_sources():
|
||||
"""Get all zeugnis sources (Bundeslaender)."""
|
||||
sources = await get_zeugnis_sources()
|
||||
if not sources:
|
||||
# Return default sources if none exist
|
||||
return [
|
||||
{
|
||||
"id": None,
|
||||
"bundesland": code,
|
||||
"name": info["name"],
|
||||
"base_url": None,
|
||||
"license_type": str(get_license_for_bundesland(code).value),
|
||||
"training_allowed": get_training_allowed(code),
|
||||
"verified_by": None,
|
||||
"verified_at": None,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
for code, info in BUNDESLAENDER.items()
|
||||
]
|
||||
return sources
|
||||
|
||||
|
||||
@router.post("/sources", response_model=dict)
|
||||
async def create_source(source: ZeugnisSourceCreate):
|
||||
"""Create or update a zeugnis source."""
|
||||
source_id = generate_id()
|
||||
success = await upsert_zeugnis_source(
|
||||
id=source_id,
|
||||
bundesland=source.bundesland,
|
||||
name=source.name,
|
||||
license_type=source.license_type.value,
|
||||
training_allowed=source.training_allowed,
|
||||
base_url=source.base_url,
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create source")
|
||||
return {"id": source_id, "success": True}
|
||||
|
||||
|
||||
@router.put("/sources/{source_id}/verify", response_model=dict)
|
||||
async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
|
||||
"""Verify a source's license status."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE zeugnis_sources
|
||||
SET license_type = $2,
|
||||
training_allowed = $3,
|
||||
verified_by = $4,
|
||||
verified_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
source_id, verification.license_type.value,
|
||||
verification.training_allowed, verification.verified_by
|
||||
)
|
||||
return {"success": True, "source_id": source_id}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/sources/{bundesland}", response_model=dict)
|
||||
async def get_source_by_bundesland(bundesland: str):
|
||||
"""Get source details for a specific Bundesland."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
# Return default info
|
||||
if bundesland not in BUNDESLAENDER:
|
||||
raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
|
||||
return {
|
||||
"bundesland": bundesland,
|
||||
"name": get_bundesland_name(bundesland),
|
||||
"training_allowed": get_training_allowed(bundesland),
|
||||
"license_type": get_license_for_bundesland(bundesland).value,
|
||||
"document_count": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
source = await conn.fetchrow(
|
||||
"SELECT * FROM zeugnis_sources WHERE bundesland = $1",
|
||||
bundesland
|
||||
)
|
||||
if source:
|
||||
doc_count = await conn.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*) FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
WHERE u.source_id = $1
|
||||
""",
|
||||
source["id"]
|
||||
)
|
||||
return {**dict(source), "document_count": doc_count or 0}
|
||||
|
||||
# Return default
|
||||
return {
|
||||
"bundesland": bundesland,
|
||||
"name": get_bundesland_name(bundesland),
|
||||
"training_allowed": get_training_allowed(bundesland),
|
||||
"license_type": get_license_for_bundesland(bundesland).value,
|
||||
"document_count": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Seed URLs Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/sources/{source_id}/urls", response_model=List[dict])
|
||||
async def list_seed_urls(source_id: str):
|
||||
"""Get all seed URLs for a source."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
|
||||
source_id
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/sources/{source_id}/urls", response_model=dict)
|
||||
async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
|
||||
"""Add a new seed URL to a source."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
url_id = generate_id()
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
|
||||
VALUES ($1, $2, $3, $4, 'pending')
|
||||
""",
|
||||
url_id, source_id, seed_url.url, seed_url.doc_type.value
|
||||
)
|
||||
return {"id": url_id, "success": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/urls/{url_id}", response_model=dict)
|
||||
async def delete_seed_url(url_id: str):
|
||||
"""Delete a seed URL."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM zeugnis_seed_urls WHERE id = $1",
|
||||
url_id
|
||||
)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Initialization Endpoint
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/init", response_model=dict)
|
||||
async def initialize_sources():
|
||||
"""Initialize default sources from BUNDESLAENDER."""
|
||||
pool = await get_pool()
|
||||
if not pool:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
created = 0
|
||||
try:
|
||||
for code, info in BUNDESLAENDER.items():
|
||||
source_id = generate_id()
|
||||
success = await upsert_zeugnis_source(
|
||||
id=source_id,
|
||||
bundesland=code,
|
||||
name=info["name"],
|
||||
license_type=get_license_for_bundesland(code).value,
|
||||
training_allowed=get_training_allowed(code),
|
||||
)
|
||||
if success:
|
||||
created += 1
|
||||
|
||||
return {"success": True, "sources_created": created}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user