From b2a0126f14d631f25de9928efd57cb387d3ba499 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Fri, 24 Apr 2026 22:47:59 +0200 Subject: [PATCH] [split-required] Split remaining Python monoliths (Phase 1 continued) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit klausur-service (7 monoliths): - grid_editor_helpers.py (1,737 → 5 files: columns, filters, headers, zones) - cv_cell_grid.py (1,675 → 7 files: build, legacy, streaming, merge, vocab) - worksheet_editor_api.py (1,305 → 4 files: models, AI, reconstruct, routes) - legal_corpus_ingestion.py (1,280 → 3 files: registry, chunking, ingestion) - cv_review.py (1,248 → 4 files: pipeline, spell, LLM, barrel) - cv_preprocessing.py (1,166 → 3 files: deskew, dewarp, barrel) - rbac.py, admin_api.py, routes/eh.py remain (next batch) backend-lehrer (1 monolith): - classroom_engine/repository.py (1,705 → 7 files by domain) All re-export barrels preserve backward compatibility. Zero import errors verified. Co-Authored-By: Claude Opus 4.6 (1M context) --- .claude/rules/loc-exceptions.txt | 1 + backend-lehrer/classroom_engine/repository.py | 1728 +--------------- .../classroom_engine/repository_context.py | 453 +++++ .../classroom_engine/repository_feedback.py | 182 ++ .../classroom_engine/repository_homework.py | 382 ++++ .../classroom_engine/repository_reflection.py | 315 +++ .../classroom_engine/repository_session.py | 248 +++ .../classroom_engine/repository_template.py | 167 ++ klausur-service/backend/cv_cell_grid.py | 1707 +--------------- klausur-service/backend/cv_cell_grid_build.py | 498 +++++ .../backend/cv_cell_grid_helpers.py | 136 ++ .../backend/cv_cell_grid_legacy.py | 436 ++++ klausur-service/backend/cv_cell_grid_merge.py | 235 +++ .../backend/cv_cell_grid_streaming.py | 217 ++ klausur-service/backend/cv_cell_grid_vocab.py | 200 ++ klausur-service/backend/cv_preprocessing.py | 1089 +--------- .../backend/cv_preprocessing_deskew.py | 437 ++++ .../backend/cv_preprocessing_dewarp.py | 474 +++++ klausur-service/backend/cv_review.py | 1276 +----------- klausur-service/backend/cv_review_llm.py | 388 ++++ klausur-service/backend/cv_review_pipeline.py | 430 ++++ klausur-service/backend/cv_review_spell.py | 315 +++ .../backend/grid_editor_columns.py | 492 +++++ .../backend/grid_editor_filters.py | 402 ++++ .../backend/grid_editor_headers.py | 499 +++++ .../backend/grid_editor_helpers.py | 1781 +---------------- klausur-service/backend/grid_editor_zones.py | 389 ++++ .../backend/legal_corpus_chunking.py | 197 ++ .../backend/legal_corpus_ingestion.py | 844 +------- .../backend/legal_corpus_registry.py | 608 ++++++ .../backend/worksheet_editor_ai.py | 485 +++++ .../backend/worksheet_editor_api.py | 1029 +--------- .../backend/worksheet_editor_models.py | 133 ++ .../backend/worksheet_editor_reconstruct.py | 255 +++ 34 files changed, 9264 insertions(+), 9164 deletions(-) create mode 100644 backend-lehrer/classroom_engine/repository_context.py create mode 100644 backend-lehrer/classroom_engine/repository_feedback.py create mode 100644 backend-lehrer/classroom_engine/repository_homework.py create mode 100644 backend-lehrer/classroom_engine/repository_reflection.py create mode 100644 backend-lehrer/classroom_engine/repository_session.py create mode 100644 backend-lehrer/classroom_engine/repository_template.py create mode 100644 klausur-service/backend/cv_cell_grid_build.py create mode 100644 klausur-service/backend/cv_cell_grid_helpers.py create mode 100644 klausur-service/backend/cv_cell_grid_legacy.py create mode 100644 klausur-service/backend/cv_cell_grid_merge.py create mode 100644 klausur-service/backend/cv_cell_grid_streaming.py create mode 100644 klausur-service/backend/cv_cell_grid_vocab.py create mode 100644 klausur-service/backend/cv_preprocessing_deskew.py create mode 100644 klausur-service/backend/cv_preprocessing_dewarp.py create mode 100644 klausur-service/backend/cv_review_llm.py create mode 100644 klausur-service/backend/cv_review_pipeline.py create mode 100644 klausur-service/backend/cv_review_spell.py create mode 100644 klausur-service/backend/grid_editor_columns.py create mode 100644 klausur-service/backend/grid_editor_filters.py create mode 100644 klausur-service/backend/grid_editor_headers.py create mode 100644 klausur-service/backend/grid_editor_zones.py create mode 100644 klausur-service/backend/legal_corpus_chunking.py create mode 100644 klausur-service/backend/legal_corpus_registry.py create mode 100644 klausur-service/backend/worksheet_editor_ai.py create mode 100644 klausur-service/backend/worksheet_editor_models.py create mode 100644 klausur-service/backend/worksheet_editor_reconstruct.py diff --git a/.claude/rules/loc-exceptions.txt b/.claude/rules/loc-exceptions.txt index 76b4aa7..c3cced7 100644 --- a/.claude/rules/loc-exceptions.txt +++ b/.claude/rules/loc-exceptions.txt @@ -17,6 +17,7 @@ # Pure Data Registries (keine Logik, nur Daten-Definitionen) **/dsfa_sources_registry.py | owner=klausur | reason=Pure data registry (license + source definitions, no logic) | review=2027-01-01 +**/legal_corpus_registry.py | owner=klausur | reason=Pure data registry (Regulation dataclass + 47 regulation definitions, no logic) | review=2027-01-01 **/backlog/backlog-items.ts | owner=admin-lehrer | reason=Pure data array (506 LOC, no logic, only BacklogItem[] literals) | review=2027-01-01 **/lib/module-registry-data.ts | owner=admin-lehrer | reason=Pure data array (510 LOC, no logic, only BackendModule[] literals) | review=2027-01-01 diff --git a/backend-lehrer/classroom_engine/repository.py b/backend-lehrer/classroom_engine/repository.py index 66a0302..777d714 100644 --- a/backend-lehrer/classroom_engine/repository.py +++ b/backend-lehrer/classroom_engine/repository.py @@ -1,1705 +1,33 @@ """ -Session Repository - CRUD Operationen fuer Classroom Sessions (Feature f14). +Session Repository - Re-export Hub. -Abstraktion der Datenbank-Operationen fuer LessonSessions. +Alle Repository-Klassen werden aus ihren Domain-Modulen re-exportiert, +damit bestehende Imports unveraendert funktionieren: + + from classroom_engine.repository import SessionRepository + from .repository import TeacherContextRepository """ -from datetime import datetime -from typing import Optional, List, Dict, Any -from sqlalchemy.orm import Session as DBSession - -from .db_models import ( - LessonSessionDB, PhaseHistoryDB, LessonTemplateDB, TeacherSettingsDB, - LessonPhaseEnum, HomeworkDB, HomeworkStatusEnum, PhaseMaterialDB, MaterialTypeEnum, - LessonReflectionDB, TeacherFeedbackDB, FeedbackTypeEnum, FeedbackStatusEnum, - FeedbackPriorityEnum -) -from .context_models import ( - TeacherContextDB, SchoolyearEventDB, RecurringRoutineDB, - MacroPhaseEnum, EventTypeEnum, EventStatusEnum, - RoutineTypeEnum, RecurrencePatternEnum, - FEDERAL_STATES, SCHOOL_TYPES -) -from .models import ( - LessonSession, LessonTemplate, LessonPhase, Homework, HomeworkStatus, - PhaseMaterial, MaterialType, get_default_durations -) -from .analytics import ( - LessonReflection, SessionSummary, TeacherAnalytics, AnalyticsCalculator +from .repository_session import SessionRepository, TeacherSettingsRepository +from .repository_template import TemplateRepository +from .repository_homework import HomeworkRepository, MaterialRepository +from .repository_reflection import ReflectionRepository, AnalyticsRepository +from .repository_feedback import TeacherFeedbackRepository +from .repository_context import ( + TeacherContextRepository, + SchoolyearEventRepository, + RecurringRoutineRepository, ) - -class SessionRepository: - """Repository fuer LessonSession CRUD-Operationen.""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create(self, session: LessonSession) -> LessonSessionDB: - """ - Erstellt eine neue Session in der Datenbank. - - Args: - session: LessonSession Dataclass - - Returns: - LessonSessionDB Model - """ - db_session = LessonSessionDB( - id=session.session_id, - teacher_id=session.teacher_id, - class_id=session.class_id, - subject=session.subject, - topic=session.topic, - current_phase=LessonPhaseEnum(session.current_phase.value), - is_paused=session.is_paused, - lesson_started_at=session.lesson_started_at, - lesson_ended_at=session.lesson_ended_at, - phase_started_at=session.phase_started_at, - pause_started_at=session.pause_started_at, - total_paused_seconds=session.total_paused_seconds, - phase_durations=session.phase_durations, - phase_history=session.phase_history, - notes=session.notes, - homework=session.homework, - ) - self.db.add(db_session) - self.db.commit() - self.db.refresh(db_session) - return db_session - - # ==================== READ ==================== - - def get_by_id(self, session_id: str) -> Optional[LessonSessionDB]: - """Holt eine Session nach ID.""" - return self.db.query(LessonSessionDB).filter( - LessonSessionDB.id == session_id - ).first() - - def get_active_by_teacher(self, teacher_id: str) -> List[LessonSessionDB]: - """Holt alle aktiven Sessions eines Lehrers.""" - return self.db.query(LessonSessionDB).filter( - LessonSessionDB.teacher_id == teacher_id, - LessonSessionDB.current_phase != LessonPhaseEnum.ENDED - ).all() - - def get_history_by_teacher( - self, - teacher_id: str, - limit: int = 20, - offset: int = 0 - ) -> List[LessonSessionDB]: - """Holt Session-History eines Lehrers (Feature f17).""" - return self.db.query(LessonSessionDB).filter( - LessonSessionDB.teacher_id == teacher_id, - LessonSessionDB.current_phase == LessonPhaseEnum.ENDED - ).order_by( - LessonSessionDB.lesson_ended_at.desc() - ).offset(offset).limit(limit).all() - - def get_by_class( - self, - class_id: str, - limit: int = 20 - ) -> List[LessonSessionDB]: - """Holt Sessions einer Klasse.""" - return self.db.query(LessonSessionDB).filter( - LessonSessionDB.class_id == class_id - ).order_by( - LessonSessionDB.created_at.desc() - ).limit(limit).all() - - # ==================== UPDATE ==================== - - def update(self, session: LessonSession) -> Optional[LessonSessionDB]: - """ - Aktualisiert eine bestehende Session. - - Args: - session: LessonSession Dataclass mit aktualisierten Werten - - Returns: - Aktualisierte LessonSessionDB oder None - """ - db_session = self.get_by_id(session.session_id) - if not db_session: - return None - - db_session.current_phase = LessonPhaseEnum(session.current_phase.value) - db_session.is_paused = session.is_paused - db_session.lesson_started_at = session.lesson_started_at - db_session.lesson_ended_at = session.lesson_ended_at - db_session.phase_started_at = session.phase_started_at - db_session.pause_started_at = session.pause_started_at - db_session.total_paused_seconds = session.total_paused_seconds - db_session.phase_durations = session.phase_durations - db_session.phase_history = session.phase_history - db_session.notes = session.notes - db_session.homework = session.homework - - self.db.commit() - self.db.refresh(db_session) - return db_session - - def update_notes( - self, - session_id: str, - notes: str, - homework: str - ) -> Optional[LessonSessionDB]: - """Aktualisiert nur Notizen und Hausaufgaben.""" - db_session = self.get_by_id(session_id) - if not db_session: - return None - - db_session.notes = notes - db_session.homework = homework - - self.db.commit() - self.db.refresh(db_session) - return db_session - - # ==================== DELETE ==================== - - def delete(self, session_id: str) -> bool: - """Loescht eine Session.""" - db_session = self.get_by_id(session_id) - if not db_session: - return False - - self.db.delete(db_session) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dataclass(self, db_session: LessonSessionDB) -> LessonSession: - """ - Konvertiert DB-Model zu Dataclass. - - Args: - db_session: LessonSessionDB Model - - Returns: - LessonSession Dataclass - """ - return LessonSession( - session_id=db_session.id, - teacher_id=db_session.teacher_id, - class_id=db_session.class_id, - subject=db_session.subject, - topic=db_session.topic, - current_phase=LessonPhase(db_session.current_phase.value), - phase_started_at=db_session.phase_started_at, - lesson_started_at=db_session.lesson_started_at, - lesson_ended_at=db_session.lesson_ended_at, - is_paused=db_session.is_paused, - pause_started_at=db_session.pause_started_at, - total_paused_seconds=db_session.total_paused_seconds or 0, - phase_durations=db_session.phase_durations or get_default_durations(), - phase_history=db_session.phase_history or [], - notes=db_session.notes or "", - homework=db_session.homework or "", - ) - - -class TeacherSettingsRepository: - """Repository fuer Lehrer-Einstellungen (Feature f16).""" - - def __init__(self, db: DBSession): - self.db = db - - def get_or_create(self, teacher_id: str) -> TeacherSettingsDB: - """Holt oder erstellt Einstellungen fuer einen Lehrer.""" - settings = self.db.query(TeacherSettingsDB).filter( - TeacherSettingsDB.teacher_id == teacher_id - ).first() - - if not settings: - settings = TeacherSettingsDB( - teacher_id=teacher_id, - default_phase_durations=get_default_durations(), - ) - self.db.add(settings) - self.db.commit() - self.db.refresh(settings) - - return settings - - def update_phase_durations( - self, - teacher_id: str, - durations: Dict[str, int] - ) -> TeacherSettingsDB: - """Aktualisiert die Standard-Phasendauern.""" - settings = self.get_or_create(teacher_id) - settings.default_phase_durations = durations - self.db.commit() - self.db.refresh(settings) - return settings - - def update_preferences( - self, - teacher_id: str, - audio_enabled: Optional[bool] = None, - high_contrast: Optional[bool] = None, - show_statistics: Optional[bool] = None - ) -> TeacherSettingsDB: - """Aktualisiert UI-Praeferenzen.""" - settings = self.get_or_create(teacher_id) - - if audio_enabled is not None: - settings.audio_enabled = audio_enabled - if high_contrast is not None: - settings.high_contrast = high_contrast - if show_statistics is not None: - settings.show_statistics = show_statistics - - self.db.commit() - self.db.refresh(settings) - return settings - - -class TemplateRepository: - """Repository fuer Stunden-Vorlagen (Feature f37).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create(self, template: LessonTemplate) -> LessonTemplateDB: - """Erstellt eine neue Vorlage.""" - db_template = LessonTemplateDB( - id=template.template_id, - teacher_id=template.teacher_id, - name=template.name, - description=template.description, - subject=template.subject, - grade_level=template.grade_level, - phase_durations=template.phase_durations, - default_topic=template.default_topic, - default_notes=template.default_notes, - is_public=template.is_public, - usage_count=template.usage_count, - ) - self.db.add(db_template) - self.db.commit() - self.db.refresh(db_template) - return db_template - - # ==================== READ ==================== - - def get_by_id(self, template_id: str) -> Optional[LessonTemplateDB]: - """Holt eine Vorlage nach ID.""" - return self.db.query(LessonTemplateDB).filter( - LessonTemplateDB.id == template_id - ).first() - - def get_by_teacher( - self, - teacher_id: str, - include_public: bool = True - ) -> List[LessonTemplateDB]: - """ - Holt alle Vorlagen eines Lehrers. - - Args: - teacher_id: ID des Lehrers - include_public: Auch oeffentliche Vorlagen anderer Lehrer einbeziehen - """ - if include_public: - return self.db.query(LessonTemplateDB).filter( - (LessonTemplateDB.teacher_id == teacher_id) | - (LessonTemplateDB.is_public == True) - ).order_by( - LessonTemplateDB.usage_count.desc() - ).all() - else: - return self.db.query(LessonTemplateDB).filter( - LessonTemplateDB.teacher_id == teacher_id - ).order_by( - LessonTemplateDB.created_at.desc() - ).all() - - def get_public_templates(self, limit: int = 20) -> List[LessonTemplateDB]: - """Holt oeffentliche Vorlagen, sortiert nach Beliebtheit.""" - return self.db.query(LessonTemplateDB).filter( - LessonTemplateDB.is_public == True - ).order_by( - LessonTemplateDB.usage_count.desc() - ).limit(limit).all() - - def get_by_subject( - self, - subject: str, - teacher_id: Optional[str] = None - ) -> List[LessonTemplateDB]: - """Holt Vorlagen fuer ein bestimmtes Fach.""" - query = self.db.query(LessonTemplateDB).filter( - LessonTemplateDB.subject == subject - ) - if teacher_id: - query = query.filter( - (LessonTemplateDB.teacher_id == teacher_id) | - (LessonTemplateDB.is_public == True) - ) - else: - query = query.filter(LessonTemplateDB.is_public == True) - - return query.order_by( - LessonTemplateDB.usage_count.desc() - ).all() - - # ==================== UPDATE ==================== - - def update(self, template: LessonTemplate) -> Optional[LessonTemplateDB]: - """Aktualisiert eine Vorlage.""" - db_template = self.get_by_id(template.template_id) - if not db_template: - return None - - db_template.name = template.name - db_template.description = template.description - db_template.subject = template.subject - db_template.grade_level = template.grade_level - db_template.phase_durations = template.phase_durations - db_template.default_topic = template.default_topic - db_template.default_notes = template.default_notes - db_template.is_public = template.is_public - - self.db.commit() - self.db.refresh(db_template) - return db_template - - def increment_usage(self, template_id: str) -> Optional[LessonTemplateDB]: - """Erhoeht den Usage-Counter einer Vorlage.""" - db_template = self.get_by_id(template_id) - if not db_template: - return None - - db_template.usage_count += 1 - self.db.commit() - self.db.refresh(db_template) - return db_template - - # ==================== DELETE ==================== - - def delete(self, template_id: str) -> bool: - """Loescht eine Vorlage.""" - db_template = self.get_by_id(template_id) - if not db_template: - return False - - self.db.delete(db_template) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dataclass(self, db_template: LessonTemplateDB) -> LessonTemplate: - """Konvertiert DB-Model zu Dataclass.""" - return LessonTemplate( - template_id=db_template.id, - teacher_id=db_template.teacher_id, - name=db_template.name, - description=db_template.description or "", - subject=db_template.subject or "", - grade_level=db_template.grade_level or "", - phase_durations=db_template.phase_durations or get_default_durations(), - default_topic=db_template.default_topic or "", - default_notes=db_template.default_notes or "", - is_public=db_template.is_public, - usage_count=db_template.usage_count, - created_at=db_template.created_at, - updated_at=db_template.updated_at, - ) - - -class HomeworkRepository: - """Repository fuer Hausaufgaben-Tracking (Feature f20).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create(self, homework: Homework) -> HomeworkDB: - """Erstellt eine neue Hausaufgabe.""" - db_homework = HomeworkDB( - id=homework.homework_id, - teacher_id=homework.teacher_id, - class_id=homework.class_id, - subject=homework.subject, - title=homework.title, - description=homework.description, - session_id=homework.session_id, - due_date=homework.due_date, - status=HomeworkStatusEnum(homework.status.value), - ) - self.db.add(db_homework) - self.db.commit() - self.db.refresh(db_homework) - return db_homework - - # ==================== READ ==================== - - def get_by_id(self, homework_id: str) -> Optional[HomeworkDB]: - """Holt eine Hausaufgabe nach ID.""" - return self.db.query(HomeworkDB).filter( - HomeworkDB.id == homework_id - ).first() - - def get_by_teacher( - self, - teacher_id: str, - status: Optional[str] = None, - limit: int = 50 - ) -> List[HomeworkDB]: - """Holt alle Hausaufgaben eines Lehrers.""" - query = self.db.query(HomeworkDB).filter( - HomeworkDB.teacher_id == teacher_id - ) - if status: - query = query.filter(HomeworkDB.status == HomeworkStatusEnum(status)) - return query.order_by( - HomeworkDB.due_date.asc().nullslast(), - HomeworkDB.created_at.desc() - ).limit(limit).all() - - def get_by_class( - self, - class_id: str, - teacher_id: str, - include_completed: bool = False, - limit: int = 20 - ) -> List[HomeworkDB]: - """Holt alle Hausaufgaben einer Klasse.""" - query = self.db.query(HomeworkDB).filter( - HomeworkDB.class_id == class_id, - HomeworkDB.teacher_id == teacher_id - ) - if not include_completed: - query = query.filter(HomeworkDB.status != HomeworkStatusEnum.COMPLETED) - return query.order_by( - HomeworkDB.due_date.asc().nullslast(), - HomeworkDB.created_at.desc() - ).limit(limit).all() - - def get_by_session(self, session_id: str) -> List[HomeworkDB]: - """Holt alle Hausaufgaben einer Session.""" - return self.db.query(HomeworkDB).filter( - HomeworkDB.session_id == session_id - ).order_by(HomeworkDB.created_at.desc()).all() - - def get_pending( - self, - teacher_id: str, - days_ahead: int = 7 - ) -> List[HomeworkDB]: - """Holt anstehende Hausaufgaben der naechsten X Tage.""" - from datetime import timedelta - cutoff = datetime.utcnow() + timedelta(days=days_ahead) - return self.db.query(HomeworkDB).filter( - HomeworkDB.teacher_id == teacher_id, - HomeworkDB.status.in_([HomeworkStatusEnum.ASSIGNED, HomeworkStatusEnum.IN_PROGRESS]), - HomeworkDB.due_date <= cutoff - ).order_by(HomeworkDB.due_date.asc()).all() - - # ==================== UPDATE ==================== - - def update_status( - self, - homework_id: str, - status: HomeworkStatus - ) -> Optional[HomeworkDB]: - """Aktualisiert den Status einer Hausaufgabe.""" - db_homework = self.get_by_id(homework_id) - if not db_homework: - return None - - db_homework.status = HomeworkStatusEnum(status.value) - self.db.commit() - self.db.refresh(db_homework) - return db_homework - - def update(self, homework: Homework) -> Optional[HomeworkDB]: - """Aktualisiert eine Hausaufgabe.""" - db_homework = self.get_by_id(homework.homework_id) - if not db_homework: - return None - - db_homework.title = homework.title - db_homework.description = homework.description - db_homework.due_date = homework.due_date - db_homework.status = HomeworkStatusEnum(homework.status.value) - - self.db.commit() - self.db.refresh(db_homework) - return db_homework - - # ==================== DELETE ==================== - - def delete(self, homework_id: str) -> bool: - """Loescht eine Hausaufgabe.""" - db_homework = self.get_by_id(homework_id) - if not db_homework: - return False - - self.db.delete(db_homework) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dataclass(self, db_homework: HomeworkDB) -> Homework: - """Konvertiert DB-Model zu Dataclass.""" - return Homework( - homework_id=db_homework.id, - teacher_id=db_homework.teacher_id, - class_id=db_homework.class_id, - subject=db_homework.subject, - title=db_homework.title, - description=db_homework.description or "", - session_id=db_homework.session_id, - due_date=db_homework.due_date, - status=HomeworkStatus(db_homework.status.value), - created_at=db_homework.created_at, - updated_at=db_homework.updated_at, - ) - - -class MaterialRepository: - """Repository fuer Phasen-Materialien (Feature f19).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create(self, material: PhaseMaterial) -> PhaseMaterialDB: - """Erstellt ein neues Material.""" - db_material = PhaseMaterialDB( - id=material.material_id, - teacher_id=material.teacher_id, - title=material.title, - material_type=MaterialTypeEnum(material.material_type.value), - url=material.url, - description=material.description, - phase=material.phase, - subject=material.subject, - grade_level=material.grade_level, - tags=material.tags, - is_public=material.is_public, - usage_count=material.usage_count, - session_id=material.session_id, - ) - self.db.add(db_material) - self.db.commit() - self.db.refresh(db_material) - return db_material - - # ==================== READ ==================== - - def get_by_id(self, material_id: str) -> Optional[PhaseMaterialDB]: - """Holt ein Material nach ID.""" - return self.db.query(PhaseMaterialDB).filter( - PhaseMaterialDB.id == material_id - ).first() - - def get_by_teacher( - self, - teacher_id: str, - phase: Optional[str] = None, - subject: Optional[str] = None, - limit: int = 50 - ) -> List[PhaseMaterialDB]: - """Holt alle Materialien eines Lehrers.""" - query = self.db.query(PhaseMaterialDB).filter( - PhaseMaterialDB.teacher_id == teacher_id - ) - if phase: - query = query.filter(PhaseMaterialDB.phase == phase) - if subject: - query = query.filter(PhaseMaterialDB.subject == subject) - - return query.order_by( - PhaseMaterialDB.usage_count.desc(), - PhaseMaterialDB.created_at.desc() - ).limit(limit).all() - - def get_by_phase( - self, - phase: str, - teacher_id: str, - include_public: bool = True - ) -> List[PhaseMaterialDB]: - """Holt alle Materialien fuer eine bestimmte Phase.""" - if include_public: - return self.db.query(PhaseMaterialDB).filter( - PhaseMaterialDB.phase == phase, - (PhaseMaterialDB.teacher_id == teacher_id) | - (PhaseMaterialDB.is_public == True) - ).order_by( - PhaseMaterialDB.usage_count.desc() - ).all() - else: - return self.db.query(PhaseMaterialDB).filter( - PhaseMaterialDB.phase == phase, - PhaseMaterialDB.teacher_id == teacher_id - ).order_by( - PhaseMaterialDB.created_at.desc() - ).all() - - def get_by_session(self, session_id: str) -> List[PhaseMaterialDB]: - """Holt alle Materialien einer Session.""" - return self.db.query(PhaseMaterialDB).filter( - PhaseMaterialDB.session_id == session_id - ).order_by(PhaseMaterialDB.phase, PhaseMaterialDB.created_at).all() - - def get_public_materials( - self, - phase: Optional[str] = None, - subject: Optional[str] = None, - limit: int = 20 - ) -> List[PhaseMaterialDB]: - """Holt oeffentliche Materialien.""" - query = self.db.query(PhaseMaterialDB).filter( - PhaseMaterialDB.is_public == True - ) - if phase: - query = query.filter(PhaseMaterialDB.phase == phase) - if subject: - query = query.filter(PhaseMaterialDB.subject == subject) - - return query.order_by( - PhaseMaterialDB.usage_count.desc() - ).limit(limit).all() - - def search_by_tags( - self, - tags: List[str], - teacher_id: Optional[str] = None - ) -> List[PhaseMaterialDB]: - """Sucht Materialien nach Tags.""" - # SQLite/PostgreSQL JSON contains - query = self.db.query(PhaseMaterialDB) - if teacher_id: - query = query.filter( - (PhaseMaterialDB.teacher_id == teacher_id) | - (PhaseMaterialDB.is_public == True) - ) - else: - query = query.filter(PhaseMaterialDB.is_public == True) - - # Filter by tags - vereinfachte Implementierung - results = [] - for material in query.all(): - if material.tags and any(tag in material.tags for tag in tags): - results.append(material) - return results[:50] - - # ==================== UPDATE ==================== - - def update(self, material: PhaseMaterial) -> Optional[PhaseMaterialDB]: - """Aktualisiert ein Material.""" - db_material = self.get_by_id(material.material_id) - if not db_material: - return None - - db_material.title = material.title - db_material.material_type = MaterialTypeEnum(material.material_type.value) - db_material.url = material.url - db_material.description = material.description - db_material.phase = material.phase - db_material.subject = material.subject - db_material.grade_level = material.grade_level - db_material.tags = material.tags - db_material.is_public = material.is_public - - self.db.commit() - self.db.refresh(db_material) - return db_material - - def increment_usage(self, material_id: str) -> Optional[PhaseMaterialDB]: - """Erhoeht den Usage-Counter eines Materials.""" - db_material = self.get_by_id(material_id) - if not db_material: - return None - - db_material.usage_count += 1 - self.db.commit() - self.db.refresh(db_material) - return db_material - - def attach_to_session( - self, - material_id: str, - session_id: str - ) -> Optional[PhaseMaterialDB]: - """Verknuepft ein Material mit einer Session.""" - db_material = self.get_by_id(material_id) - if not db_material: - return None - - db_material.session_id = session_id - db_material.usage_count += 1 - self.db.commit() - self.db.refresh(db_material) - return db_material - - # ==================== DELETE ==================== - - def delete(self, material_id: str) -> bool: - """Loescht ein Material.""" - db_material = self.get_by_id(material_id) - if not db_material: - return False - - self.db.delete(db_material) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dataclass(self, db_material: PhaseMaterialDB) -> PhaseMaterial: - """Konvertiert DB-Model zu Dataclass.""" - return PhaseMaterial( - material_id=db_material.id, - teacher_id=db_material.teacher_id, - title=db_material.title, - material_type=MaterialType(db_material.material_type.value), - url=db_material.url, - description=db_material.description or "", - phase=db_material.phase, - subject=db_material.subject or "", - grade_level=db_material.grade_level or "", - tags=db_material.tags or [], - is_public=db_material.is_public, - usage_count=db_material.usage_count, - session_id=db_material.session_id, - created_at=db_material.created_at, - updated_at=db_material.updated_at, - ) - - -# ==================== REFLECTION REPOSITORY (Phase 5) ==================== - -class ReflectionRepository: - """Repository fuer LessonReflection CRUD-Operationen.""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create(self, reflection: LessonReflection) -> LessonReflectionDB: - """Erstellt eine neue Reflection.""" - db_reflection = LessonReflectionDB( - id=reflection.reflection_id, - session_id=reflection.session_id, - teacher_id=reflection.teacher_id, - notes=reflection.notes, - overall_rating=reflection.overall_rating, - what_worked=reflection.what_worked, - improvements=reflection.improvements, - notes_for_next_lesson=reflection.notes_for_next_lesson, - ) - self.db.add(db_reflection) - self.db.commit() - self.db.refresh(db_reflection) - return db_reflection - - # ==================== READ ==================== - - def get_by_id(self, reflection_id: str) -> Optional[LessonReflectionDB]: - """Holt eine Reflection nach ID.""" - return self.db.query(LessonReflectionDB).filter( - LessonReflectionDB.id == reflection_id - ).first() - - def get_by_session(self, session_id: str) -> Optional[LessonReflectionDB]: - """Holt die Reflection einer Session.""" - return self.db.query(LessonReflectionDB).filter( - LessonReflectionDB.session_id == session_id - ).first() - - def get_by_teacher( - self, - teacher_id: str, - limit: int = 20, - offset: int = 0 - ) -> List[LessonReflectionDB]: - """Holt alle Reflections eines Lehrers.""" - return self.db.query(LessonReflectionDB).filter( - LessonReflectionDB.teacher_id == teacher_id - ).order_by( - LessonReflectionDB.created_at.desc() - ).offset(offset).limit(limit).all() - - # ==================== UPDATE ==================== - - def update(self, reflection: LessonReflection) -> Optional[LessonReflectionDB]: - """Aktualisiert eine Reflection.""" - db_reflection = self.get_by_id(reflection.reflection_id) - if not db_reflection: - return None - - db_reflection.notes = reflection.notes - db_reflection.overall_rating = reflection.overall_rating - db_reflection.what_worked = reflection.what_worked - db_reflection.improvements = reflection.improvements - db_reflection.notes_for_next_lesson = reflection.notes_for_next_lesson - - self.db.commit() - self.db.refresh(db_reflection) - return db_reflection - - # ==================== DELETE ==================== - - def delete(self, reflection_id: str) -> bool: - """Loescht eine Reflection.""" - db_reflection = self.get_by_id(reflection_id) - if not db_reflection: - return False - - self.db.delete(db_reflection) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dataclass(self, db_reflection: LessonReflectionDB) -> LessonReflection: - """Konvertiert DB-Model zu Dataclass.""" - return LessonReflection( - reflection_id=db_reflection.id, - session_id=db_reflection.session_id, - teacher_id=db_reflection.teacher_id, - notes=db_reflection.notes or "", - overall_rating=db_reflection.overall_rating, - what_worked=db_reflection.what_worked or [], - improvements=db_reflection.improvements or [], - notes_for_next_lesson=db_reflection.notes_for_next_lesson or "", - created_at=db_reflection.created_at, - updated_at=db_reflection.updated_at, - ) - - -# ==================== ANALYTICS REPOSITORY (Phase 5) ==================== - -class AnalyticsRepository: - """Repository fuer Analytics-Abfragen.""" - - def __init__(self, db: DBSession): - self.db = db - - def get_session_summary(self, session_id: str) -> Optional[SessionSummary]: - """ - Berechnet die Summary einer abgeschlossenen Session. - - Args: - session_id: ID der Session - - Returns: - SessionSummary oder None wenn Session nicht gefunden - """ - db_session = self.db.query(LessonSessionDB).filter( - LessonSessionDB.id == session_id - ).first() - - if not db_session: - return None - - # Session-Daten zusammenstellen - session_data = { - "session_id": db_session.id, - "teacher_id": db_session.teacher_id, - "class_id": db_session.class_id, - "subject": db_session.subject, - "topic": db_session.topic, - "lesson_started_at": db_session.lesson_started_at, - "lesson_ended_at": db_session.lesson_ended_at, - "phase_durations": db_session.phase_durations or {}, - } - - # Phase History aus DB oder JSON - phase_history = db_session.phase_history or [] - - # Summary berechnen - return AnalyticsCalculator.calculate_session_summary( - session_data, phase_history - ) - - def get_teacher_analytics( - self, - teacher_id: str, - period_start: Optional[datetime] = None, - period_end: Optional[datetime] = None - ) -> TeacherAnalytics: - """ - Berechnet aggregierte Statistiken fuer einen Lehrer. - - Args: - teacher_id: ID des Lehrers - period_start: Beginn des Zeitraums (default: 30 Tage zurueck) - period_end: Ende des Zeitraums (default: jetzt) - - Returns: - TeacherAnalytics mit aggregierten Statistiken - """ - from datetime import timedelta - - if not period_end: - period_end = datetime.utcnow() - if not period_start: - period_start = period_end - timedelta(days=30) - - # Sessions im Zeitraum abfragen - sessions_query = self.db.query(LessonSessionDB).filter( - LessonSessionDB.teacher_id == teacher_id, - LessonSessionDB.lesson_started_at >= period_start, - LessonSessionDB.lesson_started_at <= period_end - ).all() - - # Sessions zu Dictionaries konvertieren - sessions_data = [] - for db_session in sessions_query: - sessions_data.append({ - "session_id": db_session.id, - "teacher_id": db_session.teacher_id, - "class_id": db_session.class_id, - "subject": db_session.subject, - "topic": db_session.topic, - "lesson_started_at": db_session.lesson_started_at, - "lesson_ended_at": db_session.lesson_ended_at, - "phase_durations": db_session.phase_durations or {}, - "phase_history": db_session.phase_history or [], - }) - - return AnalyticsCalculator.calculate_teacher_analytics( - sessions_data, period_start, period_end - ) - - def get_phase_duration_trends( - self, - teacher_id: str, - phase: str, - limit: int = 20 - ) -> List[Dict[str, Any]]: - """ - Gibt die Dauer-Trends fuer eine bestimmte Phase zurueck. - - Args: - teacher_id: ID des Lehrers - phase: Phasen-ID (einstieg, erarbeitung, etc.) - limit: Max Anzahl der Datenpunkte - - Returns: - Liste von Datenpunkten [{date, planned, actual, difference}] - """ - sessions = self.db.query(LessonSessionDB).filter( - LessonSessionDB.teacher_id == teacher_id, - LessonSessionDB.current_phase == LessonPhaseEnum.ENDED - ).order_by( - LessonSessionDB.lesson_ended_at.desc() - ).limit(limit).all() - - trends = [] - for db_session in sessions: - history = db_session.phase_history or [] - for entry in history: - if entry.get("phase") == phase: - planned = (db_session.phase_durations or {}).get(phase, 0) * 60 - actual = entry.get("duration_seconds", 0) or 0 - trends.append({ - "date": db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None, - "session_id": db_session.id, - "subject": db_session.subject, - "planned_seconds": planned, - "actual_seconds": actual, - "difference_seconds": actual - planned, - }) - break - - return list(reversed(trends)) # Chronologisch sortieren - - def get_overtime_analysis( - self, - teacher_id: str, - limit: int = 30 - ) -> Dict[str, Any]: - """ - Analysiert Overtime-Muster. - - Args: - teacher_id: ID des Lehrers - limit: Anzahl der zu analysierenden Sessions - - Returns: - Dict mit Overtime-Statistiken pro Phase - """ - sessions = self.db.query(LessonSessionDB).filter( - LessonSessionDB.teacher_id == teacher_id, - LessonSessionDB.current_phase == LessonPhaseEnum.ENDED - ).order_by( - LessonSessionDB.lesson_ended_at.desc() - ).limit(limit).all() - - phase_overtime: Dict[str, List[int]] = { - "einstieg": [], - "erarbeitung": [], - "sicherung": [], - "transfer": [], - "reflexion": [], - } - - for db_session in sessions: - history = db_session.phase_history or [] - phase_durations = db_session.phase_durations or {} - - for entry in history: - phase = entry.get("phase", "") - if phase in phase_overtime: - planned = phase_durations.get(phase, 0) * 60 - actual = entry.get("duration_seconds", 0) or 0 - overtime = max(0, actual - planned) - phase_overtime[phase].append(overtime) - - # Statistiken berechnen - result = {} - for phase, overtimes in phase_overtime.items(): - if overtimes: - result[phase] = { - "count": len([o for o in overtimes if o > 0]), - "total": len(overtimes), - "avg_overtime_seconds": sum(overtimes) / len(overtimes), - "max_overtime_seconds": max(overtimes), - "overtime_percentage": len([o for o in overtimes if o > 0]) / len(overtimes) * 100, - } - else: - result[phase] = { - "count": 0, - "total": 0, - "avg_overtime_seconds": 0, - "max_overtime_seconds": 0, - "overtime_percentage": 0, - } - - return result - - -# ==================== TEACHER FEEDBACK REPOSITORY (Phase 7) ==================== - - -class TeacherFeedbackRepository: - """ - Repository fuer Lehrer-Feedback CRUD-Operationen. - - Ermoeglicht Lehrern, Feedback (Bugs, Feature-Requests, Verbesserungen) - direkt aus dem Lehrer-Frontend zu senden. - """ - - def __init__(self, db: DBSession): - self.db = db - - def create( - self, - teacher_id: str, - title: str, - description: str, - feedback_type: str = "improvement", - priority: str = "medium", - teacher_name: str = "", - teacher_email: str = "", - context_url: str = "", - context_phase: str = "", - context_session_id: str = None, - user_agent: str = "", - related_feature: str = None, - ) -> TeacherFeedbackDB: - """Erstellt neues Feedback.""" - import uuid - - db_feedback = TeacherFeedbackDB( - id=str(uuid.uuid4()), - teacher_id=teacher_id, - teacher_name=teacher_name, - teacher_email=teacher_email, - title=title, - description=description, - feedback_type=FeedbackTypeEnum(feedback_type), - priority=FeedbackPriorityEnum(priority), - status=FeedbackStatusEnum.NEW, - related_feature=related_feature, - context_url=context_url, - context_phase=context_phase, - context_session_id=context_session_id, - user_agent=user_agent, - ) - - self.db.add(db_feedback) - self.db.commit() - self.db.refresh(db_feedback) - return db_feedback - - def get_by_id(self, feedback_id: str) -> Optional[TeacherFeedbackDB]: - """Holt Feedback nach ID.""" - return self.db.query(TeacherFeedbackDB).filter( - TeacherFeedbackDB.id == feedback_id - ).first() - - def get_all( - self, - status: str = None, - feedback_type: str = None, - limit: int = 100, - offset: int = 0 - ) -> List[TeacherFeedbackDB]: - """Holt alle Feedbacks mit optionalen Filtern.""" - query = self.db.query(TeacherFeedbackDB) - - if status: - query = query.filter(TeacherFeedbackDB.status == FeedbackStatusEnum(status)) - if feedback_type: - query = query.filter(TeacherFeedbackDB.feedback_type == FeedbackTypeEnum(feedback_type)) - - return query.order_by( - TeacherFeedbackDB.created_at.desc() - ).offset(offset).limit(limit).all() - - def get_by_teacher(self, teacher_id: str, limit: int = 50) -> List[TeacherFeedbackDB]: - """Holt Feedback eines bestimmten Lehrers.""" - return self.db.query(TeacherFeedbackDB).filter( - TeacherFeedbackDB.teacher_id == teacher_id - ).order_by( - TeacherFeedbackDB.created_at.desc() - ).limit(limit).all() - - def update_status( - self, - feedback_id: str, - status: str, - response: str = None, - responded_by: str = None - ) -> Optional[TeacherFeedbackDB]: - """Aktualisiert den Status eines Feedbacks.""" - db_feedback = self.get_by_id(feedback_id) - if not db_feedback: - return None - - db_feedback.status = FeedbackStatusEnum(status) - if response: - db_feedback.response = response - db_feedback.responded_at = datetime.utcnow() - db_feedback.responded_by = responded_by - - self.db.commit() - self.db.refresh(db_feedback) - return db_feedback - - def delete(self, feedback_id: str) -> bool: - """Loescht ein Feedback.""" - db_feedback = self.get_by_id(feedback_id) - if not db_feedback: - return False - - self.db.delete(db_feedback) - self.db.commit() - return True - - def get_stats(self) -> Dict[str, Any]: - """Gibt Statistiken ueber alle Feedbacks zurueck.""" - all_feedback = self.db.query(TeacherFeedbackDB).all() - - stats = { - "total": len(all_feedback), - "by_status": {}, - "by_type": {}, - "by_priority": {}, - } - - for fb in all_feedback: - # By Status - status = fb.status.value - stats["by_status"][status] = stats["by_status"].get(status, 0) + 1 - - # By Type - fb_type = fb.feedback_type.value - stats["by_type"][fb_type] = stats["by_type"].get(fb_type, 0) + 1 - - # By Priority - priority = fb.priority.value - stats["by_priority"][priority] = stats["by_priority"].get(priority, 0) + 1 - - return stats - - def to_dict(self, db_feedback: TeacherFeedbackDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": db_feedback.id, - "teacher_id": db_feedback.teacher_id, - "teacher_name": db_feedback.teacher_name, - "teacher_email": db_feedback.teacher_email, - "title": db_feedback.title, - "description": db_feedback.description, - "feedback_type": db_feedback.feedback_type.value, - "priority": db_feedback.priority.value, - "status": db_feedback.status.value, - "related_feature": db_feedback.related_feature, - "context_url": db_feedback.context_url, - "context_phase": db_feedback.context_phase, - "context_session_id": db_feedback.context_session_id, - "user_agent": db_feedback.user_agent, - "response": db_feedback.response, - "responded_at": db_feedback.responded_at.isoformat() if db_feedback.responded_at else None, - "responded_by": db_feedback.responded_by, - "created_at": db_feedback.created_at.isoformat() if db_feedback.created_at else None, - "updated_at": db_feedback.updated_at.isoformat() if db_feedback.updated_at else None, - } - - -# ==================== Phase 8: Teacher Context Repository ==================== - - -class TeacherContextRepository: - """Repository fuer Lehrer-Kontext CRUD-Operationen (Phase 8).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE / GET-OR-CREATE ==================== - - def get_or_create(self, teacher_id: str) -> TeacherContextDB: - """ - Holt den Kontext eines Lehrers oder erstellt einen neuen. - - Args: - teacher_id: ID des Lehrers - - Returns: - TeacherContextDB Model - """ - context = self.get_by_teacher_id(teacher_id) - if context: - return context - - # Neuen Kontext erstellen - from uuid import uuid4 - context = TeacherContextDB( - id=str(uuid4()), - teacher_id=teacher_id, - macro_phase=MacroPhaseEnum.ONBOARDING, - ) - self.db.add(context) - self.db.commit() - self.db.refresh(context) - return context - - # ==================== READ ==================== - - def get_by_teacher_id(self, teacher_id: str) -> Optional[TeacherContextDB]: - """Holt den Kontext eines Lehrers.""" - return self.db.query(TeacherContextDB).filter( - TeacherContextDB.teacher_id == teacher_id - ).first() - - # ==================== UPDATE ==================== - - def update_context( - self, - teacher_id: str, - federal_state: str = None, - school_type: str = None, - schoolyear: str = None, - schoolyear_start: datetime = None, - macro_phase: str = None, - current_week: int = None, - ) -> Optional[TeacherContextDB]: - """Aktualisiert den Kontext eines Lehrers.""" - context = self.get_or_create(teacher_id) - - if federal_state is not None: - context.federal_state = federal_state - if school_type is not None: - context.school_type = school_type - if schoolyear is not None: - context.schoolyear = schoolyear - if schoolyear_start is not None: - context.schoolyear_start = schoolyear_start - if macro_phase is not None: - context.macro_phase = MacroPhaseEnum(macro_phase) - if current_week is not None: - context.current_week = current_week - - self.db.commit() - self.db.refresh(context) - return context - - def complete_onboarding(self, teacher_id: str) -> TeacherContextDB: - """Markiert Onboarding als abgeschlossen.""" - context = self.get_or_create(teacher_id) - context.onboarding_completed = True - context.macro_phase = MacroPhaseEnum.SCHULJAHRESSTART - self.db.commit() - self.db.refresh(context) - return context - - def update_flags( - self, - teacher_id: str, - has_classes: bool = None, - has_schedule: bool = None, - is_exam_period: bool = None, - is_before_holidays: bool = None, - ) -> TeacherContextDB: - """Aktualisiert die Status-Flags eines Kontexts.""" - context = self.get_or_create(teacher_id) - - if has_classes is not None: - context.has_classes = has_classes - if has_schedule is not None: - context.has_schedule = has_schedule - if is_exam_period is not None: - context.is_exam_period = is_exam_period - if is_before_holidays is not None: - context.is_before_holidays = is_before_holidays - - self.db.commit() - self.db.refresh(context) - return context - - def to_dict(self, context: TeacherContextDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": context.id, - "teacher_id": context.teacher_id, - "school": { - "federal_state": context.federal_state, - "federal_state_name": FEDERAL_STATES.get(context.federal_state, ""), - "school_type": context.school_type, - "school_type_name": SCHOOL_TYPES.get(context.school_type, ""), - }, - "school_year": { - "id": context.schoolyear, - "start": context.schoolyear_start.isoformat() if context.schoolyear_start else None, - "current_week": context.current_week, - }, - "macro_phase": { - "id": context.macro_phase.value, - "label": self._get_phase_label(context.macro_phase), - }, - "flags": { - "onboarding_completed": context.onboarding_completed, - "has_classes": context.has_classes, - "has_schedule": context.has_schedule, - "is_exam_period": context.is_exam_period, - "is_before_holidays": context.is_before_holidays, - }, - "created_at": context.created_at.isoformat() if context.created_at else None, - "updated_at": context.updated_at.isoformat() if context.updated_at else None, - } - - def _get_phase_label(self, phase: MacroPhaseEnum) -> str: - """Gibt den Anzeigenamen einer Makro-Phase zurueck.""" - labels = { - MacroPhaseEnum.ONBOARDING: "Einrichtung", - MacroPhaseEnum.SCHULJAHRESSTART: "Schuljahresstart", - MacroPhaseEnum.UNTERRICHTSAUFBAU: "Unterrichtsaufbau", - MacroPhaseEnum.LEISTUNGSPHASE_1: "Leistungsphase 1", - MacroPhaseEnum.HALBJAHRESABSCHLUSS: "Halbjahresabschluss", - MacroPhaseEnum.LEISTUNGSPHASE_2: "Leistungsphase 2", - MacroPhaseEnum.JAHRESABSCHLUSS: "Jahresabschluss", - } - return labels.get(phase, phase.value) - - -# ==================== Phase 8: Schoolyear Event Repository ==================== - - -class SchoolyearEventRepository: - """Repository fuer Schuljahr-Events (Phase 8).""" - - def __init__(self, db: DBSession): - self.db = db - - def create( - self, - teacher_id: str, - title: str, - start_date: datetime, - event_type: str = "other", - end_date: datetime = None, - class_id: str = None, - subject: str = None, - description: str = "", - needs_preparation: bool = True, - reminder_days_before: int = 7, - extra_data: Dict[str, Any] = None, - ) -> SchoolyearEventDB: - """Erstellt ein neues Schuljahr-Event.""" - from uuid import uuid4 - event = SchoolyearEventDB( - id=str(uuid4()), - teacher_id=teacher_id, - title=title, - event_type=EventTypeEnum(event_type), - start_date=start_date, - end_date=end_date, - class_id=class_id, - subject=subject, - description=description, - needs_preparation=needs_preparation, - reminder_days_before=reminder_days_before, - extra_data=extra_data or {}, - ) - self.db.add(event) - self.db.commit() - self.db.refresh(event) - return event - - def get_by_id(self, event_id: str) -> Optional[SchoolyearEventDB]: - """Holt ein Event nach ID.""" - return self.db.query(SchoolyearEventDB).filter( - SchoolyearEventDB.id == event_id - ).first() - - def get_by_teacher( - self, - teacher_id: str, - status: str = None, - event_type: str = None, - limit: int = 50, - ) -> List[SchoolyearEventDB]: - """Holt Events eines Lehrers.""" - query = self.db.query(SchoolyearEventDB).filter( - SchoolyearEventDB.teacher_id == teacher_id - ) - if status: - query = query.filter(SchoolyearEventDB.status == EventStatusEnum(status)) - if event_type: - query = query.filter(SchoolyearEventDB.event_type == EventTypeEnum(event_type)) - - return query.order_by(SchoolyearEventDB.start_date).limit(limit).all() - - def get_upcoming( - self, - teacher_id: str, - days: int = 30, - limit: int = 10, - ) -> List[SchoolyearEventDB]: - """Holt anstehende Events der naechsten X Tage.""" - from datetime import timedelta - now = datetime.utcnow() - end = now + timedelta(days=days) - - return self.db.query(SchoolyearEventDB).filter( - SchoolyearEventDB.teacher_id == teacher_id, - SchoolyearEventDB.start_date >= now, - SchoolyearEventDB.start_date <= end, - SchoolyearEventDB.status != EventStatusEnum.CANCELLED, - ).order_by(SchoolyearEventDB.start_date).limit(limit).all() - - def update_status( - self, - event_id: str, - status: str, - preparation_done: bool = None, - ) -> Optional[SchoolyearEventDB]: - """Aktualisiert den Status eines Events.""" - event = self.get_by_id(event_id) - if not event: - return None - - event.status = EventStatusEnum(status) - if preparation_done is not None: - event.preparation_done = preparation_done - - self.db.commit() - self.db.refresh(event) - return event - - def delete(self, event_id: str) -> bool: - """Loescht ein Event.""" - event = self.get_by_id(event_id) - if not event: - return False - self.db.delete(event) - self.db.commit() - return True - - def to_dict(self, event: SchoolyearEventDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": event.id, - "teacher_id": event.teacher_id, - "event_type": event.event_type.value, - "title": event.title, - "description": event.description, - "start_date": event.start_date.isoformat() if event.start_date else None, - "end_date": event.end_date.isoformat() if event.end_date else None, - "class_id": event.class_id, - "subject": event.subject, - "status": event.status.value, - "needs_preparation": event.needs_preparation, - "preparation_done": event.preparation_done, - "reminder_days_before": event.reminder_days_before, - "extra_data": event.extra_data, - "created_at": event.created_at.isoformat() if event.created_at else None, - } - - -# ==================== Phase 8: Recurring Routine Repository ==================== - - -class RecurringRoutineRepository: - """Repository fuer wiederkehrende Routinen (Phase 8).""" - - def __init__(self, db: DBSession): - self.db = db - - def create( - self, - teacher_id: str, - title: str, - routine_type: str = "other", - recurrence_pattern: str = "weekly", - day_of_week: int = None, - day_of_month: int = None, - time_of_day: str = None, # Format: "14:00" - duration_minutes: int = 60, - description: str = "", - valid_from: datetime = None, - valid_until: datetime = None, - ) -> RecurringRoutineDB: - """Erstellt eine neue wiederkehrende Routine.""" - from uuid import uuid4 - from datetime import time as dt_time - - time_obj = None - if time_of_day: - parts = time_of_day.split(":") - time_obj = dt_time(int(parts[0]), int(parts[1])) - - routine = RecurringRoutineDB( - id=str(uuid4()), - teacher_id=teacher_id, - title=title, - routine_type=RoutineTypeEnum(routine_type), - recurrence_pattern=RecurrencePatternEnum(recurrence_pattern), - day_of_week=day_of_week, - day_of_month=day_of_month, - time_of_day=time_obj, - duration_minutes=duration_minutes, - description=description, - valid_from=valid_from, - valid_until=valid_until, - ) - self.db.add(routine) - self.db.commit() - self.db.refresh(routine) - return routine - - def get_by_id(self, routine_id: str) -> Optional[RecurringRoutineDB]: - """Holt eine Routine nach ID.""" - return self.db.query(RecurringRoutineDB).filter( - RecurringRoutineDB.id == routine_id - ).first() - - def get_by_teacher( - self, - teacher_id: str, - is_active: bool = True, - routine_type: str = None, - ) -> List[RecurringRoutineDB]: - """Holt Routinen eines Lehrers.""" - query = self.db.query(RecurringRoutineDB).filter( - RecurringRoutineDB.teacher_id == teacher_id - ) - if is_active is not None: - query = query.filter(RecurringRoutineDB.is_active == is_active) - if routine_type: - query = query.filter(RecurringRoutineDB.routine_type == RoutineTypeEnum(routine_type)) - - return query.all() - - def get_today(self, teacher_id: str) -> List[RecurringRoutineDB]: - """Holt Routinen die heute stattfinden.""" - today = datetime.utcnow() - day_of_week = today.weekday() # 0 = Montag - day_of_month = today.day - - routines = self.get_by_teacher(teacher_id, is_active=True) - today_routines = [] - - for routine in routines: - if routine.recurrence_pattern == RecurrencePatternEnum.DAILY: - today_routines.append(routine) - elif routine.recurrence_pattern == RecurrencePatternEnum.WEEKLY: - if routine.day_of_week == day_of_week: - today_routines.append(routine) - elif routine.recurrence_pattern == RecurrencePatternEnum.BIWEEKLY: - # Vereinfacht: Pruefen ob Tag passt (echte Logik braucht Startdatum) - if routine.day_of_week == day_of_week: - today_routines.append(routine) - elif routine.recurrence_pattern == RecurrencePatternEnum.MONTHLY: - if routine.day_of_month == day_of_month: - today_routines.append(routine) - - return today_routines - - def update( - self, - routine_id: str, - title: str = None, - is_active: bool = None, - day_of_week: int = None, - time_of_day: str = None, - ) -> Optional[RecurringRoutineDB]: - """Aktualisiert eine Routine.""" - routine = self.get_by_id(routine_id) - if not routine: - return None - - if title is not None: - routine.title = title - if is_active is not None: - routine.is_active = is_active - if day_of_week is not None: - routine.day_of_week = day_of_week - if time_of_day is not None: - from datetime import time as dt_time - parts = time_of_day.split(":") - routine.time_of_day = dt_time(int(parts[0]), int(parts[1])) - - self.db.commit() - self.db.refresh(routine) - return routine - - def delete(self, routine_id: str) -> bool: - """Loescht eine Routine.""" - routine = self.get_by_id(routine_id) - if not routine: - return False - self.db.delete(routine) - self.db.commit() - return True - - def to_dict(self, routine: RecurringRoutineDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": routine.id, - "teacher_id": routine.teacher_id, - "routine_type": routine.routine_type.value, - "title": routine.title, - "description": routine.description, - "recurrence_pattern": routine.recurrence_pattern.value, - "day_of_week": routine.day_of_week, - "day_of_month": routine.day_of_month, - "time_of_day": routine.time_of_day.isoformat() if routine.time_of_day else None, - "duration_minutes": routine.duration_minutes, - "is_active": routine.is_active, - "valid_from": routine.valid_from.isoformat() if routine.valid_from else None, - "valid_until": routine.valid_until.isoformat() if routine.valid_until else None, - "created_at": routine.created_at.isoformat() if routine.created_at else None, - } +__all__ = [ + "SessionRepository", + "TeacherSettingsRepository", + "TemplateRepository", + "HomeworkRepository", + "MaterialRepository", + "ReflectionRepository", + "AnalyticsRepository", + "TeacherFeedbackRepository", + "TeacherContextRepository", + "SchoolyearEventRepository", + "RecurringRoutineRepository", +] diff --git a/backend-lehrer/classroom_engine/repository_context.py b/backend-lehrer/classroom_engine/repository_context.py new file mode 100644 index 0000000..5bbfe42 --- /dev/null +++ b/backend-lehrer/classroom_engine/repository_context.py @@ -0,0 +1,453 @@ +""" +Teacher Context, Schoolyear Event & Recurring Routine Repositories. + +CRUD-Operationen fuer Schuljahres-Kontext (Phase 8). +""" +from datetime import datetime +from typing import Optional, List, Dict, Any + +from sqlalchemy.orm import Session as DBSession + +from .context_models import ( + TeacherContextDB, SchoolyearEventDB, RecurringRoutineDB, + MacroPhaseEnum, EventTypeEnum, EventStatusEnum, + RoutineTypeEnum, RecurrencePatternEnum, + FEDERAL_STATES, SCHOOL_TYPES, +) + + +class TeacherContextRepository: + """Repository fuer Lehrer-Kontext CRUD-Operationen (Phase 8).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE / GET-OR-CREATE ==================== + + def get_or_create(self, teacher_id: str) -> TeacherContextDB: + """ + Holt den Kontext eines Lehrers oder erstellt einen neuen. + + Args: + teacher_id: ID des Lehrers + + Returns: + TeacherContextDB Model + """ + context = self.get_by_teacher_id(teacher_id) + if context: + return context + + # Neuen Kontext erstellen + from uuid import uuid4 + context = TeacherContextDB( + id=str(uuid4()), + teacher_id=teacher_id, + macro_phase=MacroPhaseEnum.ONBOARDING, + ) + self.db.add(context) + self.db.commit() + self.db.refresh(context) + return context + + # ==================== READ ==================== + + def get_by_teacher_id(self, teacher_id: str) -> Optional[TeacherContextDB]: + """Holt den Kontext eines Lehrers.""" + return self.db.query(TeacherContextDB).filter( + TeacherContextDB.teacher_id == teacher_id + ).first() + + # ==================== UPDATE ==================== + + def update_context( + self, + teacher_id: str, + federal_state: str = None, + school_type: str = None, + schoolyear: str = None, + schoolyear_start: datetime = None, + macro_phase: str = None, + current_week: int = None, + ) -> Optional[TeacherContextDB]: + """Aktualisiert den Kontext eines Lehrers.""" + context = self.get_or_create(teacher_id) + + if federal_state is not None: + context.federal_state = federal_state + if school_type is not None: + context.school_type = school_type + if schoolyear is not None: + context.schoolyear = schoolyear + if schoolyear_start is not None: + context.schoolyear_start = schoolyear_start + if macro_phase is not None: + context.macro_phase = MacroPhaseEnum(macro_phase) + if current_week is not None: + context.current_week = current_week + + self.db.commit() + self.db.refresh(context) + return context + + def complete_onboarding(self, teacher_id: str) -> TeacherContextDB: + """Markiert Onboarding als abgeschlossen.""" + context = self.get_or_create(teacher_id) + context.onboarding_completed = True + context.macro_phase = MacroPhaseEnum.SCHULJAHRESSTART + self.db.commit() + self.db.refresh(context) + return context + + def update_flags( + self, + teacher_id: str, + has_classes: bool = None, + has_schedule: bool = None, + is_exam_period: bool = None, + is_before_holidays: bool = None, + ) -> TeacherContextDB: + """Aktualisiert die Status-Flags eines Kontexts.""" + context = self.get_or_create(teacher_id) + + if has_classes is not None: + context.has_classes = has_classes + if has_schedule is not None: + context.has_schedule = has_schedule + if is_exam_period is not None: + context.is_exam_period = is_exam_period + if is_before_holidays is not None: + context.is_before_holidays = is_before_holidays + + self.db.commit() + self.db.refresh(context) + return context + + def to_dict(self, context: TeacherContextDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": context.id, + "teacher_id": context.teacher_id, + "school": { + "federal_state": context.federal_state, + "federal_state_name": FEDERAL_STATES.get(context.federal_state, ""), + "school_type": context.school_type, + "school_type_name": SCHOOL_TYPES.get(context.school_type, ""), + }, + "school_year": { + "id": context.schoolyear, + "start": context.schoolyear_start.isoformat() if context.schoolyear_start else None, + "current_week": context.current_week, + }, + "macro_phase": { + "id": context.macro_phase.value, + "label": self._get_phase_label(context.macro_phase), + }, + "flags": { + "onboarding_completed": context.onboarding_completed, + "has_classes": context.has_classes, + "has_schedule": context.has_schedule, + "is_exam_period": context.is_exam_period, + "is_before_holidays": context.is_before_holidays, + }, + "created_at": context.created_at.isoformat() if context.created_at else None, + "updated_at": context.updated_at.isoformat() if context.updated_at else None, + } + + def _get_phase_label(self, phase: MacroPhaseEnum) -> str: + """Gibt den Anzeigenamen einer Makro-Phase zurueck.""" + labels = { + MacroPhaseEnum.ONBOARDING: "Einrichtung", + MacroPhaseEnum.SCHULJAHRESSTART: "Schuljahresstart", + MacroPhaseEnum.UNTERRICHTSAUFBAU: "Unterrichtsaufbau", + MacroPhaseEnum.LEISTUNGSPHASE_1: "Leistungsphase 1", + MacroPhaseEnum.HALBJAHRESABSCHLUSS: "Halbjahresabschluss", + MacroPhaseEnum.LEISTUNGSPHASE_2: "Leistungsphase 2", + MacroPhaseEnum.JAHRESABSCHLUSS: "Jahresabschluss", + } + return labels.get(phase, phase.value) + + +class SchoolyearEventRepository: + """Repository fuer Schuljahr-Events (Phase 8).""" + + def __init__(self, db: DBSession): + self.db = db + + def create( + self, + teacher_id: str, + title: str, + start_date: datetime, + event_type: str = "other", + end_date: datetime = None, + class_id: str = None, + subject: str = None, + description: str = "", + needs_preparation: bool = True, + reminder_days_before: int = 7, + extra_data: Dict[str, Any] = None, + ) -> SchoolyearEventDB: + """Erstellt ein neues Schuljahr-Event.""" + from uuid import uuid4 + event = SchoolyearEventDB( + id=str(uuid4()), + teacher_id=teacher_id, + title=title, + event_type=EventTypeEnum(event_type), + start_date=start_date, + end_date=end_date, + class_id=class_id, + subject=subject, + description=description, + needs_preparation=needs_preparation, + reminder_days_before=reminder_days_before, + extra_data=extra_data or {}, + ) + self.db.add(event) + self.db.commit() + self.db.refresh(event) + return event + + def get_by_id(self, event_id: str) -> Optional[SchoolyearEventDB]: + """Holt ein Event nach ID.""" + return self.db.query(SchoolyearEventDB).filter( + SchoolyearEventDB.id == event_id + ).first() + + def get_by_teacher( + self, + teacher_id: str, + status: str = None, + event_type: str = None, + limit: int = 50, + ) -> List[SchoolyearEventDB]: + """Holt Events eines Lehrers.""" + query = self.db.query(SchoolyearEventDB).filter( + SchoolyearEventDB.teacher_id == teacher_id + ) + if status: + query = query.filter(SchoolyearEventDB.status == EventStatusEnum(status)) + if event_type: + query = query.filter(SchoolyearEventDB.event_type == EventTypeEnum(event_type)) + + return query.order_by(SchoolyearEventDB.start_date).limit(limit).all() + + def get_upcoming( + self, + teacher_id: str, + days: int = 30, + limit: int = 10, + ) -> List[SchoolyearEventDB]: + """Holt anstehende Events der naechsten X Tage.""" + from datetime import timedelta + now = datetime.utcnow() + end = now + timedelta(days=days) + + return self.db.query(SchoolyearEventDB).filter( + SchoolyearEventDB.teacher_id == teacher_id, + SchoolyearEventDB.start_date >= now, + SchoolyearEventDB.start_date <= end, + SchoolyearEventDB.status != EventStatusEnum.CANCELLED, + ).order_by(SchoolyearEventDB.start_date).limit(limit).all() + + def update_status( + self, + event_id: str, + status: str, + preparation_done: bool = None, + ) -> Optional[SchoolyearEventDB]: + """Aktualisiert den Status eines Events.""" + event = self.get_by_id(event_id) + if not event: + return None + + event.status = EventStatusEnum(status) + if preparation_done is not None: + event.preparation_done = preparation_done + + self.db.commit() + self.db.refresh(event) + return event + + def delete(self, event_id: str) -> bool: + """Loescht ein Event.""" + event = self.get_by_id(event_id) + if not event: + return False + self.db.delete(event) + self.db.commit() + return True + + def to_dict(self, event: SchoolyearEventDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": event.id, + "teacher_id": event.teacher_id, + "event_type": event.event_type.value, + "title": event.title, + "description": event.description, + "start_date": event.start_date.isoformat() if event.start_date else None, + "end_date": event.end_date.isoformat() if event.end_date else None, + "class_id": event.class_id, + "subject": event.subject, + "status": event.status.value, + "needs_preparation": event.needs_preparation, + "preparation_done": event.preparation_done, + "reminder_days_before": event.reminder_days_before, + "extra_data": event.extra_data, + "created_at": event.created_at.isoformat() if event.created_at else None, + } + + +class RecurringRoutineRepository: + """Repository fuer wiederkehrende Routinen (Phase 8).""" + + def __init__(self, db: DBSession): + self.db = db + + def create( + self, + teacher_id: str, + title: str, + routine_type: str = "other", + recurrence_pattern: str = "weekly", + day_of_week: int = None, + day_of_month: int = None, + time_of_day: str = None, # Format: "14:00" + duration_minutes: int = 60, + description: str = "", + valid_from: datetime = None, + valid_until: datetime = None, + ) -> RecurringRoutineDB: + """Erstellt eine neue wiederkehrende Routine.""" + from uuid import uuid4 + from datetime import time as dt_time + + time_obj = None + if time_of_day: + parts = time_of_day.split(":") + time_obj = dt_time(int(parts[0]), int(parts[1])) + + routine = RecurringRoutineDB( + id=str(uuid4()), + teacher_id=teacher_id, + title=title, + routine_type=RoutineTypeEnum(routine_type), + recurrence_pattern=RecurrencePatternEnum(recurrence_pattern), + day_of_week=day_of_week, + day_of_month=day_of_month, + time_of_day=time_obj, + duration_minutes=duration_minutes, + description=description, + valid_from=valid_from, + valid_until=valid_until, + ) + self.db.add(routine) + self.db.commit() + self.db.refresh(routine) + return routine + + def get_by_id(self, routine_id: str) -> Optional[RecurringRoutineDB]: + """Holt eine Routine nach ID.""" + return self.db.query(RecurringRoutineDB).filter( + RecurringRoutineDB.id == routine_id + ).first() + + def get_by_teacher( + self, + teacher_id: str, + is_active: bool = True, + routine_type: str = None, + ) -> List[RecurringRoutineDB]: + """Holt Routinen eines Lehrers.""" + query = self.db.query(RecurringRoutineDB).filter( + RecurringRoutineDB.teacher_id == teacher_id + ) + if is_active is not None: + query = query.filter(RecurringRoutineDB.is_active == is_active) + if routine_type: + query = query.filter(RecurringRoutineDB.routine_type == RoutineTypeEnum(routine_type)) + + return query.all() + + def get_today(self, teacher_id: str) -> List[RecurringRoutineDB]: + """Holt Routinen die heute stattfinden.""" + today = datetime.utcnow() + day_of_week = today.weekday() # 0 = Montag + day_of_month = today.day + + routines = self.get_by_teacher(teacher_id, is_active=True) + today_routines = [] + + for routine in routines: + if routine.recurrence_pattern == RecurrencePatternEnum.DAILY: + today_routines.append(routine) + elif routine.recurrence_pattern == RecurrencePatternEnum.WEEKLY: + if routine.day_of_week == day_of_week: + today_routines.append(routine) + elif routine.recurrence_pattern == RecurrencePatternEnum.BIWEEKLY: + # Vereinfacht: Pruefen ob Tag passt (echte Logik braucht Startdatum) + if routine.day_of_week == day_of_week: + today_routines.append(routine) + elif routine.recurrence_pattern == RecurrencePatternEnum.MONTHLY: + if routine.day_of_month == day_of_month: + today_routines.append(routine) + + return today_routines + + def update( + self, + routine_id: str, + title: str = None, + is_active: bool = None, + day_of_week: int = None, + time_of_day: str = None, + ) -> Optional[RecurringRoutineDB]: + """Aktualisiert eine Routine.""" + routine = self.get_by_id(routine_id) + if not routine: + return None + + if title is not None: + routine.title = title + if is_active is not None: + routine.is_active = is_active + if day_of_week is not None: + routine.day_of_week = day_of_week + if time_of_day is not None: + from datetime import time as dt_time + parts = time_of_day.split(":") + routine.time_of_day = dt_time(int(parts[0]), int(parts[1])) + + self.db.commit() + self.db.refresh(routine) + return routine + + def delete(self, routine_id: str) -> bool: + """Loescht eine Routine.""" + routine = self.get_by_id(routine_id) + if not routine: + return False + self.db.delete(routine) + self.db.commit() + return True + + def to_dict(self, routine: RecurringRoutineDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": routine.id, + "teacher_id": routine.teacher_id, + "routine_type": routine.routine_type.value, + "title": routine.title, + "description": routine.description, + "recurrence_pattern": routine.recurrence_pattern.value, + "day_of_week": routine.day_of_week, + "day_of_month": routine.day_of_month, + "time_of_day": routine.time_of_day.isoformat() if routine.time_of_day else None, + "duration_minutes": routine.duration_minutes, + "is_active": routine.is_active, + "valid_from": routine.valid_from.isoformat() if routine.valid_from else None, + "valid_until": routine.valid_until.isoformat() if routine.valid_until else None, + "created_at": routine.created_at.isoformat() if routine.created_at else None, + } diff --git a/backend-lehrer/classroom_engine/repository_feedback.py b/backend-lehrer/classroom_engine/repository_feedback.py new file mode 100644 index 0000000..192b3e2 --- /dev/null +++ b/backend-lehrer/classroom_engine/repository_feedback.py @@ -0,0 +1,182 @@ +""" +Teacher Feedback Repository. + +CRUD-Operationen fuer Lehrer-Feedback (Phase 7). +Ermoeglicht Lehrern, Bugs, Feature-Requests und Verbesserungen zu melden. +""" +from datetime import datetime +from typing import Optional, List, Dict, Any + +from sqlalchemy.orm import Session as DBSession + +from .db_models import ( + TeacherFeedbackDB, FeedbackTypeEnum, FeedbackStatusEnum, + FeedbackPriorityEnum, +) + + +class TeacherFeedbackRepository: + """ + Repository fuer Lehrer-Feedback CRUD-Operationen. + + Ermoeglicht Lehrern, Feedback (Bugs, Feature-Requests, Verbesserungen) + direkt aus dem Lehrer-Frontend zu senden. + """ + + def __init__(self, db: DBSession): + self.db = db + + def create( + self, + teacher_id: str, + title: str, + description: str, + feedback_type: str = "improvement", + priority: str = "medium", + teacher_name: str = "", + teacher_email: str = "", + context_url: str = "", + context_phase: str = "", + context_session_id: str = None, + user_agent: str = "", + related_feature: str = None, + ) -> TeacherFeedbackDB: + """Erstellt neues Feedback.""" + import uuid + + db_feedback = TeacherFeedbackDB( + id=str(uuid.uuid4()), + teacher_id=teacher_id, + teacher_name=teacher_name, + teacher_email=teacher_email, + title=title, + description=description, + feedback_type=FeedbackTypeEnum(feedback_type), + priority=FeedbackPriorityEnum(priority), + status=FeedbackStatusEnum.NEW, + related_feature=related_feature, + context_url=context_url, + context_phase=context_phase, + context_session_id=context_session_id, + user_agent=user_agent, + ) + + self.db.add(db_feedback) + self.db.commit() + self.db.refresh(db_feedback) + return db_feedback + + def get_by_id(self, feedback_id: str) -> Optional[TeacherFeedbackDB]: + """Holt Feedback nach ID.""" + return self.db.query(TeacherFeedbackDB).filter( + TeacherFeedbackDB.id == feedback_id + ).first() + + def get_all( + self, + status: str = None, + feedback_type: str = None, + limit: int = 100, + offset: int = 0 + ) -> List[TeacherFeedbackDB]: + """Holt alle Feedbacks mit optionalen Filtern.""" + query = self.db.query(TeacherFeedbackDB) + + if status: + query = query.filter(TeacherFeedbackDB.status == FeedbackStatusEnum(status)) + if feedback_type: + query = query.filter(TeacherFeedbackDB.feedback_type == FeedbackTypeEnum(feedback_type)) + + return query.order_by( + TeacherFeedbackDB.created_at.desc() + ).offset(offset).limit(limit).all() + + def get_by_teacher(self, teacher_id: str, limit: int = 50) -> List[TeacherFeedbackDB]: + """Holt Feedback eines bestimmten Lehrers.""" + return self.db.query(TeacherFeedbackDB).filter( + TeacherFeedbackDB.teacher_id == teacher_id + ).order_by( + TeacherFeedbackDB.created_at.desc() + ).limit(limit).all() + + def update_status( + self, + feedback_id: str, + status: str, + response: str = None, + responded_by: str = None + ) -> Optional[TeacherFeedbackDB]: + """Aktualisiert den Status eines Feedbacks.""" + db_feedback = self.get_by_id(feedback_id) + if not db_feedback: + return None + + db_feedback.status = FeedbackStatusEnum(status) + if response: + db_feedback.response = response + db_feedback.responded_at = datetime.utcnow() + db_feedback.responded_by = responded_by + + self.db.commit() + self.db.refresh(db_feedback) + return db_feedback + + def delete(self, feedback_id: str) -> bool: + """Loescht ein Feedback.""" + db_feedback = self.get_by_id(feedback_id) + if not db_feedback: + return False + + self.db.delete(db_feedback) + self.db.commit() + return True + + def get_stats(self) -> Dict[str, Any]: + """Gibt Statistiken ueber alle Feedbacks zurueck.""" + all_feedback = self.db.query(TeacherFeedbackDB).all() + + stats = { + "total": len(all_feedback), + "by_status": {}, + "by_type": {}, + "by_priority": {}, + } + + for fb in all_feedback: + # By Status + status = fb.status.value + stats["by_status"][status] = stats["by_status"].get(status, 0) + 1 + + # By Type + fb_type = fb.feedback_type.value + stats["by_type"][fb_type] = stats["by_type"].get(fb_type, 0) + 1 + + # By Priority + priority = fb.priority.value + stats["by_priority"][priority] = stats["by_priority"].get(priority, 0) + 1 + + return stats + + def to_dict(self, db_feedback: TeacherFeedbackDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": db_feedback.id, + "teacher_id": db_feedback.teacher_id, + "teacher_name": db_feedback.teacher_name, + "teacher_email": db_feedback.teacher_email, + "title": db_feedback.title, + "description": db_feedback.description, + "feedback_type": db_feedback.feedback_type.value, + "priority": db_feedback.priority.value, + "status": db_feedback.status.value, + "related_feature": db_feedback.related_feature, + "context_url": db_feedback.context_url, + "context_phase": db_feedback.context_phase, + "context_session_id": db_feedback.context_session_id, + "user_agent": db_feedback.user_agent, + "response": db_feedback.response, + "responded_at": db_feedback.responded_at.isoformat() if db_feedback.responded_at else None, + "responded_by": db_feedback.responded_by, + "created_at": db_feedback.created_at.isoformat() if db_feedback.created_at else None, + "updated_at": db_feedback.updated_at.isoformat() if db_feedback.updated_at else None, + } diff --git a/backend-lehrer/classroom_engine/repository_homework.py b/backend-lehrer/classroom_engine/repository_homework.py new file mode 100644 index 0000000..25e07b9 --- /dev/null +++ b/backend-lehrer/classroom_engine/repository_homework.py @@ -0,0 +1,382 @@ +""" +Homework & Material Repositories. + +CRUD-Operationen fuer Hausaufgaben (Feature f20) und Phasen-Materialien (Feature f19). +""" +from datetime import datetime +from typing import Optional, List + +from sqlalchemy.orm import Session as DBSession + +from .db_models import ( + HomeworkDB, HomeworkStatusEnum, PhaseMaterialDB, MaterialTypeEnum, +) +from .models import ( + Homework, HomeworkStatus, PhaseMaterial, MaterialType, +) + + +class HomeworkRepository: + """Repository fuer Hausaufgaben-Tracking (Feature f20).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create(self, homework: Homework) -> HomeworkDB: + """Erstellt eine neue Hausaufgabe.""" + db_homework = HomeworkDB( + id=homework.homework_id, + teacher_id=homework.teacher_id, + class_id=homework.class_id, + subject=homework.subject, + title=homework.title, + description=homework.description, + session_id=homework.session_id, + due_date=homework.due_date, + status=HomeworkStatusEnum(homework.status.value), + ) + self.db.add(db_homework) + self.db.commit() + self.db.refresh(db_homework) + return db_homework + + # ==================== READ ==================== + + def get_by_id(self, homework_id: str) -> Optional[HomeworkDB]: + """Holt eine Hausaufgabe nach ID.""" + return self.db.query(HomeworkDB).filter( + HomeworkDB.id == homework_id + ).first() + + def get_by_teacher( + self, + teacher_id: str, + status: Optional[str] = None, + limit: int = 50 + ) -> List[HomeworkDB]: + """Holt alle Hausaufgaben eines Lehrers.""" + query = self.db.query(HomeworkDB).filter( + HomeworkDB.teacher_id == teacher_id + ) + if status: + query = query.filter(HomeworkDB.status == HomeworkStatusEnum(status)) + return query.order_by( + HomeworkDB.due_date.asc().nullslast(), + HomeworkDB.created_at.desc() + ).limit(limit).all() + + def get_by_class( + self, + class_id: str, + teacher_id: str, + include_completed: bool = False, + limit: int = 20 + ) -> List[HomeworkDB]: + """Holt alle Hausaufgaben einer Klasse.""" + query = self.db.query(HomeworkDB).filter( + HomeworkDB.class_id == class_id, + HomeworkDB.teacher_id == teacher_id + ) + if not include_completed: + query = query.filter(HomeworkDB.status != HomeworkStatusEnum.COMPLETED) + return query.order_by( + HomeworkDB.due_date.asc().nullslast(), + HomeworkDB.created_at.desc() + ).limit(limit).all() + + def get_by_session(self, session_id: str) -> List[HomeworkDB]: + """Holt alle Hausaufgaben einer Session.""" + return self.db.query(HomeworkDB).filter( + HomeworkDB.session_id == session_id + ).order_by(HomeworkDB.created_at.desc()).all() + + def get_pending( + self, + teacher_id: str, + days_ahead: int = 7 + ) -> List[HomeworkDB]: + """Holt anstehende Hausaufgaben der naechsten X Tage.""" + from datetime import timedelta + cutoff = datetime.utcnow() + timedelta(days=days_ahead) + return self.db.query(HomeworkDB).filter( + HomeworkDB.teacher_id == teacher_id, + HomeworkDB.status.in_([HomeworkStatusEnum.ASSIGNED, HomeworkStatusEnum.IN_PROGRESS]), + HomeworkDB.due_date <= cutoff + ).order_by(HomeworkDB.due_date.asc()).all() + + # ==================== UPDATE ==================== + + def update_status( + self, + homework_id: str, + status: HomeworkStatus + ) -> Optional[HomeworkDB]: + """Aktualisiert den Status einer Hausaufgabe.""" + db_homework = self.get_by_id(homework_id) + if not db_homework: + return None + + db_homework.status = HomeworkStatusEnum(status.value) + self.db.commit() + self.db.refresh(db_homework) + return db_homework + + def update(self, homework: Homework) -> Optional[HomeworkDB]: + """Aktualisiert eine Hausaufgabe.""" + db_homework = self.get_by_id(homework.homework_id) + if not db_homework: + return None + + db_homework.title = homework.title + db_homework.description = homework.description + db_homework.due_date = homework.due_date + db_homework.status = HomeworkStatusEnum(homework.status.value) + + self.db.commit() + self.db.refresh(db_homework) + return db_homework + + # ==================== DELETE ==================== + + def delete(self, homework_id: str) -> bool: + """Loescht eine Hausaufgabe.""" + db_homework = self.get_by_id(homework_id) + if not db_homework: + return False + + self.db.delete(db_homework) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dataclass(self, db_homework: HomeworkDB) -> Homework: + """Konvertiert DB-Model zu Dataclass.""" + return Homework( + homework_id=db_homework.id, + teacher_id=db_homework.teacher_id, + class_id=db_homework.class_id, + subject=db_homework.subject, + title=db_homework.title, + description=db_homework.description or "", + session_id=db_homework.session_id, + due_date=db_homework.due_date, + status=HomeworkStatus(db_homework.status.value), + created_at=db_homework.created_at, + updated_at=db_homework.updated_at, + ) + + +class MaterialRepository: + """Repository fuer Phasen-Materialien (Feature f19).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create(self, material: PhaseMaterial) -> PhaseMaterialDB: + """Erstellt ein neues Material.""" + db_material = PhaseMaterialDB( + id=material.material_id, + teacher_id=material.teacher_id, + title=material.title, + material_type=MaterialTypeEnum(material.material_type.value), + url=material.url, + description=material.description, + phase=material.phase, + subject=material.subject, + grade_level=material.grade_level, + tags=material.tags, + is_public=material.is_public, + usage_count=material.usage_count, + session_id=material.session_id, + ) + self.db.add(db_material) + self.db.commit() + self.db.refresh(db_material) + return db_material + + # ==================== READ ==================== + + def get_by_id(self, material_id: str) -> Optional[PhaseMaterialDB]: + """Holt ein Material nach ID.""" + return self.db.query(PhaseMaterialDB).filter( + PhaseMaterialDB.id == material_id + ).first() + + def get_by_teacher( + self, + teacher_id: str, + phase: Optional[str] = None, + subject: Optional[str] = None, + limit: int = 50 + ) -> List[PhaseMaterialDB]: + """Holt alle Materialien eines Lehrers.""" + query = self.db.query(PhaseMaterialDB).filter( + PhaseMaterialDB.teacher_id == teacher_id + ) + if phase: + query = query.filter(PhaseMaterialDB.phase == phase) + if subject: + query = query.filter(PhaseMaterialDB.subject == subject) + + return query.order_by( + PhaseMaterialDB.usage_count.desc(), + PhaseMaterialDB.created_at.desc() + ).limit(limit).all() + + def get_by_phase( + self, + phase: str, + teacher_id: str, + include_public: bool = True + ) -> List[PhaseMaterialDB]: + """Holt alle Materialien fuer eine bestimmte Phase.""" + if include_public: + return self.db.query(PhaseMaterialDB).filter( + PhaseMaterialDB.phase == phase, + (PhaseMaterialDB.teacher_id == teacher_id) | + (PhaseMaterialDB.is_public == True) + ).order_by( + PhaseMaterialDB.usage_count.desc() + ).all() + else: + return self.db.query(PhaseMaterialDB).filter( + PhaseMaterialDB.phase == phase, + PhaseMaterialDB.teacher_id == teacher_id + ).order_by( + PhaseMaterialDB.created_at.desc() + ).all() + + def get_by_session(self, session_id: str) -> List[PhaseMaterialDB]: + """Holt alle Materialien einer Session.""" + return self.db.query(PhaseMaterialDB).filter( + PhaseMaterialDB.session_id == session_id + ).order_by(PhaseMaterialDB.phase, PhaseMaterialDB.created_at).all() + + def get_public_materials( + self, + phase: Optional[str] = None, + subject: Optional[str] = None, + limit: int = 20 + ) -> List[PhaseMaterialDB]: + """Holt oeffentliche Materialien.""" + query = self.db.query(PhaseMaterialDB).filter( + PhaseMaterialDB.is_public == True + ) + if phase: + query = query.filter(PhaseMaterialDB.phase == phase) + if subject: + query = query.filter(PhaseMaterialDB.subject == subject) + + return query.order_by( + PhaseMaterialDB.usage_count.desc() + ).limit(limit).all() + + def search_by_tags( + self, + tags: List[str], + teacher_id: Optional[str] = None + ) -> List[PhaseMaterialDB]: + """Sucht Materialien nach Tags.""" + query = self.db.query(PhaseMaterialDB) + if teacher_id: + query = query.filter( + (PhaseMaterialDB.teacher_id == teacher_id) | + (PhaseMaterialDB.is_public == True) + ) + else: + query = query.filter(PhaseMaterialDB.is_public == True) + + # Filter by tags - vereinfachte Implementierung + results = [] + for material in query.all(): + if material.tags and any(tag in material.tags for tag in tags): + results.append(material) + return results[:50] + + # ==================== UPDATE ==================== + + def update(self, material: PhaseMaterial) -> Optional[PhaseMaterialDB]: + """Aktualisiert ein Material.""" + db_material = self.get_by_id(material.material_id) + if not db_material: + return None + + db_material.title = material.title + db_material.material_type = MaterialTypeEnum(material.material_type.value) + db_material.url = material.url + db_material.description = material.description + db_material.phase = material.phase + db_material.subject = material.subject + db_material.grade_level = material.grade_level + db_material.tags = material.tags + db_material.is_public = material.is_public + + self.db.commit() + self.db.refresh(db_material) + return db_material + + def increment_usage(self, material_id: str) -> Optional[PhaseMaterialDB]: + """Erhoeht den Usage-Counter eines Materials.""" + db_material = self.get_by_id(material_id) + if not db_material: + return None + + db_material.usage_count += 1 + self.db.commit() + self.db.refresh(db_material) + return db_material + + def attach_to_session( + self, + material_id: str, + session_id: str + ) -> Optional[PhaseMaterialDB]: + """Verknuepft ein Material mit einer Session.""" + db_material = self.get_by_id(material_id) + if not db_material: + return None + + db_material.session_id = session_id + db_material.usage_count += 1 + self.db.commit() + self.db.refresh(db_material) + return db_material + + # ==================== DELETE ==================== + + def delete(self, material_id: str) -> bool: + """Loescht ein Material.""" + db_material = self.get_by_id(material_id) + if not db_material: + return False + + self.db.delete(db_material) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dataclass(self, db_material: PhaseMaterialDB) -> PhaseMaterial: + """Konvertiert DB-Model zu Dataclass.""" + return PhaseMaterial( + material_id=db_material.id, + teacher_id=db_material.teacher_id, + title=db_material.title, + material_type=MaterialType(db_material.material_type.value), + url=db_material.url, + description=db_material.description or "", + phase=db_material.phase, + subject=db_material.subject or "", + grade_level=db_material.grade_level or "", + tags=db_material.tags or [], + is_public=db_material.is_public, + usage_count=db_material.usage_count, + session_id=db_material.session_id, + created_at=db_material.created_at, + updated_at=db_material.updated_at, + ) diff --git a/backend-lehrer/classroom_engine/repository_reflection.py b/backend-lehrer/classroom_engine/repository_reflection.py new file mode 100644 index 0000000..159fb5f --- /dev/null +++ b/backend-lehrer/classroom_engine/repository_reflection.py @@ -0,0 +1,315 @@ +""" +Reflection & Analytics Repositories. + +CRUD-Operationen fuer Lesson-Reflections und Analytics-Abfragen (Phase 5). +""" +from datetime import datetime +from typing import Optional, List, Dict, Any + +from sqlalchemy.orm import Session as DBSession + +from .db_models import LessonSessionDB, LessonPhaseEnum, LessonReflectionDB +from .analytics import ( + LessonReflection, SessionSummary, TeacherAnalytics, AnalyticsCalculator, +) + + +class ReflectionRepository: + """Repository fuer LessonReflection CRUD-Operationen.""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create(self, reflection: LessonReflection) -> LessonReflectionDB: + """Erstellt eine neue Reflection.""" + db_reflection = LessonReflectionDB( + id=reflection.reflection_id, + session_id=reflection.session_id, + teacher_id=reflection.teacher_id, + notes=reflection.notes, + overall_rating=reflection.overall_rating, + what_worked=reflection.what_worked, + improvements=reflection.improvements, + notes_for_next_lesson=reflection.notes_for_next_lesson, + ) + self.db.add(db_reflection) + self.db.commit() + self.db.refresh(db_reflection) + return db_reflection + + # ==================== READ ==================== + + def get_by_id(self, reflection_id: str) -> Optional[LessonReflectionDB]: + """Holt eine Reflection nach ID.""" + return self.db.query(LessonReflectionDB).filter( + LessonReflectionDB.id == reflection_id + ).first() + + def get_by_session(self, session_id: str) -> Optional[LessonReflectionDB]: + """Holt die Reflection einer Session.""" + return self.db.query(LessonReflectionDB).filter( + LessonReflectionDB.session_id == session_id + ).first() + + def get_by_teacher( + self, + teacher_id: str, + limit: int = 20, + offset: int = 0 + ) -> List[LessonReflectionDB]: + """Holt alle Reflections eines Lehrers.""" + return self.db.query(LessonReflectionDB).filter( + LessonReflectionDB.teacher_id == teacher_id + ).order_by( + LessonReflectionDB.created_at.desc() + ).offset(offset).limit(limit).all() + + # ==================== UPDATE ==================== + + def update(self, reflection: LessonReflection) -> Optional[LessonReflectionDB]: + """Aktualisiert eine Reflection.""" + db_reflection = self.get_by_id(reflection.reflection_id) + if not db_reflection: + return None + + db_reflection.notes = reflection.notes + db_reflection.overall_rating = reflection.overall_rating + db_reflection.what_worked = reflection.what_worked + db_reflection.improvements = reflection.improvements + db_reflection.notes_for_next_lesson = reflection.notes_for_next_lesson + + self.db.commit() + self.db.refresh(db_reflection) + return db_reflection + + # ==================== DELETE ==================== + + def delete(self, reflection_id: str) -> bool: + """Loescht eine Reflection.""" + db_reflection = self.get_by_id(reflection_id) + if not db_reflection: + return False + + self.db.delete(db_reflection) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dataclass(self, db_reflection: LessonReflectionDB) -> LessonReflection: + """Konvertiert DB-Model zu Dataclass.""" + return LessonReflection( + reflection_id=db_reflection.id, + session_id=db_reflection.session_id, + teacher_id=db_reflection.teacher_id, + notes=db_reflection.notes or "", + overall_rating=db_reflection.overall_rating, + what_worked=db_reflection.what_worked or [], + improvements=db_reflection.improvements or [], + notes_for_next_lesson=db_reflection.notes_for_next_lesson or "", + created_at=db_reflection.created_at, + updated_at=db_reflection.updated_at, + ) + + +class AnalyticsRepository: + """Repository fuer Analytics-Abfragen.""" + + def __init__(self, db: DBSession): + self.db = db + + def get_session_summary(self, session_id: str) -> Optional[SessionSummary]: + """ + Berechnet die Summary einer abgeschlossenen Session. + + Args: + session_id: ID der Session + + Returns: + SessionSummary oder None wenn Session nicht gefunden + """ + db_session = self.db.query(LessonSessionDB).filter( + LessonSessionDB.id == session_id + ).first() + + if not db_session: + return None + + # Session-Daten zusammenstellen + session_data = { + "session_id": db_session.id, + "teacher_id": db_session.teacher_id, + "class_id": db_session.class_id, + "subject": db_session.subject, + "topic": db_session.topic, + "lesson_started_at": db_session.lesson_started_at, + "lesson_ended_at": db_session.lesson_ended_at, + "phase_durations": db_session.phase_durations or {}, + } + + # Phase History aus DB oder JSON + phase_history = db_session.phase_history or [] + + # Summary berechnen + return AnalyticsCalculator.calculate_session_summary( + session_data, phase_history + ) + + def get_teacher_analytics( + self, + teacher_id: str, + period_start: Optional[datetime] = None, + period_end: Optional[datetime] = None + ) -> TeacherAnalytics: + """ + Berechnet aggregierte Statistiken fuer einen Lehrer. + + Args: + teacher_id: ID des Lehrers + period_start: Beginn des Zeitraums (default: 30 Tage zurueck) + period_end: Ende des Zeitraums (default: jetzt) + + Returns: + TeacherAnalytics mit aggregierten Statistiken + """ + from datetime import timedelta + + if not period_end: + period_end = datetime.utcnow() + if not period_start: + period_start = period_end - timedelta(days=30) + + # Sessions im Zeitraum abfragen + sessions_query = self.db.query(LessonSessionDB).filter( + LessonSessionDB.teacher_id == teacher_id, + LessonSessionDB.lesson_started_at >= period_start, + LessonSessionDB.lesson_started_at <= period_end + ).all() + + # Sessions zu Dictionaries konvertieren + sessions_data = [] + for db_session in sessions_query: + sessions_data.append({ + "session_id": db_session.id, + "teacher_id": db_session.teacher_id, + "class_id": db_session.class_id, + "subject": db_session.subject, + "topic": db_session.topic, + "lesson_started_at": db_session.lesson_started_at, + "lesson_ended_at": db_session.lesson_ended_at, + "phase_durations": db_session.phase_durations or {}, + "phase_history": db_session.phase_history or [], + }) + + return AnalyticsCalculator.calculate_teacher_analytics( + sessions_data, period_start, period_end + ) + + def get_phase_duration_trends( + self, + teacher_id: str, + phase: str, + limit: int = 20 + ) -> List[Dict[str, Any]]: + """ + Gibt die Dauer-Trends fuer eine bestimmte Phase zurueck. + + Args: + teacher_id: ID des Lehrers + phase: Phasen-ID (einstieg, erarbeitung, etc.) + limit: Max Anzahl der Datenpunkte + + Returns: + Liste von Datenpunkten [{date, planned, actual, difference}] + """ + sessions = self.db.query(LessonSessionDB).filter( + LessonSessionDB.teacher_id == teacher_id, + LessonSessionDB.current_phase == LessonPhaseEnum.ENDED + ).order_by( + LessonSessionDB.lesson_ended_at.desc() + ).limit(limit).all() + + trends = [] + for db_session in sessions: + history = db_session.phase_history or [] + for entry in history: + if entry.get("phase") == phase: + planned = (db_session.phase_durations or {}).get(phase, 0) * 60 + actual = entry.get("duration_seconds", 0) or 0 + trends.append({ + "date": db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None, + "session_id": db_session.id, + "subject": db_session.subject, + "planned_seconds": planned, + "actual_seconds": actual, + "difference_seconds": actual - planned, + }) + break + + return list(reversed(trends)) # Chronologisch sortieren + + def get_overtime_analysis( + self, + teacher_id: str, + limit: int = 30 + ) -> Dict[str, Any]: + """ + Analysiert Overtime-Muster. + + Args: + teacher_id: ID des Lehrers + limit: Anzahl der zu analysierenden Sessions + + Returns: + Dict mit Overtime-Statistiken pro Phase + """ + sessions = self.db.query(LessonSessionDB).filter( + LessonSessionDB.teacher_id == teacher_id, + LessonSessionDB.current_phase == LessonPhaseEnum.ENDED + ).order_by( + LessonSessionDB.lesson_ended_at.desc() + ).limit(limit).all() + + phase_overtime: Dict[str, List[int]] = { + "einstieg": [], + "erarbeitung": [], + "sicherung": [], + "transfer": [], + "reflexion": [], + } + + for db_session in sessions: + history = db_session.phase_history or [] + phase_durations = db_session.phase_durations or {} + + for entry in history: + phase = entry.get("phase", "") + if phase in phase_overtime: + planned = phase_durations.get(phase, 0) * 60 + actual = entry.get("duration_seconds", 0) or 0 + overtime = max(0, actual - planned) + phase_overtime[phase].append(overtime) + + # Statistiken berechnen + result = {} + for phase, overtimes in phase_overtime.items(): + if overtimes: + result[phase] = { + "count": len([o for o in overtimes if o > 0]), + "total": len(overtimes), + "avg_overtime_seconds": sum(overtimes) / len(overtimes), + "max_overtime_seconds": max(overtimes), + "overtime_percentage": len([o for o in overtimes if o > 0]) / len(overtimes) * 100, + } + else: + result[phase] = { + "count": 0, + "total": 0, + "avg_overtime_seconds": 0, + "max_overtime_seconds": 0, + "overtime_percentage": 0, + } + + return result diff --git a/backend-lehrer/classroom_engine/repository_session.py b/backend-lehrer/classroom_engine/repository_session.py new file mode 100644 index 0000000..1165d33 --- /dev/null +++ b/backend-lehrer/classroom_engine/repository_session.py @@ -0,0 +1,248 @@ +""" +Session & Teacher Settings Repositories. + +CRUD-Operationen fuer LessonSessions und Lehrer-Einstellungen. +""" +from typing import Optional, List, Dict + +from sqlalchemy.orm import Session as DBSession + +from .db_models import ( + LessonSessionDB, LessonPhaseEnum, TeacherSettingsDB, +) +from .models import ( + LessonSession, LessonPhase, get_default_durations, +) + + +class SessionRepository: + """Repository fuer LessonSession CRUD-Operationen.""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create(self, session: LessonSession) -> LessonSessionDB: + """ + Erstellt eine neue Session in der Datenbank. + + Args: + session: LessonSession Dataclass + + Returns: + LessonSessionDB Model + """ + db_session = LessonSessionDB( + id=session.session_id, + teacher_id=session.teacher_id, + class_id=session.class_id, + subject=session.subject, + topic=session.topic, + current_phase=LessonPhaseEnum(session.current_phase.value), + is_paused=session.is_paused, + lesson_started_at=session.lesson_started_at, + lesson_ended_at=session.lesson_ended_at, + phase_started_at=session.phase_started_at, + pause_started_at=session.pause_started_at, + total_paused_seconds=session.total_paused_seconds, + phase_durations=session.phase_durations, + phase_history=session.phase_history, + notes=session.notes, + homework=session.homework, + ) + self.db.add(db_session) + self.db.commit() + self.db.refresh(db_session) + return db_session + + # ==================== READ ==================== + + def get_by_id(self, session_id: str) -> Optional[LessonSessionDB]: + """Holt eine Session nach ID.""" + return self.db.query(LessonSessionDB).filter( + LessonSessionDB.id == session_id + ).first() + + def get_active_by_teacher(self, teacher_id: str) -> List[LessonSessionDB]: + """Holt alle aktiven Sessions eines Lehrers.""" + return self.db.query(LessonSessionDB).filter( + LessonSessionDB.teacher_id == teacher_id, + LessonSessionDB.current_phase != LessonPhaseEnum.ENDED + ).all() + + def get_history_by_teacher( + self, + teacher_id: str, + limit: int = 20, + offset: int = 0 + ) -> List[LessonSessionDB]: + """Holt Session-History eines Lehrers (Feature f17).""" + return self.db.query(LessonSessionDB).filter( + LessonSessionDB.teacher_id == teacher_id, + LessonSessionDB.current_phase == LessonPhaseEnum.ENDED + ).order_by( + LessonSessionDB.lesson_ended_at.desc() + ).offset(offset).limit(limit).all() + + def get_by_class( + self, + class_id: str, + limit: int = 20 + ) -> List[LessonSessionDB]: + """Holt Sessions einer Klasse.""" + return self.db.query(LessonSessionDB).filter( + LessonSessionDB.class_id == class_id + ).order_by( + LessonSessionDB.created_at.desc() + ).limit(limit).all() + + # ==================== UPDATE ==================== + + def update(self, session: LessonSession) -> Optional[LessonSessionDB]: + """ + Aktualisiert eine bestehende Session. + + Args: + session: LessonSession Dataclass mit aktualisierten Werten + + Returns: + Aktualisierte LessonSessionDB oder None + """ + db_session = self.get_by_id(session.session_id) + if not db_session: + return None + + db_session.current_phase = LessonPhaseEnum(session.current_phase.value) + db_session.is_paused = session.is_paused + db_session.lesson_started_at = session.lesson_started_at + db_session.lesson_ended_at = session.lesson_ended_at + db_session.phase_started_at = session.phase_started_at + db_session.pause_started_at = session.pause_started_at + db_session.total_paused_seconds = session.total_paused_seconds + db_session.phase_durations = session.phase_durations + db_session.phase_history = session.phase_history + db_session.notes = session.notes + db_session.homework = session.homework + + self.db.commit() + self.db.refresh(db_session) + return db_session + + def update_notes( + self, + session_id: str, + notes: str, + homework: str + ) -> Optional[LessonSessionDB]: + """Aktualisiert nur Notizen und Hausaufgaben.""" + db_session = self.get_by_id(session_id) + if not db_session: + return None + + db_session.notes = notes + db_session.homework = homework + + self.db.commit() + self.db.refresh(db_session) + return db_session + + # ==================== DELETE ==================== + + def delete(self, session_id: str) -> bool: + """Loescht eine Session.""" + db_session = self.get_by_id(session_id) + if not db_session: + return False + + self.db.delete(db_session) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dataclass(self, db_session: LessonSessionDB) -> LessonSession: + """ + Konvertiert DB-Model zu Dataclass. + + Args: + db_session: LessonSessionDB Model + + Returns: + LessonSession Dataclass + """ + return LessonSession( + session_id=db_session.id, + teacher_id=db_session.teacher_id, + class_id=db_session.class_id, + subject=db_session.subject, + topic=db_session.topic, + current_phase=LessonPhase(db_session.current_phase.value), + phase_started_at=db_session.phase_started_at, + lesson_started_at=db_session.lesson_started_at, + lesson_ended_at=db_session.lesson_ended_at, + is_paused=db_session.is_paused, + pause_started_at=db_session.pause_started_at, + total_paused_seconds=db_session.total_paused_seconds or 0, + phase_durations=db_session.phase_durations or get_default_durations(), + phase_history=db_session.phase_history or [], + notes=db_session.notes or "", + homework=db_session.homework or "", + ) + + +class TeacherSettingsRepository: + """Repository fuer Lehrer-Einstellungen (Feature f16).""" + + def __init__(self, db: DBSession): + self.db = db + + def get_or_create(self, teacher_id: str) -> TeacherSettingsDB: + """Holt oder erstellt Einstellungen fuer einen Lehrer.""" + settings = self.db.query(TeacherSettingsDB).filter( + TeacherSettingsDB.teacher_id == teacher_id + ).first() + + if not settings: + settings = TeacherSettingsDB( + teacher_id=teacher_id, + default_phase_durations=get_default_durations(), + ) + self.db.add(settings) + self.db.commit() + self.db.refresh(settings) + + return settings + + def update_phase_durations( + self, + teacher_id: str, + durations: Dict[str, int] + ) -> TeacherSettingsDB: + """Aktualisiert die Standard-Phasendauern.""" + settings = self.get_or_create(teacher_id) + settings.default_phase_durations = durations + self.db.commit() + self.db.refresh(settings) + return settings + + def update_preferences( + self, + teacher_id: str, + audio_enabled: Optional[bool] = None, + high_contrast: Optional[bool] = None, + show_statistics: Optional[bool] = None + ) -> TeacherSettingsDB: + """Aktualisiert UI-Praeferenzen.""" + settings = self.get_or_create(teacher_id) + + if audio_enabled is not None: + settings.audio_enabled = audio_enabled + if high_contrast is not None: + settings.high_contrast = high_contrast + if show_statistics is not None: + settings.show_statistics = show_statistics + + self.db.commit() + self.db.refresh(settings) + return settings diff --git a/backend-lehrer/classroom_engine/repository_template.py b/backend-lehrer/classroom_engine/repository_template.py new file mode 100644 index 0000000..e97c16d --- /dev/null +++ b/backend-lehrer/classroom_engine/repository_template.py @@ -0,0 +1,167 @@ +""" +Template Repository. + +CRUD-Operationen fuer Stunden-Vorlagen (Feature f37). +""" +from typing import Optional, List + +from sqlalchemy.orm import Session as DBSession + +from .db_models import LessonTemplateDB +from .models import LessonTemplate, get_default_durations + + +class TemplateRepository: + """Repository fuer Stunden-Vorlagen (Feature f37).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create(self, template: LessonTemplate) -> LessonTemplateDB: + """Erstellt eine neue Vorlage.""" + db_template = LessonTemplateDB( + id=template.template_id, + teacher_id=template.teacher_id, + name=template.name, + description=template.description, + subject=template.subject, + grade_level=template.grade_level, + phase_durations=template.phase_durations, + default_topic=template.default_topic, + default_notes=template.default_notes, + is_public=template.is_public, + usage_count=template.usage_count, + ) + self.db.add(db_template) + self.db.commit() + self.db.refresh(db_template) + return db_template + + # ==================== READ ==================== + + def get_by_id(self, template_id: str) -> Optional[LessonTemplateDB]: + """Holt eine Vorlage nach ID.""" + return self.db.query(LessonTemplateDB).filter( + LessonTemplateDB.id == template_id + ).first() + + def get_by_teacher( + self, + teacher_id: str, + include_public: bool = True + ) -> List[LessonTemplateDB]: + """ + Holt alle Vorlagen eines Lehrers. + + Args: + teacher_id: ID des Lehrers + include_public: Auch oeffentliche Vorlagen anderer Lehrer einbeziehen + """ + if include_public: + return self.db.query(LessonTemplateDB).filter( + (LessonTemplateDB.teacher_id == teacher_id) | + (LessonTemplateDB.is_public == True) + ).order_by( + LessonTemplateDB.usage_count.desc() + ).all() + else: + return self.db.query(LessonTemplateDB).filter( + LessonTemplateDB.teacher_id == teacher_id + ).order_by( + LessonTemplateDB.created_at.desc() + ).all() + + def get_public_templates(self, limit: int = 20) -> List[LessonTemplateDB]: + """Holt oeffentliche Vorlagen, sortiert nach Beliebtheit.""" + return self.db.query(LessonTemplateDB).filter( + LessonTemplateDB.is_public == True + ).order_by( + LessonTemplateDB.usage_count.desc() + ).limit(limit).all() + + def get_by_subject( + self, + subject: str, + teacher_id: Optional[str] = None + ) -> List[LessonTemplateDB]: + """Holt Vorlagen fuer ein bestimmtes Fach.""" + query = self.db.query(LessonTemplateDB).filter( + LessonTemplateDB.subject == subject + ) + if teacher_id: + query = query.filter( + (LessonTemplateDB.teacher_id == teacher_id) | + (LessonTemplateDB.is_public == True) + ) + else: + query = query.filter(LessonTemplateDB.is_public == True) + + return query.order_by( + LessonTemplateDB.usage_count.desc() + ).all() + + # ==================== UPDATE ==================== + + def update(self, template: LessonTemplate) -> Optional[LessonTemplateDB]: + """Aktualisiert eine Vorlage.""" + db_template = self.get_by_id(template.template_id) + if not db_template: + return None + + db_template.name = template.name + db_template.description = template.description + db_template.subject = template.subject + db_template.grade_level = template.grade_level + db_template.phase_durations = template.phase_durations + db_template.default_topic = template.default_topic + db_template.default_notes = template.default_notes + db_template.is_public = template.is_public + + self.db.commit() + self.db.refresh(db_template) + return db_template + + def increment_usage(self, template_id: str) -> Optional[LessonTemplateDB]: + """Erhoeht den Usage-Counter einer Vorlage.""" + db_template = self.get_by_id(template_id) + if not db_template: + return None + + db_template.usage_count += 1 + self.db.commit() + self.db.refresh(db_template) + return db_template + + # ==================== DELETE ==================== + + def delete(self, template_id: str) -> bool: + """Loescht eine Vorlage.""" + db_template = self.get_by_id(template_id) + if not db_template: + return False + + self.db.delete(db_template) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dataclass(self, db_template: LessonTemplateDB) -> LessonTemplate: + """Konvertiert DB-Model zu Dataclass.""" + return LessonTemplate( + template_id=db_template.id, + teacher_id=db_template.teacher_id, + name=db_template.name, + description=db_template.description or "", + subject=db_template.subject or "", + grade_level=db_template.grade_level or "", + phase_durations=db_template.phase_durations or get_default_durations(), + default_topic=db_template.default_topic or "", + default_notes=db_template.default_notes or "", + is_public=db_template.is_public, + usage_count=db_template.usage_count, + created_at=db_template.created_at, + updated_at=db_template.updated_at, + ) diff --git a/klausur-service/backend/cv_cell_grid.py b/klausur-service/backend/cv_cell_grid.py index 56fd472..466565e 100644 --- a/klausur-service/backend/cv_cell_grid.py +++ b/klausur-service/backend/cv_cell_grid.py @@ -1,1675 +1,60 @@ """ Cell-grid construction (v2 + legacy), vocab conversion, and word-grid OCR. +Re-export hub — all public and private names remain importable from here +for backward compatibility. The actual implementations live in: + + cv_cell_grid_helpers.py — shared helpers (_heal_row_gaps, _is_artifact_row, ...) + cv_cell_grid_build.py — v2 hybrid grid (build_cell_grid_v2, _ocr_cell_crop) + cv_cell_grid_legacy.py — deprecated v1 grid (build_cell_grid, _ocr_single_cell) + cv_cell_grid_streaming.py — streaming variants (build_cell_grid_v2_streaming, ...) + cv_cell_grid_merge.py — row-merging logic (_merge_wrapped_rows, ...) + cv_cell_grid_vocab.py — vocab extraction (_cells_to_vocab_entries, build_word_grid) + Lizenz: Apache 2.0 (kommerziell nutzbar) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import logging -import re -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, Generator, List, Optional, Tuple - -import numpy as np - -from cv_vocab_types import PageRegion, RowGeometry -from cv_ocr_engines import ( - RAPIDOCR_AVAILABLE, - _RE_ALPHA, - _assign_row_words_to_columns, - _attach_example_sentences, - _clean_cell_text, - _clean_cell_text_lite, - _fix_phonetic_brackets, - _split_comma_entries, - _words_to_reading_order_text, - _words_to_spaced_text, - ocr_region_lighton, - ocr_region_rapid, - ocr_region_trocr, +# --- Helpers --- +from cv_cell_grid_helpers import ( # noqa: F401 + _MIN_WORD_CONF, + _compute_cell_padding, + _ensure_minimum_crop_size, + _heal_row_gaps, + _is_artifact_row, + _select_psm_for_column, ) -logger = logging.getLogger(__name__) - -try: - import cv2 -except ImportError: - cv2 = None # type: ignore[assignment] - -try: - from PIL import Image -except ImportError: - Image = None # type: ignore[assignment,misc] - -# Minimum OCR word confidence to keep (used across multiple functions) -_MIN_WORD_CONF = 30 - -# --------------------------------------------------------------------------- - -def _ocr_cell_crop( - row_idx: int, - col_idx: int, - row: RowGeometry, - col: PageRegion, - ocr_img: np.ndarray, - img_bgr: Optional[np.ndarray], - img_w: int, - img_h: int, - engine_name: str, - lang: str, - lang_map: Dict[str, str], -) -> Dict[str, Any]: - """OCR a single cell by cropping the exact column×row intersection. - - No padding beyond cell boundaries → no neighbour bleeding. - """ - # Display bbox: exact column × row intersection - disp_x = col.x - disp_y = row.y - disp_w = col.width - disp_h = row.height - - # Crop boundaries: add small internal padding (3px each side) to avoid - # clipping characters near column/row edges (e.g. parentheses, descenders). - # Stays within image bounds but may extend slightly beyond strict cell. - # 3px is small enough to avoid neighbour content at typical scan DPI (200-300). - _PAD = 3 - cx = max(0, disp_x - _PAD) - cy = max(0, disp_y - _PAD) - cx2 = min(img_w, disp_x + disp_w + _PAD) - cy2 = min(img_h, disp_y + disp_h + _PAD) - cw = cx2 - cx - ch = cy2 - cy - - empty_cell = { - 'cell_id': f"R{row_idx:02d}_C{col_idx}", - 'row_index': row_idx, - 'col_index': col_idx, - 'col_type': col.type, - 'text': '', - 'confidence': 0.0, - 'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h}, - 'bbox_pct': { - 'x': round(disp_x / img_w * 100, 2) if img_w else 0, - 'y': round(disp_y / img_h * 100, 2) if img_h else 0, - 'w': round(disp_w / img_w * 100, 2) if img_w else 0, - 'h': round(disp_h / img_h * 100, 2) if img_h else 0, - }, - 'ocr_engine': 'cell_crop_v2', - 'is_bold': False, - } - - if cw <= 0 or ch <= 0: - logger.debug("_ocr_cell_crop R%02d_C%d: zero-size crop (%dx%d)", row_idx, col_idx, cw, ch) - return empty_cell - - # --- Pixel-density check: skip truly empty cells --- - if ocr_img is not None: - crop = ocr_img[cy:cy + ch, cx:cx + cw] - if crop.size > 0: - dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size - if dark_ratio < 0.005: - logger.debug("_ocr_cell_crop R%02d_C%d: skip empty (dark_ratio=%.4f, crop=%dx%d)", - row_idx, col_idx, dark_ratio, cw, ch) - return empty_cell - - # --- Prepare crop for OCR --- - cell_lang = lang_map.get(col.type, lang) - psm = _select_psm_for_column(col.type, col.width, row.height) - text = '' - avg_conf = 0.0 - used_engine = 'cell_crop_v2' - - if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: - cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) - words = ocr_region_trocr(img_bgr, cell_region, - handwritten=(engine_name == "trocr-handwritten")) - elif engine_name == "lighton" and img_bgr is not None: - cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) - words = ocr_region_lighton(img_bgr, cell_region) - elif engine_name == "rapid" and img_bgr is not None: - # Upscale small BGR crops for RapidOCR. - # Cell crops typically have height 35-55px but width >300px. - # _ensure_minimum_crop_size only scales when EITHER dim < min_dim, - # using uniform scale → a 365×54 crop becomes ~1014×150 (scale ~2.78). - # For very short heights (< 80px), force 3× upscale for better OCR - # of small characters like periods, ellipsis, and phonetic symbols. - bgr_crop = img_bgr[cy:cy + ch, cx:cx + cw] - if bgr_crop.size == 0: - words = [] - else: - crop_h, crop_w = bgr_crop.shape[:2] - if crop_h < 80: - # Force 3× upscale for short rows — small chars need more pixels - scale = 3.0 - bgr_up = cv2.resize(bgr_crop, None, fx=scale, fy=scale, - interpolation=cv2.INTER_CUBIC) - else: - bgr_up = _ensure_minimum_crop_size(bgr_crop, min_dim=150, max_scale=3) - up_h, up_w = bgr_up.shape[:2] - scale_x = up_w / max(crop_w, 1) - scale_y = up_h / max(crop_h, 1) - was_scaled = (up_w != crop_w or up_h != crop_h) - logger.debug("_ocr_cell_crop R%02d_C%d: rapid %dx%d -> %dx%d (scale=%.1fx)", - row_idx, col_idx, crop_w, crop_h, up_w, up_h, scale_y) - tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) - words = ocr_region_rapid(bgr_up, tmp_region) - # Remap positions back to original image coords - if words and was_scaled: - for w in words: - w['left'] = int(w['left'] / scale_x) + cx - w['top'] = int(w['top'] / scale_y) + cy - w['width'] = int(w['width'] / scale_x) - w['height'] = int(w['height'] / scale_y) - elif words: - for w in words: - w['left'] += cx - w['top'] += cy - else: - # Tesseract: upscale tiny crops for better recognition - if ocr_img is not None: - crop_slice = ocr_img[cy:cy + ch, cx:cx + cw] - upscaled = _ensure_minimum_crop_size(crop_slice) - up_h, up_w = upscaled.shape[:2] - tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) - words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm) - # Remap word positions back to original image coordinates - if words and (up_w != cw or up_h != ch): - sx = cw / max(up_w, 1) - sy = ch / max(up_h, 1) - for w in words: - w['left'] = int(w['left'] * sx) + cx - w['top'] = int(w['top'] * sy) + cy - w['width'] = int(w['width'] * sx) - w['height'] = int(w['height'] * sy) - elif words: - for w in words: - w['left'] += cx - w['top'] += cy - else: - words = [] - - # Filter low-confidence words - if words: - words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] - - if words: - y_tol = max(15, ch) - text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) - avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) - logger.debug("_ocr_cell_crop R%02d_C%d: OCR raw text=%r conf=%.1f nwords=%d crop=%dx%d psm=%s engine=%s", - row_idx, col_idx, text, avg_conf, len(words), cw, ch, psm, engine_name) - else: - logger.debug("_ocr_cell_crop R%02d_C%d: OCR returned NO words (crop=%dx%d psm=%s engine=%s)", - row_idx, col_idx, cw, ch, psm, engine_name) - - # --- PSM 7 fallback for still-empty Tesseract cells --- - if not text.strip() and engine_name == "tesseract" and ocr_img is not None: - crop_slice = ocr_img[cy:cy + ch, cx:cx + cw] - upscaled = _ensure_minimum_crop_size(crop_slice) - up_h, up_w = upscaled.shape[:2] - tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) - psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7) - if psm7_words: - psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF] - if psm7_words: - p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10) - if p7_text.strip(): - text = p7_text - avg_conf = round( - sum(w['conf'] for w in psm7_words) / len(psm7_words), 1 - ) - used_engine = 'cell_crop_v2_psm7' - # Remap PSM7 word positions back to original image coords - if up_w != cw or up_h != ch: - sx = cw / max(up_w, 1) - sy = ch / max(up_h, 1) - for w in psm7_words: - w['left'] = int(w['left'] * sx) + cx - w['top'] = int(w['top'] * sy) + cy - w['width'] = int(w['width'] * sx) - w['height'] = int(w['height'] * sy) - else: - for w in psm7_words: - w['left'] += cx - w['top'] += cy - words = psm7_words - - # --- Noise filter --- - if text.strip(): - pre_filter = text - text = _clean_cell_text_lite(text) - if not text: - logger.debug("_ocr_cell_crop R%02d_C%d: _clean_cell_text_lite REMOVED %r", - row_idx, col_idx, pre_filter) - avg_conf = 0.0 - - result = dict(empty_cell) - result['text'] = text - result['confidence'] = avg_conf - result['ocr_engine'] = used_engine - - # Store individual word bounding boxes (absolute image coordinates) - # for pixel-accurate overlay positioning in the frontend. - if words and text.strip(): - result['word_boxes'] = [ - { - 'text': w.get('text', ''), - 'left': w['left'], - 'top': w['top'], - 'width': w['width'], - 'height': w['height'], - 'conf': w.get('conf', 0), - } - for w in words - if w.get('text', '').strip() - ] - - return result - - -# Threshold: columns narrower than this (% of image width) use single-cell -# crop OCR instead of full-page word assignment. -# -# Broad columns (>= threshold): Full-page Tesseract word assignment. -# Better for multi-word content (sentences, IPA brackets, punctuation). -# Examples: EN vocabulary, DE translation, example sentences. -# -# Narrow columns (< threshold): Isolated cell-crop OCR. -# Prevents neighbour bleeding from adjacent broad columns. -# Examples: page_ref, marker, numbering columns. -# -# 15% was empirically validated across vocab table scans with 3-5 columns. -# Typical broad columns: 20-40% width. Typical narrow columns: 3-12% width. -# The 15% boundary cleanly separates the two groups. -_NARROW_COL_THRESHOLD_PCT = 15.0 - - -def build_cell_grid_v2( - ocr_img: np.ndarray, - column_regions: List[PageRegion], - row_geometries: List[RowGeometry], - img_w: int, - img_h: int, - lang: str = "eng+deu", - ocr_engine: str = "auto", - img_bgr: Optional[np.ndarray] = None, - skip_heal_gaps: bool = False, -) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - """Hybrid Grid: full-page OCR for broad columns, cell-crop for narrow ones. - - Drop-in replacement for build_cell_grid() — same signature & return type. - - Strategy: - - Broad columns (>15% image width): Use pre-assigned full-page Tesseract - words (from row.words). Handles IPA brackets, punctuation, sentence - continuity correctly. - - Narrow columns (<15% image width): Use isolated cell-crop OCR to prevent - neighbour bleeding from adjacent broad columns. - """ - engine_name = "tesseract" - if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): - engine_name = ocr_engine - elif ocr_engine == "rapid" and RAPIDOCR_AVAILABLE: - engine_name = "rapid" - - logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}' (hybrid mode)") - - # Filter to content rows only - content_rows = [r for r in row_geometries if r.row_type == 'content'] - if not content_rows: - logger.warning("build_cell_grid_v2: no content rows found") - return [], [] - - # Filter phantom rows (word_count=0) and artifact rows - before = len(content_rows) - content_rows = [r for r in content_rows if r.word_count > 0] - skipped = before - len(content_rows) - if skipped > 0: - logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)") - if not content_rows: - logger.warning("build_cell_grid_v2: no content rows with words found") - return [], [] - - before_art = len(content_rows) - content_rows = [r for r in content_rows if not _is_artifact_row(r)] - artifact_skipped = before_art - len(content_rows) - if artifact_skipped > 0: - logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows") - if not content_rows: - logger.warning("build_cell_grid_v2: no content rows after artifact filtering") - return [], [] - - # Filter columns - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', - 'margin_bottom', 'margin_left', 'margin_right'} - relevant_cols = [c for c in column_regions if c.type not in _skip_types] - if not relevant_cols: - logger.warning("build_cell_grid_v2: no usable columns found") - return [], [] - - # Heal row gaps — use header/footer boundaries - content_rows.sort(key=lambda r: r.y) - header_rows = [r for r in row_geometries if r.row_type == 'header'] - footer_rows = [r for r in row_geometries if r.row_type == 'footer'] - if header_rows: - top_bound = max(r.y + r.height for r in header_rows) - else: - top_bound = content_rows[0].y - if footer_rows: - bottom_bound = min(r.y for r in footer_rows) - else: - bottom_bound = content_rows[-1].y + content_rows[-1].height - - # skip_heal_gaps: When True, keep cell positions at their exact row geometry - # positions without expanding to fill gaps from removed rows. Useful for - # overlay rendering where pixel-precise positioning matters more than - # full-coverage OCR crops. - if not skip_heal_gaps: - _heal_row_gaps(content_rows, top_bound=top_bound, bottom_bound=bottom_bound) - - relevant_cols.sort(key=lambda c: c.x) - - columns_meta = [ - {'index': ci, 'type': c.type, 'x': c.x, 'width': c.width} - for ci, c in enumerate(relevant_cols) - ] - - lang_map = { - 'column_en': 'eng', - 'column_de': 'deu', - 'column_example': 'eng+deu', - } - - # --- Classify columns as broad vs narrow --- - narrow_col_indices = set() - for ci, col in enumerate(relevant_cols): - col_pct = (col.width / img_w * 100) if img_w > 0 else 0 - if col_pct < _NARROW_COL_THRESHOLD_PCT: - narrow_col_indices.add(ci) - - broad_col_count = len(relevant_cols) - len(narrow_col_indices) - logger.info(f"build_cell_grid_v2: {broad_col_count} broad columns (full-page), " - f"{len(narrow_col_indices)} narrow columns (cell-crop)") - - # --- Phase 1: Broad columns via full-page word assignment --- - cells: List[Dict[str, Any]] = [] - - for row_idx, row in enumerate(content_rows): - # Assign full-page words to columns for this row - col_words = _assign_row_words_to_columns(row, relevant_cols) - - for col_idx, col in enumerate(relevant_cols): - if col_idx not in narrow_col_indices: - # BROAD column: use pre-assigned full-page words - words = col_words.get(col_idx, []) - # Filter low-confidence words - words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] - - # Single full-width column (box sub-session): preserve spacing - is_single_full_column = ( - len(relevant_cols) == 1 - and img_w > 0 - and relevant_cols[0].width / img_w > 0.9 - ) - - if words: - y_tol = max(15, row.height) - if is_single_full_column: - text = _words_to_spaced_text(words, y_tolerance_px=y_tol) - logger.info(f"R{row_idx:02d}: {len(words)} words, " - f"text={text!r:.100}") - else: - text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) - avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) - else: - text = '' - avg_conf = 0.0 - if is_single_full_column: - logger.info(f"R{row_idx:02d}: 0 words (row has " - f"{row.word_count} total, y={row.y}..{row.y+row.height})") - - # Apply noise filter — but NOT for single-column sub-sessions: - # 1. _clean_cell_text strips trailing non-alpha tokens (e.g. €0.50, - # £1, €2.50) which are valid content in box layouts. - # 2. _clean_cell_text joins tokens with single space, destroying - # the proportional spacing from _words_to_spaced_text. - if not is_single_full_column: - text = _clean_cell_text(text) - - cell = { - 'cell_id': f"R{row_idx:02d}_C{col_idx}", - 'row_index': row_idx, - 'col_index': col_idx, - 'col_type': col.type, - 'text': text, - 'confidence': avg_conf, - 'bbox_px': { - 'x': col.x, 'y': row.y, - 'w': col.width, 'h': row.height, - }, - 'bbox_pct': { - 'x': round(col.x / img_w * 100, 2) if img_w else 0, - 'y': round(row.y / img_h * 100, 2) if img_h else 0, - 'w': round(col.width / img_w * 100, 2) if img_w else 0, - 'h': round(row.height / img_h * 100, 2) if img_h else 0, - }, - 'ocr_engine': 'word_lookup', - 'is_bold': False, - } - # Store word bounding boxes for pixel-accurate overlay - if words and text.strip(): - cell['word_boxes'] = [ - { - 'text': w.get('text', ''), - 'left': w['left'], - 'top': w['top'], - 'width': w['width'], - 'height': w['height'], - 'conf': w.get('conf', 0), - } - for w in words - if w.get('text', '').strip() - ] - cells.append(cell) - - # --- Phase 2: Narrow columns via cell-crop OCR (parallel) --- - narrow_tasks = [] - for row_idx, row in enumerate(content_rows): - for col_idx, col in enumerate(relevant_cols): - if col_idx in narrow_col_indices: - narrow_tasks.append((row_idx, col_idx, row, col)) - - if narrow_tasks: - max_workers = 4 if engine_name == "tesseract" else 2 - with ThreadPoolExecutor(max_workers=max_workers) as pool: - futures = { - pool.submit( - _ocr_cell_crop, - ri, ci, row, col, - ocr_img, img_bgr, img_w, img_h, - engine_name, lang, lang_map, - ): (ri, ci) - for ri, ci, row, col in narrow_tasks - } - for future in as_completed(futures): - try: - cell = future.result() - cells.append(cell) - except Exception as e: - ri, ci = futures[future] - logger.error(f"build_cell_grid_v2: narrow cell R{ri:02d}_C{ci} failed: {e}") - - # Sort cells by (row_index, col_index) - cells.sort(key=lambda c: (c['row_index'], c['col_index'])) - - # Remove all-empty rows - rows_with_text: set = set() - for cell in cells: - if cell['text'].strip(): - rows_with_text.add(cell['row_index']) - before_filter = len(cells) - cells = [c for c in cells if c['row_index'] in rows_with_text] - empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1) - if empty_rows_removed > 0: - logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows") - - # Bold detection disabled: cell-level stroke-width analysis cannot - # distinguish bold from non-bold when cells contain mixed formatting - # (e.g. "cookie ['kuki]" — bold word + non-bold phonetics). - # TODO: word-level bold detection would require per-word bounding boxes. - - logger.info(f"build_cell_grid_v2: {len(cells)} cells from " - f"{len(content_rows)} rows × {len(relevant_cols)} columns, " - f"engine={engine_name} (hybrid)") - - return cells, columns_meta - - -def build_cell_grid_v2_streaming( - ocr_img: np.ndarray, - column_regions: List[PageRegion], - row_geometries: List[RowGeometry], - img_w: int, - img_h: int, - lang: str = "eng+deu", - ocr_engine: str = "auto", - img_bgr: Optional[np.ndarray] = None, -) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]: - """Streaming variant of build_cell_grid_v2 — yields each cell as OCR'd. - - Yields: - (cell_dict, columns_meta, total_cells) - """ - # Resolve engine — default to Tesseract for cell-first OCR. - # Tesseract excels at isolated text crops (binarized, upscaled). - # RapidOCR is optimized for full-page scene-text and produces artifacts - # on small cell crops (extra chars, missing punctuation, garbled IPA). - use_rapid = False - if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): - engine_name = ocr_engine - elif ocr_engine == "auto": - engine_name = "tesseract" - elif ocr_engine == "rapid": - if not RAPIDOCR_AVAILABLE: - logger.warning("RapidOCR requested but not available, falling back to Tesseract") - else: - use_rapid = True - engine_name = "rapid" if use_rapid else "tesseract" - else: - engine_name = "tesseract" - - content_rows = [r for r in row_geometries if r.row_type == 'content'] - if not content_rows: - return - - content_rows = [r for r in content_rows if r.word_count > 0] - if not content_rows: - return - - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', - 'margin_bottom', 'margin_left', 'margin_right'} - relevant_cols = [c for c in column_regions if c.type not in _skip_types] - if not relevant_cols: - return - - content_rows = [r for r in content_rows if not _is_artifact_row(r)] - if not content_rows: - return - - # Use header/footer boundaries for heal_row_gaps (same as build_cell_grid_v2) - content_rows.sort(key=lambda r: r.y) - header_rows = [r for r in row_geometries if r.row_type == 'header'] - footer_rows = [r for r in row_geometries if r.row_type == 'footer'] - if header_rows: - top_bound = max(r.y + r.height for r in header_rows) - else: - top_bound = content_rows[0].y - if footer_rows: - bottom_bound = min(r.y for r in footer_rows) - else: - bottom_bound = content_rows[-1].y + content_rows[-1].height - - _heal_row_gaps(content_rows, top_bound=top_bound, bottom_bound=bottom_bound) - - relevant_cols.sort(key=lambda c: c.x) - - columns_meta = [ - {'index': ci, 'type': c.type, 'x': c.x, 'width': c.width} - for ci, c in enumerate(relevant_cols) - ] - - lang_map = { - 'column_en': 'eng', - 'column_de': 'deu', - 'column_example': 'eng+deu', - } - - total_cells = len(content_rows) * len(relevant_cols) - - for row_idx, row in enumerate(content_rows): - for col_idx, col in enumerate(relevant_cols): - cell = _ocr_cell_crop( - row_idx, col_idx, row, col, - ocr_img, img_bgr, img_w, img_h, - engine_name, lang, lang_map, - ) - yield cell, columns_meta, total_cells - - -# --------------------------------------------------------------------------- -# Narrow-column OCR helpers (Proposal B) — DEPRECATED (kept for legacy build_cell_grid) -# --------------------------------------------------------------------------- - -def _compute_cell_padding(col_width: int, img_w: int) -> int: - """Adaptive padding for OCR crops based on column width. - - Narrow columns (page_ref, marker) need more surrounding context so - Tesseract can segment characters correctly. Wide columns keep the - minimal 4 px padding to avoid pulling in neighbours. - """ - col_pct = col_width / img_w * 100 if img_w > 0 else 100 - if col_pct < 5: - return max(20, col_width // 2) - if col_pct < 10: - return max(12, col_width // 4) - if col_pct < 15: - return 8 - return 4 - - -def _ensure_minimum_crop_size(crop: np.ndarray, min_dim: int = 150, - max_scale: int = 3) -> np.ndarray: - """Upscale tiny crops so Tesseract gets enough pixel data. - - If either dimension is below *min_dim*, the crop is bicubic-upscaled - so the smallest dimension reaches *min_dim* (capped at *max_scale* ×). - """ - h, w = crop.shape[:2] - if h >= min_dim and w >= min_dim: - return crop - scale = min(max_scale, max(min_dim / max(h, 1), min_dim / max(w, 1))) - if scale <= 1.0: - return crop - new_w = int(w * scale) - new_h = int(h * scale) - return cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_CUBIC) - - -def _select_psm_for_column(col_type: str, col_width: int, - row_height: int) -> int: - """Choose the best Tesseract PSM for a given column geometry. - - - page_ref columns are almost always single short tokens → PSM 8 - - Very narrow or short cells → PSM 7 (single text line) - - Everything else → PSM 6 (uniform block) - """ - if col_type in ('page_ref', 'marker'): - return 8 # single word - if col_width < 100 or row_height < 30: - return 7 # single line - return 6 # uniform block - - -def _ocr_single_cell( - row_idx: int, - col_idx: int, - row: RowGeometry, - col: PageRegion, - ocr_img: np.ndarray, - img_bgr: Optional[np.ndarray], - img_w: int, - img_h: int, - use_rapid: bool, - engine_name: str, - lang: str, - lang_map: Dict[str, str], - preassigned_words: Optional[List[Dict]] = None, -) -> Dict[str, Any]: - """Populate a single cell (column x row intersection) via word lookup.""" - # Display bbox: exact column × row intersection (no padding) - disp_x = col.x - disp_y = row.y - disp_w = col.width - disp_h = row.height - - # OCR crop: adaptive padding — narrow columns get more context - pad = _compute_cell_padding(col.width, img_w) - cell_x = max(0, col.x - pad) - cell_y = max(0, row.y - pad) - cell_w = min(col.width + 2 * pad, img_w - cell_x) - cell_h = min(row.height + 2 * pad, img_h - cell_y) - is_narrow = (col.width / img_w * 100) < 15 if img_w > 0 else False - - if disp_w <= 0 or disp_h <= 0: - return { - 'cell_id': f"R{row_idx:02d}_C{col_idx}", - 'row_index': row_idx, - 'col_index': col_idx, - 'col_type': col.type, - 'text': '', - 'confidence': 0.0, - 'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height}, - 'bbox_pct': { - 'x': round(col.x / img_w * 100, 2), - 'y': round(row.y / img_h * 100, 2), - 'w': round(col.width / img_w * 100, 2), - 'h': round(row.height / img_h * 100, 2), - }, - 'ocr_engine': 'word_lookup', - } - - # --- PRIMARY: Word-lookup from full-page Tesseract --- - words = preassigned_words if preassigned_words is not None else [] - used_engine = 'word_lookup' - - # Filter low-confidence words (OCR noise from images/artifacts). - # Tesseract gives low confidence to misread image edges, borders, - # and other non-text elements. - if words: - words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] - - if words: - # Use row height as Y-tolerance so all words within a single row - # are grouped onto one line (avoids splitting e.g. "Maus, Mäuse" - # across two lines due to slight vertical offset). - y_tol = max(15, row.height) - text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) - avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) - else: - text = '' - avg_conf = 0.0 - - # --- FALLBACK: Cell-OCR for empty cells --- - # Full-page Tesseract can miss small or isolated words (e.g. "Ei"). - # Re-run OCR on the cell crop to catch what word-lookup missed. - # To avoid wasting time on truly empty cells, check pixel density first: - # only run Tesseract if the cell crop contains enough dark pixels to - # plausibly contain text. - _run_fallback = False - if not text.strip() and cell_w > 0 and cell_h > 0: - if ocr_img is not None: - crop = ocr_img[cell_y:cell_y + cell_h, cell_x:cell_x + cell_w] - if crop.size > 0: - # Threshold: pixels darker than 180 (on 0-255 grayscale). - # Use 0.5% to catch even small text like "Ei" (2 chars) - # in an otherwise empty cell. - dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size - _run_fallback = dark_ratio > 0.005 - if _run_fallback: - # For narrow columns, upscale the crop before OCR - if is_narrow and ocr_img is not None: - _crop_slice = ocr_img[cell_y:cell_y + cell_h, cell_x:cell_x + cell_w] - _upscaled = _ensure_minimum_crop_size(_crop_slice) - if _upscaled is not _crop_slice: - # Build a temporary full-size image with the upscaled crop - # placed at origin so ocr_region can crop it cleanly. - _up_h, _up_w = _upscaled.shape[:2] - _tmp_region = PageRegion( - type=col.type, x=0, y=0, width=_up_w, height=_up_h, - ) - _cell_psm = _select_psm_for_column(col.type, col.width, row.height) - cell_lang = lang_map.get(col.type, lang) - fallback_words = ocr_region(_upscaled, _tmp_region, - lang=cell_lang, psm=_cell_psm) - # Remap word positions back to original image coordinates - _sx = cell_w / max(_up_w, 1) - _sy = cell_h / max(_up_h, 1) - for _fw in (fallback_words or []): - _fw['left'] = int(_fw['left'] * _sx) + cell_x - _fw['top'] = int(_fw['top'] * _sy) + cell_y - _fw['width'] = int(_fw['width'] * _sx) - _fw['height'] = int(_fw['height'] * _sy) - else: - # No upscaling needed, use adaptive PSM - cell_region = PageRegion( - type=col.type, x=cell_x, y=cell_y, - width=cell_w, height=cell_h, - ) - _cell_psm = _select_psm_for_column(col.type, col.width, row.height) - cell_lang = lang_map.get(col.type, lang) - fallback_words = ocr_region(ocr_img, cell_region, - lang=cell_lang, psm=_cell_psm) - else: - cell_region = PageRegion( - type=col.type, - x=cell_x, y=cell_y, - width=cell_w, height=cell_h, - ) - if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: - fallback_words = ocr_region_trocr(img_bgr, cell_region, handwritten=(engine_name == "trocr-handwritten")) - elif engine_name == "lighton" and img_bgr is not None: - fallback_words = ocr_region_lighton(img_bgr, cell_region) - elif use_rapid and img_bgr is not None: - fallback_words = ocr_region_rapid(img_bgr, cell_region) - else: - _cell_psm = _select_psm_for_column(col.type, col.width, row.height) - cell_lang = lang_map.get(col.type, lang) - fallback_words = ocr_region(ocr_img, cell_region, - lang=cell_lang, psm=_cell_psm) - - if fallback_words: - # Apply same confidence filter to fallback words - fallback_words = [w for w in fallback_words if w.get('conf', 0) >= _MIN_WORD_CONF] - if fallback_words: - fb_avg_h = sum(w['height'] for w in fallback_words) / len(fallback_words) - fb_y_tol = max(10, int(fb_avg_h * 0.5)) - fb_text = _words_to_reading_order_text(fallback_words, y_tolerance_px=fb_y_tol) - if fb_text.strip(): - text = fb_text - avg_conf = round( - sum(w['conf'] for w in fallback_words) / len(fallback_words), 1 - ) - used_engine = 'cell_ocr_fallback' - - # --- SECONDARY FALLBACK: PSM=7 (single line) for still-empty cells --- - if not text.strip() and _run_fallback and not use_rapid: - _fb_region = PageRegion( - type=col.type, x=cell_x, y=cell_y, - width=cell_w, height=cell_h, - ) - cell_lang = lang_map.get(col.type, lang) - psm7_words = ocr_region(ocr_img, _fb_region, lang=cell_lang, psm=7) - if psm7_words: - psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF] - if psm7_words: - p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10) - if p7_text.strip(): - text = p7_text - avg_conf = round( - sum(w['conf'] for w in psm7_words) / len(psm7_words), 1 - ) - used_engine = 'cell_ocr_psm7' - - # --- TERTIARY FALLBACK: Row-strip re-OCR for narrow columns --- - # If a narrow cell is still empty, OCR the entire row strip with - # RapidOCR (which handles small text better) and assign words by - # X-position overlap with this column. - if not text.strip() and is_narrow and img_bgr is not None: - row_region = PageRegion( - type='_row_strip', x=0, y=row.y, - width=img_w, height=row.height, - ) - strip_words = ocr_region_rapid(img_bgr, row_region) - if strip_words: - # Filter to words overlapping this column's X-range - col_left = col.x - col_right = col.x + col.width - col_words = [] - for sw in strip_words: - sw_left = sw.get('left', 0) - sw_right = sw_left + sw.get('width', 0) - overlap = max(0, min(sw_right, col_right) - max(sw_left, col_left)) - if overlap > sw.get('width', 1) * 0.3: - col_words.append(sw) - if col_words: - col_words = [w for w in col_words if w.get('conf', 0) >= _MIN_WORD_CONF] - if col_words: - rs_text = _words_to_reading_order_text(col_words, y_tolerance_px=row.height) - if rs_text.strip(): - text = rs_text - avg_conf = round( - sum(w['conf'] for w in col_words) / len(col_words), 1 - ) - used_engine = 'row_strip_rapid' - - # --- NOISE FILTER: clear cells that contain only OCR artifacts --- - if text.strip(): - text = _clean_cell_text(text) - if not text: - avg_conf = 0.0 - - return { - 'cell_id': f"R{row_idx:02d}_C{col_idx}", - 'row_index': row_idx, - 'col_index': col_idx, - 'col_type': col.type, - 'text': text, - 'confidence': avg_conf, - 'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h}, - 'bbox_pct': { - 'x': round(disp_x / img_w * 100, 2), - 'y': round(disp_y / img_h * 100, 2), - 'w': round(disp_w / img_w * 100, 2), - 'h': round(disp_h / img_h * 100, 2), - }, - 'ocr_engine': used_engine, - } - - -def _is_artifact_row(row: RowGeometry) -> bool: - """Return True if this row contains only scan artifacts, not real text. - - Artifact rows (scanner shadows, noise) typically produce only single-character - detections. A real content row always has at least one token with 2+ characters. - """ - if row.word_count == 0: - return True - texts = [w.get('text', '').strip() for w in row.words] - return all(len(t) <= 1 for t in texts) - - -def _heal_row_gaps( - rows: List[RowGeometry], - top_bound: int, - bottom_bound: int, -) -> None: - """Expand row y/height to fill vertical gaps caused by removed adjacent rows. - - After filtering out empty or artifact rows, remaining content rows may have - gaps between them where the removed rows used to be. This function mutates - each row to extend upward/downward to the midpoint of such gaps so that - OCR crops cover the full available content area. - - The first row always extends to top_bound; the last row to bottom_bound. - """ - if not rows: - return - rows.sort(key=lambda r: r.y) - n = len(rows) - orig = [(r.y, r.y + r.height) for r in rows] # snapshot before mutation - - for i, row in enumerate(rows): - # New top: midpoint between previous row's bottom and this row's top - if i == 0: - new_top = top_bound - else: - prev_bot = orig[i - 1][1] - my_top = orig[i][0] - gap = my_top - prev_bot - new_top = prev_bot + gap // 2 if gap > 1 else my_top - - # New bottom: midpoint between this row's bottom and next row's top - if i == n - 1: - new_bottom = bottom_bound - else: - my_bot = orig[i][1] - next_top = orig[i + 1][0] - gap = next_top - my_bot - new_bottom = my_bot + gap // 2 if gap > 1 else my_bot - - row.y = new_top - row.height = max(5, new_bottom - new_top) - - logger.debug( - f"_heal_row_gaps: {n} rows → y range [{rows[0].y}..{rows[-1].y + rows[-1].height}] " - f"(bounds: top={top_bound}, bottom={bottom_bound})" - ) - - -def build_cell_grid( - ocr_img: np.ndarray, - column_regions: List[PageRegion], - row_geometries: List[RowGeometry], - img_w: int, - img_h: int, - lang: str = "eng+deu", - ocr_engine: str = "auto", - img_bgr: Optional[np.ndarray] = None, -) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - """Generic Cell-Grid: Columns × Rows → cells with OCR text. - - This is the layout-agnostic foundation. Every column (except column_ignore) - is intersected with every content row to produce numbered cells. - - Args: - ocr_img: Binarized full-page image (for Tesseract). - column_regions: Classified columns from Step 3 (PageRegion list). - row_geometries: Rows from Step 4 (RowGeometry list). - img_w: Image width in pixels. - img_h: Image height in pixels. - lang: Default Tesseract language. - ocr_engine: 'tesseract', 'rapid', 'auto', 'trocr-printed', 'trocr-handwritten', or 'lighton'. - img_bgr: BGR color image (required for RapidOCR / TrOCR / LightOnOCR). - - Returns: - (cells, columns_meta) where cells is a list of cell dicts and - columns_meta describes the columns used. - """ - # Resolve engine choice - use_rapid = False - if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): - engine_name = ocr_engine - elif ocr_engine == "auto": - use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None - engine_name = "rapid" if use_rapid else "tesseract" - elif ocr_engine == "rapid": - if not RAPIDOCR_AVAILABLE: - logger.warning("RapidOCR requested but not available, falling back to Tesseract") - else: - use_rapid = True - engine_name = "rapid" if use_rapid else "tesseract" - else: - engine_name = "tesseract" - - logger.info(f"build_cell_grid: using OCR engine '{engine_name}'") - - # Filter to content rows only (skip header/footer) - content_rows = [r for r in row_geometries if r.row_type == 'content'] - if not content_rows: - logger.warning("build_cell_grid: no content rows found") - return [], [] - - # Filter phantom rows: rows with no Tesseract words assigned are - # inter-line whitespace gaps that would produce garbage OCR. - before = len(content_rows) - content_rows = [r for r in content_rows if r.word_count > 0] - skipped = before - len(content_rows) - if skipped > 0: - logger.info(f"build_cell_grid: skipped {skipped} phantom rows (word_count=0)") - if not content_rows: - logger.warning("build_cell_grid: no content rows with words found") - return [], [] - - # Use columns only — skip ignore, header, footer, page_ref - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - relevant_cols = [c for c in column_regions if c.type not in _skip_types] - if not relevant_cols: - logger.warning("build_cell_grid: no usable columns found") - return [], [] - - # Filter artifact rows: rows whose detected words are all single characters - # are caused by scanner shadows or noise, not real text. - before_art = len(content_rows) - content_rows = [r for r in content_rows if not _is_artifact_row(r)] - artifact_skipped = before_art - len(content_rows) - if artifact_skipped > 0: - logger.info(f"build_cell_grid: skipped {artifact_skipped} artifact rows (all single-char words)") - if not content_rows: - logger.warning("build_cell_grid: no content rows after artifact filtering") - return [], [] - - # Heal row gaps: rows removed above leave vertical gaps; expand adjacent rows - # to fill the space so OCR crops are not artificially narrow. - _heal_row_gaps( - content_rows, - top_bound=min(c.y for c in relevant_cols), - bottom_bound=max(c.y + c.height for c in relevant_cols), - ) - - # Sort columns left-to-right - relevant_cols.sort(key=lambda c: c.x) - - # Build columns_meta - columns_meta = [ - { - 'index': col_idx, - 'type': col.type, - 'x': col.x, - 'width': col.width, - } - for col_idx, col in enumerate(relevant_cols) - ] - - # Choose OCR language per column type (Tesseract only) - lang_map = { - 'column_en': 'eng', - 'column_de': 'deu', - 'column_example': 'eng+deu', - } - - cells: List[Dict[str, Any]] = [] - - for row_idx, row in enumerate(content_rows): - # Pre-assign each word to exactly one column (nearest center) - col_words = _assign_row_words_to_columns(row, relevant_cols) - for col_idx, col in enumerate(relevant_cols): - cell = _ocr_single_cell( - row_idx, col_idx, row, col, - ocr_img, img_bgr, img_w, img_h, - use_rapid, engine_name, lang, lang_map, - preassigned_words=col_words[col_idx], - ) - cells.append(cell) - - # --- BATCH FALLBACK: re-OCR empty cells by column strip --- - # Collect cells that are still empty but have visible pixels. - # Instead of calling Tesseract once per cell (expensive), crop an entire - # column strip and run OCR once, then assign words to cells by Y position. - empty_by_col: Dict[int, List[int]] = {} # col_idx → [cell list indices] - for ci, cell in enumerate(cells): - if not cell['text'].strip() and cell.get('ocr_engine') != 'cell_ocr_psm7': - bpx = cell['bbox_px'] - x, y, w, h = bpx['x'], bpx['y'], bpx['w'], bpx['h'] - if w > 0 and h > 0 and ocr_img is not None: - crop = ocr_img[y:y + h, x:x + w] - if crop.size > 0: - dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size - if dark_ratio > 0.005: - empty_by_col.setdefault(cell['col_index'], []).append(ci) - - for col_idx, cell_indices in empty_by_col.items(): - if len(cell_indices) < 3: - continue # Not worth batching for < 3 cells - - # Find the column strip bounding box (union of all empty cell bboxes) - min_y = min(cells[ci]['bbox_px']['y'] for ci in cell_indices) - max_y_h = max(cells[ci]['bbox_px']['y'] + cells[ci]['bbox_px']['h'] for ci in cell_indices) - col_x = cells[cell_indices[0]]['bbox_px']['x'] - col_w = cells[cell_indices[0]]['bbox_px']['w'] - - strip_region = PageRegion( - type=relevant_cols[col_idx].type, - x=col_x, y=min_y, - width=col_w, height=max_y_h - min_y, - ) - strip_lang = lang_map.get(relevant_cols[col_idx].type, lang) - - if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: - strip_words = ocr_region_trocr(img_bgr, strip_region, handwritten=(engine_name == "trocr-handwritten")) - elif engine_name == "lighton" and img_bgr is not None: - strip_words = ocr_region_lighton(img_bgr, strip_region) - elif use_rapid and img_bgr is not None: - strip_words = ocr_region_rapid(img_bgr, strip_region) - else: - strip_words = ocr_region(ocr_img, strip_region, lang=strip_lang, psm=6) - - if not strip_words: - continue - - strip_words = [w for w in strip_words if w.get('conf', 0) >= 30] - if not strip_words: - continue - - # Assign words to cells by Y overlap - for ci in cell_indices: - cell_y = cells[ci]['bbox_px']['y'] - cell_h = cells[ci]['bbox_px']['h'] - cell_mid_y = cell_y + cell_h / 2 - - matched_words = [ - w for w in strip_words - if abs((w['top'] + w['height'] / 2) - cell_mid_y) < cell_h * 0.8 - ] - if matched_words: - matched_words.sort(key=lambda w: w['left']) - batch_text = ' '.join(w['text'] for w in matched_words) - batch_text = _clean_cell_text(batch_text) - if batch_text.strip(): - cells[ci]['text'] = batch_text - cells[ci]['confidence'] = round( - sum(w['conf'] for w in matched_words) / len(matched_words), 1 - ) - cells[ci]['ocr_engine'] = 'batch_column_ocr' - - batch_filled = sum(1 for ci in cell_indices if cells[ci]['text'].strip()) - if batch_filled > 0: - logger.info( - f"build_cell_grid: batch OCR filled {batch_filled}/{len(cell_indices)} " - f"empty cells in column {col_idx}" - ) - - # Post-OCR: remove rows where ALL cells are empty (inter-row gaps - # that had stray Tesseract artifacts giving word_count > 0). - rows_with_text: set = set() - for cell in cells: - if cell['text'].strip(): - rows_with_text.add(cell['row_index']) - before_filter = len(cells) - cells = [c for c in cells if c['row_index'] in rows_with_text] - empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1) - if empty_rows_removed > 0: - logger.info(f"build_cell_grid: removed {empty_rows_removed} all-empty rows after OCR") - - logger.info(f"build_cell_grid: {len(cells)} cells from " - f"{len(content_rows)} rows × {len(relevant_cols)} columns, " - f"engine={engine_name}") - - return cells, columns_meta - - -def build_cell_grid_streaming( - ocr_img: np.ndarray, - column_regions: List[PageRegion], - row_geometries: List[RowGeometry], - img_w: int, - img_h: int, - lang: str = "eng+deu", - ocr_engine: str = "auto", - img_bgr: Optional[np.ndarray] = None, -) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]: - """Like build_cell_grid(), but yields each cell as it is OCR'd. - - Yields: - (cell_dict, columns_meta, total_cells) for each cell. - """ - # Resolve engine choice (same as build_cell_grid) - use_rapid = False - if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): - engine_name = ocr_engine - elif ocr_engine == "auto": - use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None - engine_name = "rapid" if use_rapid else "tesseract" - elif ocr_engine == "rapid": - if not RAPIDOCR_AVAILABLE: - logger.warning("RapidOCR requested but not available, falling back to Tesseract") - else: - use_rapid = True - engine_name = "rapid" if use_rapid else "tesseract" - else: - engine_name = "tesseract" - - content_rows = [r for r in row_geometries if r.row_type == 'content'] - if not content_rows: - return - - # Filter phantom rows: rows with no Tesseract words assigned are - # inter-line whitespace gaps that would produce garbage OCR. - before = len(content_rows) - content_rows = [r for r in content_rows if r.word_count > 0] - skipped = before - len(content_rows) - if skipped > 0: - logger.info(f"build_cell_grid_streaming: skipped {skipped} phantom rows (word_count=0)") - if not content_rows: - return - - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - relevant_cols = [c for c in column_regions if c.type not in _skip_types] - if not relevant_cols: - return - - # Filter artifact rows + heal gaps (same logic as build_cell_grid) - before_art = len(content_rows) - content_rows = [r for r in content_rows if not _is_artifact_row(r)] - artifact_skipped = before_art - len(content_rows) - if artifact_skipped > 0: - logger.info(f"build_cell_grid_streaming: skipped {artifact_skipped} artifact rows") - if not content_rows: - return - _heal_row_gaps( - content_rows, - top_bound=min(c.y for c in relevant_cols), - bottom_bound=max(c.y + c.height for c in relevant_cols), - ) - - relevant_cols.sort(key=lambda c: c.x) - - columns_meta = [ - { - 'index': col_idx, - 'type': col.type, - 'x': col.x, - 'width': col.width, - } - for col_idx, col in enumerate(relevant_cols) - ] - - lang_map = { - 'column_en': 'eng', - 'column_de': 'deu', - 'column_example': 'eng+deu', - } - - total_cells = len(content_rows) * len(relevant_cols) - - for row_idx, row in enumerate(content_rows): - # Pre-assign each word to exactly one column (nearest center) - col_words = _assign_row_words_to_columns(row, relevant_cols) - for col_idx, col in enumerate(relevant_cols): - cell = _ocr_single_cell( - row_idx, col_idx, row, col, - ocr_img, img_bgr, img_w, img_h, - use_rapid, engine_name, lang, lang_map, - preassigned_words=col_words[col_idx], - ) - yield cell, columns_meta, total_cells - - -def _cells_to_vocab_entries( - cells: List[Dict[str, Any]], - columns_meta: List[Dict[str, Any]], -) -> List[Dict[str, Any]]: - """Map generic cells to vocab entries with english/german/example fields. - - Groups cells by row_index, maps col_type → field name, and produces - one entry per row (only rows with at least one non-empty field). - """ - # Determine image dimensions from first cell (for row-level bbox) - col_type_to_field = { - 'column_en': 'english', - 'column_de': 'german', - 'column_example': 'example', - 'page_ref': 'source_page', - 'column_marker': 'marker', - 'column_text': 'text', # generic single-column (box sub-sessions) - } - bbox_key_map = { - 'column_en': 'bbox_en', - 'column_de': 'bbox_de', - 'column_example': 'bbox_ex', - 'page_ref': 'bbox_ref', - 'column_marker': 'bbox_marker', - 'column_text': 'bbox_text', - } - - # Group cells by row_index - rows: Dict[int, List[Dict]] = {} - for cell in cells: - ri = cell['row_index'] - rows.setdefault(ri, []).append(cell) - - entries: List[Dict[str, Any]] = [] - for row_idx in sorted(rows.keys()): - row_cells = rows[row_idx] - entry: Dict[str, Any] = { - 'row_index': row_idx, - 'english': '', - 'german': '', - 'example': '', - 'text': '', # generic single-column (box sub-sessions) - 'source_page': '', - 'marker': '', - 'confidence': 0.0, - 'bbox': None, - 'bbox_en': None, - 'bbox_de': None, - 'bbox_ex': None, - 'bbox_ref': None, - 'bbox_marker': None, - 'bbox_text': None, - 'ocr_engine': row_cells[0].get('ocr_engine', '') if row_cells else '', - } - - confidences = [] - for cell in row_cells: - col_type = cell['col_type'] - field = col_type_to_field.get(col_type) - if field: - entry[field] = cell['text'] - bbox_field = bbox_key_map.get(col_type) - if bbox_field: - entry[bbox_field] = cell['bbox_pct'] - if cell['confidence'] > 0: - confidences.append(cell['confidence']) - - # Compute row-level bbox as union of all cell bboxes - all_bboxes = [c['bbox_pct'] for c in row_cells if c.get('bbox_pct')] - if all_bboxes: - min_x = min(b['x'] for b in all_bboxes) - min_y = min(b['y'] for b in all_bboxes) - max_x2 = max(b['x'] + b['w'] for b in all_bboxes) - max_y2 = max(b['y'] + b['h'] for b in all_bboxes) - entry['bbox'] = { - 'x': round(min_x, 2), - 'y': round(min_y, 2), - 'w': round(max_x2 - min_x, 2), - 'h': round(max_y2 - min_y, 2), - } - - entry['confidence'] = round( - sum(confidences) / len(confidences), 1 - ) if confidences else 0.0 - - # Only include if at least one mapped field has text - has_content = any( - entry.get(f) - for f in col_type_to_field.values() - ) - if has_content: - entries.append(entry) - - return entries - - -# Regex: line starts with phonetic bracket content only (no real word before it) -_PHONETIC_ONLY_RE = re.compile( - r'''^\s*[\[\('"]*[^\]]*[\])\s]*$''' +# --- v2 build (current default) --- +from cv_cell_grid_build import ( # noqa: F401 + _NARROW_COL_THRESHOLD_PCT, + _ocr_cell_crop, + build_cell_grid_v2, ) +# --- Legacy build (DEPRECATED) --- +from cv_cell_grid_legacy import ( # noqa: F401 + _ocr_single_cell, + build_cell_grid, +) -def _is_phonetic_only_text(text: str) -> bool: - """Check if text consists only of phonetic transcription. +# --- Streaming variants --- +from cv_cell_grid_streaming import ( # noqa: F401 + build_cell_grid_streaming, + build_cell_grid_v2_streaming, +) - Phonetic-only patterns: - ['mani serva] → True - [dɑːns] → True - ["a:mand] → True - almond ['a:mand] → False (has real word before bracket) - Mandel → False - """ - t = text.strip() - if not t: - return False - # Must contain at least one bracket - if '[' not in t and ']' not in t: - return False - # Remove all bracket content and surrounding punctuation/whitespace - without_brackets = re.sub(r"\[.*?\]", '', t) - without_brackets = re.sub(r"[\[\]'\"()\s]", '', without_brackets) - # If nothing meaningful remains, it's phonetic-only - alpha_remaining = ''.join(_RE_ALPHA.findall(without_brackets)) - return len(alpha_remaining) < 2 - - -def _merge_phonetic_continuation_rows( - entries: List[Dict[str, Any]], -) -> List[Dict[str, Any]]: - """Merge rows that contain only phonetic transcription into previous entry. - - In dictionary pages, phonetic transcription sometimes wraps to the next - row. E.g.: - Row 28: EN="it's a money-saver" DE="es spart Kosten" - Row 29: EN="['mani serva]" DE="" - - Row 29 is phonetic-only → merge into row 28's EN field. - """ - if len(entries) < 2: - return entries - - merged: List[Dict[str, Any]] = [] - for entry in entries: - en = (entry.get('english') or '').strip() - de = (entry.get('german') or '').strip() - ex = (entry.get('example') or '').strip() - - # Check if this entry is phonetic-only (EN has only phonetics, DE empty) - if merged and _is_phonetic_only_text(en) and not de: - prev = merged[-1] - prev_en = (prev.get('english') or '').strip() - # Append phonetic to previous entry's EN - if prev_en: - prev['english'] = prev_en + ' ' + en - else: - prev['english'] = en - # If there was an example, append to previous too - if ex: - prev_ex = (prev.get('example') or '').strip() - prev['example'] = (prev_ex + ' ' + ex).strip() if prev_ex else ex - logger.debug( - f"Merged phonetic row {entry.get('row_index')} " - f"into previous entry: {prev['english']!r}" - ) - continue - - merged.append(entry) - - return merged - - -def _merge_wrapped_rows( - entries: List[Dict[str, Any]], -) -> List[Dict[str, Any]]: - """Merge rows where the primary column (EN) is empty — cell wrap continuation. - - In textbook vocabulary tables, columns are often narrow, so the author - wraps text within a cell. OCR treats each physical line as a separate row. - The key indicator: if the EN column is empty but DE/example have text, - this row is a continuation of the previous row's cells. - - Example (original textbook has ONE row): - Row 2: EN="take part (in)" DE="teilnehmen (an), mitmachen" EX="More than 200 singers took" - Row 3: EN="" DE="(bei)" EX="part in the concert." - → Merged: EN="take part (in)" DE="teilnehmen (an), mitmachen (bei)" EX="More than 200 singers took part in the concert." - - Also handles the reverse case: DE empty but EN has text (wrap in EN column). - """ - if len(entries) < 2: - return entries - - merged: List[Dict[str, Any]] = [] - for entry in entries: - en = (entry.get('english') or '').strip() - de = (entry.get('german') or '').strip() - ex = (entry.get('example') or '').strip() - - if not merged: - merged.append(entry) - continue - - prev = merged[-1] - prev_en = (prev.get('english') or '').strip() - prev_de = (prev.get('german') or '').strip() - prev_ex = (prev.get('example') or '').strip() - - # Case 1: EN is empty → continuation of previous row - # (DE or EX have text that should be appended to previous row) - if not en and (de or ex) and prev_en: - if de: - if prev_de.endswith(','): - sep = ' ' # "Wort," + " " + "Ausdruck" - elif prev_de.endswith(('-', '(')): - sep = '' # "teil-" + "nehmen" or "(" + "bei)" - else: - sep = ' ' - prev['german'] = (prev_de + sep + de).strip() - if ex: - sep = ' ' if prev_ex else '' - prev['example'] = (prev_ex + sep + ex).strip() - logger.debug( - f"Merged wrapped row {entry.get('row_index')} into previous " - f"(empty EN): DE={prev['german']!r}, EX={prev.get('example', '')!r}" - ) - continue - - # Case 2: DE is empty, EN has text that looks like continuation - # (starts with lowercase or is a parenthetical like "(bei)") - if en and not de and prev_de: - is_paren = en.startswith('(') - first_alpha = next((c for c in en if c.isalpha()), '') - starts_lower = first_alpha and first_alpha.islower() - - if (is_paren or starts_lower) and len(en.split()) < 5: - sep = ' ' if prev_en and not prev_en.endswith((',', '-', '(')) else '' - prev['english'] = (prev_en + sep + en).strip() - if ex: - sep2 = ' ' if prev_ex else '' - prev['example'] = (prev_ex + sep2 + ex).strip() - logger.debug( - f"Merged wrapped row {entry.get('row_index')} into previous " - f"(empty DE): EN={prev['english']!r}" - ) - continue - - merged.append(entry) - - if len(merged) < len(entries): - logger.info( - f"_merge_wrapped_rows: merged {len(entries) - len(merged)} " - f"continuation rows ({len(entries)} → {len(merged)})" - ) - return merged - - -def _merge_continuation_rows( - entries: List[Dict[str, Any]], -) -> List[Dict[str, Any]]: - """Merge multi-line vocabulary entries where text wraps to the next row. - - A row is a continuation of the previous entry when: - - EN has text, but DE is empty - - EN starts with a lowercase letter (not a new vocab entry) - - Previous entry's EN does NOT end with a sentence terminator (.!?) - - The continuation text has fewer than 4 words (not an example sentence) - - The row was not already merged as phonetic - - Example: - Row 5: EN="to put up" DE="aufstellen" - Row 6: EN="with sth." DE="" - → Merged: EN="to put up with sth." DE="aufstellen" - """ - if len(entries) < 2: - return entries - - merged: List[Dict[str, Any]] = [] - for entry in entries: - en = (entry.get('english') or '').strip() - de = (entry.get('german') or '').strip() - - if merged and en and not de: - # Check: not phonetic (already handled) - if _is_phonetic_only_text(en): - merged.append(entry) - continue - - # Check: starts with lowercase - first_alpha = next((c for c in en if c.isalpha()), '') - starts_lower = first_alpha and first_alpha.islower() - - # Check: fewer than 4 words (not an example sentence) - word_count = len(en.split()) - is_short = word_count < 4 - - # Check: previous entry doesn't end with sentence terminator - prev = merged[-1] - prev_en = (prev.get('english') or '').strip() - prev_ends_sentence = prev_en and prev_en[-1] in '.!?' - - if starts_lower and is_short and not prev_ends_sentence: - # Merge into previous entry - prev['english'] = (prev_en + ' ' + en).strip() - # Merge example if present - ex = (entry.get('example') or '').strip() - if ex: - prev_ex = (prev.get('example') or '').strip() - prev['example'] = (prev_ex + ' ' + ex).strip() if prev_ex else ex - logger.debug( - f"Merged continuation row {entry.get('row_index')} " - f"into previous entry: {prev['english']!r}" - ) - continue - - merged.append(entry) - - return merged - - -def build_word_grid( - ocr_img: np.ndarray, - column_regions: List[PageRegion], - row_geometries: List[RowGeometry], - img_w: int, - img_h: int, - lang: str = "eng+deu", - ocr_engine: str = "auto", - img_bgr: Optional[np.ndarray] = None, - pronunciation: str = "british", -) -> List[Dict[str, Any]]: - """Vocab-specific: Cell-Grid + Vocab-Mapping + Post-Processing. - - Wrapper around build_cell_grid() that adds vocabulary-specific logic: - - Maps cells to english/german/example entries - - Applies character confusion fixes, IPA lookup, comma splitting, etc. - - Falls back to returning raw cells if no vocab columns detected. - - Args: - ocr_img: Binarized full-page image (for Tesseract). - column_regions: Classified columns from Step 3. - row_geometries: Rows from Step 4. - img_w, img_h: Image dimensions. - lang: Default Tesseract language. - ocr_engine: 'tesseract', 'rapid', or 'auto'. - img_bgr: BGR color image (required for RapidOCR). - pronunciation: 'british' or 'american' for IPA lookup. - - Returns: - List of entry dicts with english/german/example text and bbox info (percent). - """ - cells, columns_meta = build_cell_grid( - ocr_img, column_regions, row_geometries, img_w, img_h, - lang=lang, ocr_engine=ocr_engine, img_bgr=img_bgr, - ) - - if not cells: - return [] - - # Check if vocab layout is present - col_types = {c['type'] for c in columns_meta} - if not (col_types & {'column_en', 'column_de'}): - logger.info("build_word_grid: no vocab columns — returning raw cells") - return cells - - # Vocab mapping: cells → entries - entries = _cells_to_vocab_entries(cells, columns_meta) - - # --- Post-processing pipeline (deterministic, no LLM) --- - n_raw = len(entries) - - # 0. Merge cell-wrap continuation rows (empty primary column = text wrap) - entries = _merge_wrapped_rows(entries) - - # 0a. Merge phonetic-only continuation rows into previous entry - entries = _merge_phonetic_continuation_rows(entries) - - # 0b. Merge multi-line continuation rows (lowercase EN, empty DE) - entries = _merge_continuation_rows(entries) - - # 1. Character confusion (| → I, 1 → I, 8 → B) is now run in - # llm_review_entries_streaming so changes are visible to the user in Step 6. - - # 2. Replace OCR'd phonetics with dictionary IPA - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - - # 3. Split comma-separated word forms (break, broke, broken → 3 entries) - entries = _split_comma_entries(entries) - - # 4. Attach example sentences (rows without DE → examples for preceding entry) - entries = _attach_example_sentences(entries) - - engine_name = cells[0].get('ocr_engine', 'unknown') if cells else 'unknown' - logger.info(f"build_word_grid: {len(entries)} entries from " - f"{n_raw} raw → {len(entries)} after post-processing " - f"(engine={engine_name})") - - return entries +# --- Row merging --- +from cv_cell_grid_merge import ( # noqa: F401 + _PHONETIC_ONLY_RE, + _is_phonetic_only_text, + _merge_continuation_rows, + _merge_phonetic_continuation_rows, + _merge_wrapped_rows, +) +# --- Vocab extraction --- +from cv_cell_grid_vocab import ( # noqa: F401 + _cells_to_vocab_entries, + build_word_grid, +) diff --git a/klausur-service/backend/cv_cell_grid_build.py b/klausur-service/backend/cv_cell_grid_build.py new file mode 100644 index 0000000..9ac0ac5 --- /dev/null +++ b/klausur-service/backend/cv_cell_grid_build.py @@ -0,0 +1,498 @@ +""" +Cell-grid construction v2 (hybrid: broad columns via word lookup, narrow via cell-crop). +Extracted from cv_cell_grid.py. +Lizenz: Apache 2.0 — DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from cv_vocab_types import PageRegion, RowGeometry +from cv_ocr_engines import ( + RAPIDOCR_AVAILABLE, + _assign_row_words_to_columns, + _clean_cell_text, + _clean_cell_text_lite, + _words_to_reading_order_text, + _words_to_spaced_text, + ocr_region_lighton, + ocr_region_rapid, + ocr_region_trocr, +) +from cv_cell_grid_helpers import ( + _MIN_WORD_CONF, + _ensure_minimum_crop_size, + _heal_row_gaps, + _is_artifact_row, + _select_psm_for_column, +) + +logger = logging.getLogger(__name__) + +try: + import cv2 +except ImportError: + cv2 = None # type: ignore[assignment] + + +# --------------------------------------------------------------------------- +# _ocr_cell_crop — isolated cell-crop OCR for v2 hybrid mode +# --------------------------------------------------------------------------- + +def _ocr_cell_crop( + row_idx: int, + col_idx: int, + row: RowGeometry, + col: PageRegion, + ocr_img: np.ndarray, + img_bgr: Optional[np.ndarray], + img_w: int, + img_h: int, + engine_name: str, + lang: str, + lang_map: Dict[str, str], +) -> Dict[str, Any]: + """OCR a single cell by cropping the exact column x row intersection. + + No padding beyond cell boundaries -> no neighbour bleeding. + """ + # Display bbox: exact column x row intersection + disp_x = col.x + disp_y = row.y + disp_w = col.width + disp_h = row.height + + # Crop boundaries: add small internal padding (3px each side) to avoid + # clipping characters near column/row edges (e.g. parentheses, descenders). + # Stays within image bounds but may extend slightly beyond strict cell. + # 3px is small enough to avoid neighbour content at typical scan DPI (200-300). + _PAD = 3 + cx = max(0, disp_x - _PAD) + cy = max(0, disp_y - _PAD) + cx2 = min(img_w, disp_x + disp_w + _PAD) + cy2 = min(img_h, disp_y + disp_h + _PAD) + cw = cx2 - cx + ch = cy2 - cy + + empty_cell = { + 'cell_id': f"R{row_idx:02d}_C{col_idx}", + 'row_index': row_idx, + 'col_index': col_idx, + 'col_type': col.type, + 'text': '', + 'confidence': 0.0, + 'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h}, + 'bbox_pct': { + 'x': round(disp_x / img_w * 100, 2) if img_w else 0, + 'y': round(disp_y / img_h * 100, 2) if img_h else 0, + 'w': round(disp_w / img_w * 100, 2) if img_w else 0, + 'h': round(disp_h / img_h * 100, 2) if img_h else 0, + }, + 'ocr_engine': 'cell_crop_v2', + 'is_bold': False, + } + + if cw <= 0 or ch <= 0: + logger.debug("_ocr_cell_crop R%02d_C%d: zero-size crop (%dx%d)", row_idx, col_idx, cw, ch) + return empty_cell + + # --- Pixel-density check: skip truly empty cells --- + if ocr_img is not None: + crop = ocr_img[cy:cy + ch, cx:cx + cw] + if crop.size > 0: + dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size + if dark_ratio < 0.005: + logger.debug("_ocr_cell_crop R%02d_C%d: skip empty (dark_ratio=%.4f, crop=%dx%d)", + row_idx, col_idx, dark_ratio, cw, ch) + return empty_cell + + # --- Prepare crop for OCR --- + cell_lang = lang_map.get(col.type, lang) + psm = _select_psm_for_column(col.type, col.width, row.height) + text = '' + avg_conf = 0.0 + used_engine = 'cell_crop_v2' + + if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: + cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) + words = ocr_region_trocr(img_bgr, cell_region, + handwritten=(engine_name == "trocr-handwritten")) + elif engine_name == "lighton" and img_bgr is not None: + cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) + words = ocr_region_lighton(img_bgr, cell_region) + elif engine_name == "rapid" and img_bgr is not None: + # Upscale small BGR crops for RapidOCR. + bgr_crop = img_bgr[cy:cy + ch, cx:cx + cw] + if bgr_crop.size == 0: + words = [] + else: + crop_h, crop_w = bgr_crop.shape[:2] + if crop_h < 80: + # Force 3x upscale for short rows — small chars need more pixels + scale = 3.0 + bgr_up = cv2.resize(bgr_crop, None, fx=scale, fy=scale, + interpolation=cv2.INTER_CUBIC) + else: + bgr_up = _ensure_minimum_crop_size(bgr_crop, min_dim=150, max_scale=3) + up_h, up_w = bgr_up.shape[:2] + scale_x = up_w / max(crop_w, 1) + scale_y = up_h / max(crop_h, 1) + was_scaled = (up_w != crop_w or up_h != crop_h) + logger.debug("_ocr_cell_crop R%02d_C%d: rapid %dx%d -> %dx%d (scale=%.1fx)", + row_idx, col_idx, crop_w, crop_h, up_w, up_h, scale_y) + tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) + words = ocr_region_rapid(bgr_up, tmp_region) + # Remap positions back to original image coords + if words and was_scaled: + for w in words: + w['left'] = int(w['left'] / scale_x) + cx + w['top'] = int(w['top'] / scale_y) + cy + w['width'] = int(w['width'] / scale_x) + w['height'] = int(w['height'] / scale_y) + elif words: + for w in words: + w['left'] += cx + w['top'] += cy + else: + # Tesseract: upscale tiny crops for better recognition + if ocr_img is not None: + crop_slice = ocr_img[cy:cy + ch, cx:cx + cw] + upscaled = _ensure_minimum_crop_size(crop_slice) + up_h, up_w = upscaled.shape[:2] + tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) + words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm) + # Remap word positions back to original image coordinates + if words and (up_w != cw or up_h != ch): + sx = cw / max(up_w, 1) + sy = ch / max(up_h, 1) + for w in words: + w['left'] = int(w['left'] * sx) + cx + w['top'] = int(w['top'] * sy) + cy + w['width'] = int(w['width'] * sx) + w['height'] = int(w['height'] * sy) + elif words: + for w in words: + w['left'] += cx + w['top'] += cy + else: + words = [] + + # Filter low-confidence words + if words: + words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] + + if words: + y_tol = max(15, ch) + text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) + avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) + logger.debug("_ocr_cell_crop R%02d_C%d: OCR raw text=%r conf=%.1f nwords=%d crop=%dx%d psm=%s engine=%s", + row_idx, col_idx, text, avg_conf, len(words), cw, ch, psm, engine_name) + else: + logger.debug("_ocr_cell_crop R%02d_C%d: OCR returned NO words (crop=%dx%d psm=%s engine=%s)", + row_idx, col_idx, cw, ch, psm, engine_name) + + # --- PSM 7 fallback for still-empty Tesseract cells --- + if not text.strip() and engine_name == "tesseract" and ocr_img is not None: + crop_slice = ocr_img[cy:cy + ch, cx:cx + cw] + upscaled = _ensure_minimum_crop_size(crop_slice) + up_h, up_w = upscaled.shape[:2] + tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) + psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7) + if psm7_words: + psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF] + if psm7_words: + p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10) + if p7_text.strip(): + text = p7_text + avg_conf = round( + sum(w['conf'] for w in psm7_words) / len(psm7_words), 1 + ) + used_engine = 'cell_crop_v2_psm7' + # Remap PSM7 word positions back to original image coords + if up_w != cw or up_h != ch: + sx = cw / max(up_w, 1) + sy = ch / max(up_h, 1) + for w in psm7_words: + w['left'] = int(w['left'] * sx) + cx + w['top'] = int(w['top'] * sy) + cy + w['width'] = int(w['width'] * sx) + w['height'] = int(w['height'] * sy) + else: + for w in psm7_words: + w['left'] += cx + w['top'] += cy + words = psm7_words + + # --- Noise filter --- + if text.strip(): + pre_filter = text + text = _clean_cell_text_lite(text) + if not text: + logger.debug("_ocr_cell_crop R%02d_C%d: _clean_cell_text_lite REMOVED %r", + row_idx, col_idx, pre_filter) + avg_conf = 0.0 + + result = dict(empty_cell) + result['text'] = text + result['confidence'] = avg_conf + result['ocr_engine'] = used_engine + + # Store individual word bounding boxes (absolute image coordinates) + # for pixel-accurate overlay positioning in the frontend. + if words and text.strip(): + result['word_boxes'] = [ + { + 'text': w.get('text', ''), + 'left': w['left'], + 'top': w['top'], + 'width': w['width'], + 'height': w['height'], + 'conf': w.get('conf', 0), + } + for w in words + if w.get('text', '').strip() + ] + + return result + + +# Threshold: columns narrower than this (% of image width) use single-cell +# crop OCR instead of full-page word assignment. +_NARROW_COL_THRESHOLD_PCT = 15.0 + + +# --------------------------------------------------------------------------- +# build_cell_grid_v2 — hybrid grid builder (current default) +# --------------------------------------------------------------------------- + +def build_cell_grid_v2( + ocr_img: np.ndarray, + column_regions: List[PageRegion], + row_geometries: List[RowGeometry], + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr: Optional[np.ndarray] = None, + skip_heal_gaps: bool = False, +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Hybrid Grid: full-page OCR for broad columns, cell-crop for narrow ones. + + Drop-in replacement for build_cell_grid() -- same signature & return type. + + Strategy: + - Broad columns (>15% image width): Use pre-assigned full-page Tesseract + words (from row.words). Handles IPA brackets, punctuation, sentence + continuity correctly. + - Narrow columns (<15% image width): Use isolated cell-crop OCR to prevent + neighbour bleeding from adjacent broad columns. + """ + engine_name = "tesseract" + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "rapid" and RAPIDOCR_AVAILABLE: + engine_name = "rapid" + + logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}' (hybrid mode)") + + # Filter to content rows only + content_rows = [r for r in row_geometries if r.row_type == 'content'] + if not content_rows: + logger.warning("build_cell_grid_v2: no content rows found") + return [], [] + + # Filter phantom rows (word_count=0) and artifact rows + before = len(content_rows) + content_rows = [r for r in content_rows if r.word_count > 0] + skipped = before - len(content_rows) + if skipped > 0: + logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)") + if not content_rows: + logger.warning("build_cell_grid_v2: no content rows with words found") + return [], [] + + before_art = len(content_rows) + content_rows = [r for r in content_rows if not _is_artifact_row(r)] + artifact_skipped = before_art - len(content_rows) + if artifact_skipped > 0: + logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows") + if not content_rows: + logger.warning("build_cell_grid_v2: no content rows after artifact filtering") + return [], [] + + # Filter columns + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', + 'margin_bottom', 'margin_left', 'margin_right'} + relevant_cols = [c for c in column_regions if c.type not in _skip_types] + if not relevant_cols: + logger.warning("build_cell_grid_v2: no usable columns found") + return [], [] + + # Heal row gaps -- use header/footer boundaries + content_rows.sort(key=lambda r: r.y) + header_rows = [r for r in row_geometries if r.row_type == 'header'] + footer_rows = [r for r in row_geometries if r.row_type == 'footer'] + if header_rows: + top_bound = max(r.y + r.height for r in header_rows) + else: + top_bound = content_rows[0].y + if footer_rows: + bottom_bound = min(r.y for r in footer_rows) + else: + bottom_bound = content_rows[-1].y + content_rows[-1].height + + # skip_heal_gaps: When True, keep cell positions at their exact row geometry + # positions without expanding to fill gaps from removed rows. + if not skip_heal_gaps: + _heal_row_gaps(content_rows, top_bound=top_bound, bottom_bound=bottom_bound) + + relevant_cols.sort(key=lambda c: c.x) + + columns_meta = [ + {'index': ci, 'type': c.type, 'x': c.x, 'width': c.width} + for ci, c in enumerate(relevant_cols) + ] + + lang_map = { + 'column_en': 'eng', + 'column_de': 'deu', + 'column_example': 'eng+deu', + } + + # --- Classify columns as broad vs narrow --- + narrow_col_indices = set() + for ci, col in enumerate(relevant_cols): + col_pct = (col.width / img_w * 100) if img_w > 0 else 0 + if col_pct < _NARROW_COL_THRESHOLD_PCT: + narrow_col_indices.add(ci) + + broad_col_count = len(relevant_cols) - len(narrow_col_indices) + logger.info(f"build_cell_grid_v2: {broad_col_count} broad columns (full-page), " + f"{len(narrow_col_indices)} narrow columns (cell-crop)") + + # --- Phase 1: Broad columns via full-page word assignment --- + cells: List[Dict[str, Any]] = [] + + for row_idx, row in enumerate(content_rows): + # Assign full-page words to columns for this row + col_words = _assign_row_words_to_columns(row, relevant_cols) + + for col_idx, col in enumerate(relevant_cols): + if col_idx not in narrow_col_indices: + # BROAD column: use pre-assigned full-page words + words = col_words.get(col_idx, []) + # Filter low-confidence words + words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] + + # Single full-width column (box sub-session): preserve spacing + is_single_full_column = ( + len(relevant_cols) == 1 + and img_w > 0 + and relevant_cols[0].width / img_w > 0.9 + ) + + if words: + y_tol = max(15, row.height) + if is_single_full_column: + text = _words_to_spaced_text(words, y_tolerance_px=y_tol) + logger.info(f"R{row_idx:02d}: {len(words)} words, " + f"text={text!r:.100}") + else: + text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) + avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) + else: + text = '' + avg_conf = 0.0 + if is_single_full_column: + logger.info(f"R{row_idx:02d}: 0 words (row has " + f"{row.word_count} total, y={row.y}..{row.y+row.height})") + + # Apply noise filter -- but NOT for single-column sub-sessions + if not is_single_full_column: + text = _clean_cell_text(text) + + cell = { + 'cell_id': f"R{row_idx:02d}_C{col_idx}", + 'row_index': row_idx, + 'col_index': col_idx, + 'col_type': col.type, + 'text': text, + 'confidence': avg_conf, + 'bbox_px': { + 'x': col.x, 'y': row.y, + 'w': col.width, 'h': row.height, + }, + 'bbox_pct': { + 'x': round(col.x / img_w * 100, 2) if img_w else 0, + 'y': round(row.y / img_h * 100, 2) if img_h else 0, + 'w': round(col.width / img_w * 100, 2) if img_w else 0, + 'h': round(row.height / img_h * 100, 2) if img_h else 0, + }, + 'ocr_engine': 'word_lookup', + 'is_bold': False, + } + # Store word bounding boxes for pixel-accurate overlay + if words and text.strip(): + cell['word_boxes'] = [ + { + 'text': w.get('text', ''), + 'left': w['left'], + 'top': w['top'], + 'width': w['width'], + 'height': w['height'], + 'conf': w.get('conf', 0), + } + for w in words + if w.get('text', '').strip() + ] + cells.append(cell) + + # --- Phase 2: Narrow columns via cell-crop OCR (parallel) --- + narrow_tasks = [] + for row_idx, row in enumerate(content_rows): + for col_idx, col in enumerate(relevant_cols): + if col_idx in narrow_col_indices: + narrow_tasks.append((row_idx, col_idx, row, col)) + + if narrow_tasks: + max_workers = 4 if engine_name == "tesseract" else 2 + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = { + pool.submit( + _ocr_cell_crop, + ri, ci, row, col, + ocr_img, img_bgr, img_w, img_h, + engine_name, lang, lang_map, + ): (ri, ci) + for ri, ci, row, col in narrow_tasks + } + for future in as_completed(futures): + try: + cell = future.result() + cells.append(cell) + except Exception as e: + ri, ci = futures[future] + logger.error(f"build_cell_grid_v2: narrow cell R{ri:02d}_C{ci} failed: {e}") + + # Sort cells by (row_index, col_index) + cells.sort(key=lambda c: (c['row_index'], c['col_index'])) + + # Remove all-empty rows + rows_with_text: set = set() + for cell in cells: + if cell['text'].strip(): + rows_with_text.add(cell['row_index']) + before_filter = len(cells) + cells = [c for c in cells if c['row_index'] in rows_with_text] + empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1) + if empty_rows_removed > 0: + logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows") + + logger.info(f"build_cell_grid_v2: {len(cells)} cells from " + f"{len(content_rows)} rows x {len(relevant_cols)} columns, " + f"engine={engine_name} (hybrid)") + + return cells, columns_meta diff --git a/klausur-service/backend/cv_cell_grid_helpers.py b/klausur-service/backend/cv_cell_grid_helpers.py new file mode 100644 index 0000000..f5e41d3 --- /dev/null +++ b/klausur-service/backend/cv_cell_grid_helpers.py @@ -0,0 +1,136 @@ +""" +Shared helpers for cell-grid construction (v2 + legacy). + +Extracted from cv_cell_grid.py — used by both cv_cell_grid_build and +cv_cell_grid_legacy. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import List + +import numpy as np + +from cv_vocab_types import RowGeometry + +logger = logging.getLogger(__name__) + +try: + import cv2 +except ImportError: + cv2 = None # type: ignore[assignment] + +# Minimum OCR word confidence to keep (used across multiple functions) +_MIN_WORD_CONF = 30 + + +def _compute_cell_padding(col_width: int, img_w: int) -> int: + """Adaptive padding for OCR crops based on column width. + + Narrow columns (page_ref, marker) need more surrounding context so + Tesseract can segment characters correctly. Wide columns keep the + minimal 4 px padding to avoid pulling in neighbours. + """ + col_pct = col_width / img_w * 100 if img_w > 0 else 100 + if col_pct < 5: + return max(20, col_width // 2) + if col_pct < 10: + return max(12, col_width // 4) + if col_pct < 15: + return 8 + return 4 + + +def _ensure_minimum_crop_size(crop: np.ndarray, min_dim: int = 150, + max_scale: int = 3) -> np.ndarray: + """Upscale tiny crops so Tesseract gets enough pixel data. + + If either dimension is below *min_dim*, the crop is bicubic-upscaled + so the smallest dimension reaches *min_dim* (capped at *max_scale* x). + """ + h, w = crop.shape[:2] + if h >= min_dim and w >= min_dim: + return crop + scale = min(max_scale, max(min_dim / max(h, 1), min_dim / max(w, 1))) + if scale <= 1.0: + return crop + new_w = int(w * scale) + new_h = int(h * scale) + return cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_CUBIC) + + +def _select_psm_for_column(col_type: str, col_width: int, + row_height: int) -> int: + """Choose the best Tesseract PSM for a given column geometry. + + - page_ref columns are almost always single short tokens -> PSM 8 + - Very narrow or short cells -> PSM 7 (single text line) + - Everything else -> PSM 6 (uniform block) + """ + if col_type in ('page_ref', 'marker'): + return 8 # single word + if col_width < 100 or row_height < 30: + return 7 # single line + return 6 # uniform block + + +def _is_artifact_row(row: RowGeometry) -> bool: + """Return True if this row contains only scan artifacts, not real text. + + Artifact rows (scanner shadows, noise) typically produce only single-character + detections. A real content row always has at least one token with 2+ characters. + """ + if row.word_count == 0: + return True + texts = [w.get('text', '').strip() for w in row.words] + return all(len(t) <= 1 for t in texts) + + +def _heal_row_gaps( + rows: List[RowGeometry], + top_bound: int, + bottom_bound: int, +) -> None: + """Expand row y/height to fill vertical gaps caused by removed adjacent rows. + + After filtering out empty or artifact rows, remaining content rows may have + gaps between them where the removed rows used to be. This function mutates + each row to extend upward/downward to the midpoint of such gaps so that + OCR crops cover the full available content area. + + The first row always extends to top_bound; the last row to bottom_bound. + """ + if not rows: + return + rows.sort(key=lambda r: r.y) + n = len(rows) + orig = [(r.y, r.y + r.height) for r in rows] # snapshot before mutation + + for i, row in enumerate(rows): + # New top: midpoint between previous row's bottom and this row's top + if i == 0: + new_top = top_bound + else: + prev_bot = orig[i - 1][1] + my_top = orig[i][0] + gap = my_top - prev_bot + new_top = prev_bot + gap // 2 if gap > 1 else my_top + + # New bottom: midpoint between this row's bottom and next row's top + if i == n - 1: + new_bottom = bottom_bound + else: + my_bot = orig[i][1] + next_top = orig[i + 1][0] + gap = next_top - my_bot + new_bottom = my_bot + gap // 2 if gap > 1 else my_bot + + row.y = new_top + row.height = max(5, new_bottom - new_top) + + logger.debug( + f"_heal_row_gaps: {n} rows -> y range [{rows[0].y}..{rows[-1].y + rows[-1].height}] " + f"(bounds: top={top_bound}, bottom={bottom_bound})" + ) diff --git a/klausur-service/backend/cv_cell_grid_legacy.py b/klausur-service/backend/cv_cell_grid_legacy.py new file mode 100644 index 0000000..e00df7c --- /dev/null +++ b/klausur-service/backend/cv_cell_grid_legacy.py @@ -0,0 +1,436 @@ +""" +Legacy cell-grid construction (v1) -- DEPRECATED, kept for backward compat. + +Extracted from cv_cell_grid.py. Prefer build_cell_grid_v2 for new code. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from cv_vocab_types import PageRegion, RowGeometry +from cv_ocr_engines import ( + RAPIDOCR_AVAILABLE, + _assign_row_words_to_columns, + _clean_cell_text, + _words_to_reading_order_text, + ocr_region_lighton, + ocr_region_rapid, + ocr_region_trocr, +) +from cv_cell_grid_helpers import ( + _MIN_WORD_CONF, + _compute_cell_padding, + _ensure_minimum_crop_size, + _heal_row_gaps, + _is_artifact_row, + _select_psm_for_column, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# _ocr_single_cell — legacy per-cell OCR with multi-level fallback +# --------------------------------------------------------------------------- + +def _ocr_single_cell( + row_idx: int, + col_idx: int, + row: RowGeometry, + col: PageRegion, + ocr_img: np.ndarray, + img_bgr: Optional[np.ndarray], + img_w: int, + img_h: int, + use_rapid: bool, + engine_name: str, + lang: str, + lang_map: Dict[str, str], + preassigned_words: Optional[List[Dict]] = None, +) -> Dict[str, Any]: + """Populate a single cell (column x row intersection) via word lookup.""" + # Display bbox: exact column x row intersection (no padding) + disp_x = col.x + disp_y = row.y + disp_w = col.width + disp_h = row.height + + # OCR crop: adaptive padding -- narrow columns get more context + pad = _compute_cell_padding(col.width, img_w) + cell_x = max(0, col.x - pad) + cell_y = max(0, row.y - pad) + cell_w = min(col.width + 2 * pad, img_w - cell_x) + cell_h = min(row.height + 2 * pad, img_h - cell_y) + is_narrow = (col.width / img_w * 100) < 15 if img_w > 0 else False + + if disp_w <= 0 or disp_h <= 0: + return { + 'cell_id': f"R{row_idx:02d}_C{col_idx}", + 'row_index': row_idx, + 'col_index': col_idx, + 'col_type': col.type, + 'text': '', + 'confidence': 0.0, + 'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height}, + 'bbox_pct': { + 'x': round(col.x / img_w * 100, 2), + 'y': round(row.y / img_h * 100, 2), + 'w': round(col.width / img_w * 100, 2), + 'h': round(row.height / img_h * 100, 2), + }, + 'ocr_engine': 'word_lookup', + } + + # --- PRIMARY: Word-lookup from full-page Tesseract --- + words = preassigned_words if preassigned_words is not None else [] + used_engine = 'word_lookup' + + # Filter low-confidence words + if words: + words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] + + if words: + y_tol = max(15, row.height) + text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) + avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) + else: + text = '' + avg_conf = 0.0 + + # --- FALLBACK: Cell-OCR for empty cells --- + _run_fallback = False + if not text.strip() and cell_w > 0 and cell_h > 0: + if ocr_img is not None: + crop = ocr_img[cell_y:cell_y + cell_h, cell_x:cell_x + cell_w] + if crop.size > 0: + dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size + _run_fallback = dark_ratio > 0.005 + if _run_fallback: + # For narrow columns, upscale the crop before OCR + if is_narrow and ocr_img is not None: + _crop_slice = ocr_img[cell_y:cell_y + cell_h, cell_x:cell_x + cell_w] + _upscaled = _ensure_minimum_crop_size(_crop_slice) + if _upscaled is not _crop_slice: + _up_h, _up_w = _upscaled.shape[:2] + _tmp_region = PageRegion( + type=col.type, x=0, y=0, width=_up_w, height=_up_h, + ) + _cell_psm = _select_psm_for_column(col.type, col.width, row.height) + cell_lang = lang_map.get(col.type, lang) + fallback_words = ocr_region(_upscaled, _tmp_region, + lang=cell_lang, psm=_cell_psm) + # Remap word positions back to original image coordinates + _sx = cell_w / max(_up_w, 1) + _sy = cell_h / max(_up_h, 1) + for _fw in (fallback_words or []): + _fw['left'] = int(_fw['left'] * _sx) + cell_x + _fw['top'] = int(_fw['top'] * _sy) + cell_y + _fw['width'] = int(_fw['width'] * _sx) + _fw['height'] = int(_fw['height'] * _sy) + else: + cell_region = PageRegion( + type=col.type, x=cell_x, y=cell_y, + width=cell_w, height=cell_h, + ) + _cell_psm = _select_psm_for_column(col.type, col.width, row.height) + cell_lang = lang_map.get(col.type, lang) + fallback_words = ocr_region(ocr_img, cell_region, + lang=cell_lang, psm=_cell_psm) + else: + cell_region = PageRegion( + type=col.type, + x=cell_x, y=cell_y, + width=cell_w, height=cell_h, + ) + if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: + fallback_words = ocr_region_trocr(img_bgr, cell_region, handwritten=(engine_name == "trocr-handwritten")) + elif engine_name == "lighton" and img_bgr is not None: + fallback_words = ocr_region_lighton(img_bgr, cell_region) + elif use_rapid and img_bgr is not None: + fallback_words = ocr_region_rapid(img_bgr, cell_region) + else: + _cell_psm = _select_psm_for_column(col.type, col.width, row.height) + cell_lang = lang_map.get(col.type, lang) + fallback_words = ocr_region(ocr_img, cell_region, + lang=cell_lang, psm=_cell_psm) + + if fallback_words: + fallback_words = [w for w in fallback_words if w.get('conf', 0) >= _MIN_WORD_CONF] + if fallback_words: + fb_avg_h = sum(w['height'] for w in fallback_words) / len(fallback_words) + fb_y_tol = max(10, int(fb_avg_h * 0.5)) + fb_text = _words_to_reading_order_text(fallback_words, y_tolerance_px=fb_y_tol) + if fb_text.strip(): + text = fb_text + avg_conf = round( + sum(w['conf'] for w in fallback_words) / len(fallback_words), 1 + ) + used_engine = 'cell_ocr_fallback' + + # --- SECONDARY FALLBACK: PSM=7 (single line) for still-empty cells --- + if not text.strip() and _run_fallback and not use_rapid: + _fb_region = PageRegion( + type=col.type, x=cell_x, y=cell_y, + width=cell_w, height=cell_h, + ) + cell_lang = lang_map.get(col.type, lang) + psm7_words = ocr_region(ocr_img, _fb_region, lang=cell_lang, psm=7) + if psm7_words: + psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF] + if psm7_words: + p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10) + if p7_text.strip(): + text = p7_text + avg_conf = round( + sum(w['conf'] for w in psm7_words) / len(psm7_words), 1 + ) + used_engine = 'cell_ocr_psm7' + + # --- TERTIARY FALLBACK: Row-strip re-OCR for narrow columns --- + if not text.strip() and is_narrow and img_bgr is not None: + row_region = PageRegion( + type='_row_strip', x=0, y=row.y, + width=img_w, height=row.height, + ) + strip_words = ocr_region_rapid(img_bgr, row_region) + if strip_words: + col_left = col.x + col_right = col.x + col.width + col_words = [] + for sw in strip_words: + sw_left = sw.get('left', 0) + sw_right = sw_left + sw.get('width', 0) + overlap = max(0, min(sw_right, col_right) - max(sw_left, col_left)) + if overlap > sw.get('width', 1) * 0.3: + col_words.append(sw) + if col_words: + col_words = [w for w in col_words if w.get('conf', 0) >= _MIN_WORD_CONF] + if col_words: + rs_text = _words_to_reading_order_text(col_words, y_tolerance_px=row.height) + if rs_text.strip(): + text = rs_text + avg_conf = round( + sum(w['conf'] for w in col_words) / len(col_words), 1 + ) + used_engine = 'row_strip_rapid' + + # --- NOISE FILTER: clear cells that contain only OCR artifacts --- + if text.strip(): + text = _clean_cell_text(text) + if not text: + avg_conf = 0.0 + + return { + 'cell_id': f"R{row_idx:02d}_C{col_idx}", + 'row_index': row_idx, + 'col_index': col_idx, + 'col_type': col.type, + 'text': text, + 'confidence': avg_conf, + 'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h}, + 'bbox_pct': { + 'x': round(disp_x / img_w * 100, 2), + 'y': round(disp_y / img_h * 100, 2), + 'w': round(disp_w / img_w * 100, 2), + 'h': round(disp_h / img_h * 100, 2), + }, + 'ocr_engine': used_engine, + } + + +# --------------------------------------------------------------------------- +# build_cell_grid — legacy grid builder (DEPRECATED) +# --------------------------------------------------------------------------- + +def build_cell_grid( + ocr_img: np.ndarray, + column_regions: List[PageRegion], + row_geometries: List[RowGeometry], + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr: Optional[np.ndarray] = None, +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Generic Cell-Grid: Columns x Rows -> cells with OCR text. + + DEPRECATED: Use build_cell_grid_v2 instead. + """ + # Resolve engine choice + use_rapid = False + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": + use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None + engine_name = "rapid" if use_rapid else "tesseract" + elif ocr_engine == "rapid": + if not RAPIDOCR_AVAILABLE: + logger.warning("RapidOCR requested but not available, falling back to Tesseract") + else: + use_rapid = True + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" + + logger.info(f"build_cell_grid: using OCR engine '{engine_name}'") + + # Filter to content rows only (skip header/footer) + content_rows = [r for r in row_geometries if r.row_type == 'content'] + if not content_rows: + logger.warning("build_cell_grid: no content rows found") + return [], [] + + before = len(content_rows) + content_rows = [r for r in content_rows if r.word_count > 0] + skipped = before - len(content_rows) + if skipped > 0: + logger.info(f"build_cell_grid: skipped {skipped} phantom rows (word_count=0)") + if not content_rows: + logger.warning("build_cell_grid: no content rows with words found") + return [], [] + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + relevant_cols = [c for c in column_regions if c.type not in _skip_types] + if not relevant_cols: + logger.warning("build_cell_grid: no usable columns found") + return [], [] + + before_art = len(content_rows) + content_rows = [r for r in content_rows if not _is_artifact_row(r)] + artifact_skipped = before_art - len(content_rows) + if artifact_skipped > 0: + logger.info(f"build_cell_grid: skipped {artifact_skipped} artifact rows (all single-char words)") + if not content_rows: + logger.warning("build_cell_grid: no content rows after artifact filtering") + return [], [] + + _heal_row_gaps( + content_rows, + top_bound=min(c.y for c in relevant_cols), + bottom_bound=max(c.y + c.height for c in relevant_cols), + ) + + relevant_cols.sort(key=lambda c: c.x) + + columns_meta = [ + { + 'index': col_idx, + 'type': col.type, + 'x': col.x, + 'width': col.width, + } + for col_idx, col in enumerate(relevant_cols) + ] + + lang_map = { + 'column_en': 'eng', + 'column_de': 'deu', + 'column_example': 'eng+deu', + } + + cells: List[Dict[str, Any]] = [] + + for row_idx, row in enumerate(content_rows): + col_words = _assign_row_words_to_columns(row, relevant_cols) + for col_idx, col in enumerate(relevant_cols): + cell = _ocr_single_cell( + row_idx, col_idx, row, col, + ocr_img, img_bgr, img_w, img_h, + use_rapid, engine_name, lang, lang_map, + preassigned_words=col_words[col_idx], + ) + cells.append(cell) + + # --- BATCH FALLBACK: re-OCR empty cells by column strip --- + empty_by_col: Dict[int, List[int]] = {} + for ci, cell in enumerate(cells): + if not cell['text'].strip() and cell.get('ocr_engine') != 'cell_ocr_psm7': + bpx = cell['bbox_px'] + x, y, w, h = bpx['x'], bpx['y'], bpx['w'], bpx['h'] + if w > 0 and h > 0 and ocr_img is not None: + crop = ocr_img[y:y + h, x:x + w] + if crop.size > 0: + dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size + if dark_ratio > 0.005: + empty_by_col.setdefault(cell['col_index'], []).append(ci) + + for col_idx, cell_indices in empty_by_col.items(): + if len(cell_indices) < 3: + continue + + min_y = min(cells[ci]['bbox_px']['y'] for ci in cell_indices) + max_y_h = max(cells[ci]['bbox_px']['y'] + cells[ci]['bbox_px']['h'] for ci in cell_indices) + col_x = cells[cell_indices[0]]['bbox_px']['x'] + col_w = cells[cell_indices[0]]['bbox_px']['w'] + + strip_region = PageRegion( + type=relevant_cols[col_idx].type, + x=col_x, y=min_y, + width=col_w, height=max_y_h - min_y, + ) + strip_lang = lang_map.get(relevant_cols[col_idx].type, lang) + + if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: + strip_words = ocr_region_trocr(img_bgr, strip_region, handwritten=(engine_name == "trocr-handwritten")) + elif engine_name == "lighton" and img_bgr is not None: + strip_words = ocr_region_lighton(img_bgr, strip_region) + elif use_rapid and img_bgr is not None: + strip_words = ocr_region_rapid(img_bgr, strip_region) + else: + strip_words = ocr_region(ocr_img, strip_region, lang=strip_lang, psm=6) + + if not strip_words: + continue + + strip_words = [w for w in strip_words if w.get('conf', 0) >= 30] + if not strip_words: + continue + + for ci in cell_indices: + cell_y = cells[ci]['bbox_px']['y'] + cell_h = cells[ci]['bbox_px']['h'] + cell_mid_y = cell_y + cell_h / 2 + + matched_words = [ + w for w in strip_words + if abs((w['top'] + w['height'] / 2) - cell_mid_y) < cell_h * 0.8 + ] + if matched_words: + matched_words.sort(key=lambda w: w['left']) + batch_text = ' '.join(w['text'] for w in matched_words) + batch_text = _clean_cell_text(batch_text) + if batch_text.strip(): + cells[ci]['text'] = batch_text + cells[ci]['confidence'] = round( + sum(w['conf'] for w in matched_words) / len(matched_words), 1 + ) + cells[ci]['ocr_engine'] = 'batch_column_ocr' + + batch_filled = sum(1 for ci in cell_indices if cells[ci]['text'].strip()) + if batch_filled > 0: + logger.info( + f"build_cell_grid: batch OCR filled {batch_filled}/{len(cell_indices)} " + f"empty cells in column {col_idx}" + ) + + # Remove all-empty rows + rows_with_text: set = set() + for cell in cells: + if cell['text'].strip(): + rows_with_text.add(cell['row_index']) + before_filter = len(cells) + cells = [c for c in cells if c['row_index'] in rows_with_text] + empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1) + if empty_rows_removed > 0: + logger.info(f"build_cell_grid: removed {empty_rows_removed} all-empty rows after OCR") + + logger.info(f"build_cell_grid: {len(cells)} cells from " + f"{len(content_rows)} rows x {len(relevant_cols)} columns, " + f"engine={engine_name}") + + return cells, columns_meta diff --git a/klausur-service/backend/cv_cell_grid_merge.py b/klausur-service/backend/cv_cell_grid_merge.py new file mode 100644 index 0000000..a86770e --- /dev/null +++ b/klausur-service/backend/cv_cell_grid_merge.py @@ -0,0 +1,235 @@ +""" +Row-merging logic for vocabulary entries (phonetic, wrapped, continuation rows). + +Extracted from cv_cell_grid.py. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Any, Dict, List + +from cv_ocr_engines import _RE_ALPHA + +logger = logging.getLogger(__name__) + +# Regex: line starts with phonetic bracket content only (no real word before it) +_PHONETIC_ONLY_RE = re.compile( + r'''^\s*[\[\('"]*[^\]]*[\])\s]*$''' +) + + +def _is_phonetic_only_text(text: str) -> bool: + """Check if text consists only of phonetic transcription. + + Phonetic-only patterns: + ['mani serva] -> True + [dance] -> True + ["a:mand] -> True + almond ['a:mand] -> False (has real word before bracket) + Mandel -> False + """ + t = text.strip() + if not t: + return False + # Must contain at least one bracket + if '[' not in t and ']' not in t: + return False + # Remove all bracket content and surrounding punctuation/whitespace + without_brackets = re.sub(r"\[.*?\]", '', t) + without_brackets = re.sub(r"[\[\]'\"()\s]", '', without_brackets) + # If nothing meaningful remains, it's phonetic-only + alpha_remaining = ''.join(_RE_ALPHA.findall(without_brackets)) + return len(alpha_remaining) < 2 + + +def _merge_phonetic_continuation_rows( + entries: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge rows that contain only phonetic transcription into previous entry. + + In dictionary pages, phonetic transcription sometimes wraps to the next + row. E.g.: + Row 28: EN="it's a money-saver" DE="es spart Kosten" + Row 29: EN="['mani serva]" DE="" + + Row 29 is phonetic-only -> merge into row 28's EN field. + """ + if len(entries) < 2: + return entries + + merged: List[Dict[str, Any]] = [] + for entry in entries: + en = (entry.get('english') or '').strip() + de = (entry.get('german') or '').strip() + ex = (entry.get('example') or '').strip() + + # Check if this entry is phonetic-only (EN has only phonetics, DE empty) + if merged and _is_phonetic_only_text(en) and not de: + prev = merged[-1] + prev_en = (prev.get('english') or '').strip() + # Append phonetic to previous entry's EN + if prev_en: + prev['english'] = prev_en + ' ' + en + else: + prev['english'] = en + # If there was an example, append to previous too + if ex: + prev_ex = (prev.get('example') or '').strip() + prev['example'] = (prev_ex + ' ' + ex).strip() if prev_ex else ex + logger.debug( + f"Merged phonetic row {entry.get('row_index')} " + f"into previous entry: {prev['english']!r}" + ) + continue + + merged.append(entry) + + return merged + + +def _merge_wrapped_rows( + entries: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge rows where the primary column (EN) is empty -- cell wrap continuation. + + In textbook vocabulary tables, columns are often narrow, so the author + wraps text within a cell. OCR treats each physical line as a separate row. + The key indicator: if the EN column is empty but DE/example have text, + this row is a continuation of the previous row's cells. + + Example (original textbook has ONE row): + Row 2: EN="take part (in)" DE="teilnehmen (an), mitmachen" EX="More than 200 singers took" + Row 3: EN="" DE="(bei)" EX="part in the concert." + -> Merged: EN="take part (in)" DE="teilnehmen (an), mitmachen (bei)" EX="..." + + Also handles the reverse case: DE empty but EN has text (wrap in EN column). + """ + if len(entries) < 2: + return entries + + merged: List[Dict[str, Any]] = [] + for entry in entries: + en = (entry.get('english') or '').strip() + de = (entry.get('german') or '').strip() + ex = (entry.get('example') or '').strip() + + if not merged: + merged.append(entry) + continue + + prev = merged[-1] + prev_en = (prev.get('english') or '').strip() + prev_de = (prev.get('german') or '').strip() + prev_ex = (prev.get('example') or '').strip() + + # Case 1: EN is empty -> continuation of previous row + if not en and (de or ex) and prev_en: + if de: + if prev_de.endswith(','): + sep = ' ' + elif prev_de.endswith(('-', '(')): + sep = '' + else: + sep = ' ' + prev['german'] = (prev_de + sep + de).strip() + if ex: + sep = ' ' if prev_ex else '' + prev['example'] = (prev_ex + sep + ex).strip() + logger.debug( + f"Merged wrapped row {entry.get('row_index')} into previous " + f"(empty EN): DE={prev['german']!r}, EX={prev.get('example', '')!r}" + ) + continue + + # Case 2: DE is empty, EN has text that looks like continuation + if en and not de and prev_de: + is_paren = en.startswith('(') + first_alpha = next((c for c in en if c.isalpha()), '') + starts_lower = first_alpha and first_alpha.islower() + + if (is_paren or starts_lower) and len(en.split()) < 5: + sep = ' ' if prev_en and not prev_en.endswith((',', '-', '(')) else '' + prev['english'] = (prev_en + sep + en).strip() + if ex: + sep2 = ' ' if prev_ex else '' + prev['example'] = (prev_ex + sep2 + ex).strip() + logger.debug( + f"Merged wrapped row {entry.get('row_index')} into previous " + f"(empty DE): EN={prev['english']!r}" + ) + continue + + merged.append(entry) + + if len(merged) < len(entries): + logger.info( + f"_merge_wrapped_rows: merged {len(entries) - len(merged)} " + f"continuation rows ({len(entries)} -> {len(merged)})" + ) + return merged + + +def _merge_continuation_rows( + entries: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge multi-line vocabulary entries where text wraps to the next row. + + A row is a continuation of the previous entry when: + - EN has text, but DE is empty + - EN starts with a lowercase letter (not a new vocab entry) + - Previous entry's EN does NOT end with a sentence terminator (.!?) + - The continuation text has fewer than 4 words (not an example sentence) + - The row was not already merged as phonetic + + Example: + Row 5: EN="to put up" DE="aufstellen" + Row 6: EN="with sth." DE="" + -> Merged: EN="to put up with sth." DE="aufstellen" + """ + if len(entries) < 2: + return entries + + merged: List[Dict[str, Any]] = [] + for entry in entries: + en = (entry.get('english') or '').strip() + de = (entry.get('german') or '').strip() + + if merged and en and not de: + # Check: not phonetic (already handled) + if _is_phonetic_only_text(en): + merged.append(entry) + continue + + # Check: starts with lowercase + first_alpha = next((c for c in en if c.isalpha()), '') + starts_lower = first_alpha and first_alpha.islower() + + # Check: fewer than 4 words (not an example sentence) + word_count = len(en.split()) + is_short = word_count < 4 + + # Check: previous entry doesn't end with sentence terminator + prev = merged[-1] + prev_en = (prev.get('english') or '').strip() + prev_ends_sentence = prev_en and prev_en[-1] in '.!?' + + if starts_lower and is_short and not prev_ends_sentence: + # Merge into previous entry + prev['english'] = (prev_en + ' ' + en).strip() + # Merge example if present + ex = (entry.get('example') or '').strip() + if ex: + prev_ex = (prev.get('example') or '').strip() + prev['example'] = (prev_ex + ' ' + ex).strip() if prev_ex else ex + logger.debug( + f"Merged continuation row {entry.get('row_index')} " + f"into previous entry: {prev['english']!r}" + ) + continue + + merged.append(entry) + + return merged diff --git a/klausur-service/backend/cv_cell_grid_streaming.py b/klausur-service/backend/cv_cell_grid_streaming.py new file mode 100644 index 0000000..4db3268 --- /dev/null +++ b/klausur-service/backend/cv_cell_grid_streaming.py @@ -0,0 +1,217 @@ +""" +Streaming variants of cell-grid builders (v2 + legacy). + +Extracted from cv_cell_grid.py. These yield cells one-by-one as OCR'd, +useful for progress reporting. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, Generator, List, Optional, Tuple + +import numpy as np + +from cv_vocab_types import PageRegion, RowGeometry +from cv_ocr_engines import ( + RAPIDOCR_AVAILABLE, + _assign_row_words_to_columns, +) +from cv_cell_grid_helpers import ( + _heal_row_gaps, + _is_artifact_row, +) +from cv_cell_grid_build import _ocr_cell_crop +from cv_cell_grid_legacy import _ocr_single_cell + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# build_cell_grid_v2_streaming +# --------------------------------------------------------------------------- + +def build_cell_grid_v2_streaming( + ocr_img: np.ndarray, + column_regions: List[PageRegion], + row_geometries: List[RowGeometry], + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr: Optional[np.ndarray] = None, +) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]: + """Streaming variant of build_cell_grid_v2 -- yields each cell as OCR'd. + + Yields: + (cell_dict, columns_meta, total_cells) + """ + use_rapid = False + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": + engine_name = "tesseract" + elif ocr_engine == "rapid": + if not RAPIDOCR_AVAILABLE: + logger.warning("RapidOCR requested but not available, falling back to Tesseract") + else: + use_rapid = True + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" + + content_rows = [r for r in row_geometries if r.row_type == 'content'] + if not content_rows: + return + + content_rows = [r for r in content_rows if r.word_count > 0] + if not content_rows: + return + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', + 'margin_bottom', 'margin_left', 'margin_right'} + relevant_cols = [c for c in column_regions if c.type not in _skip_types] + if not relevant_cols: + return + + content_rows = [r for r in content_rows if not _is_artifact_row(r)] + if not content_rows: + return + + # Use header/footer boundaries for heal_row_gaps + content_rows.sort(key=lambda r: r.y) + header_rows = [r for r in row_geometries if r.row_type == 'header'] + footer_rows = [r for r in row_geometries if r.row_type == 'footer'] + if header_rows: + top_bound = max(r.y + r.height for r in header_rows) + else: + top_bound = content_rows[0].y + if footer_rows: + bottom_bound = min(r.y for r in footer_rows) + else: + bottom_bound = content_rows[-1].y + content_rows[-1].height + + _heal_row_gaps(content_rows, top_bound=top_bound, bottom_bound=bottom_bound) + + relevant_cols.sort(key=lambda c: c.x) + + columns_meta = [ + {'index': ci, 'type': c.type, 'x': c.x, 'width': c.width} + for ci, c in enumerate(relevant_cols) + ] + + lang_map = { + 'column_en': 'eng', + 'column_de': 'deu', + 'column_example': 'eng+deu', + } + + total_cells = len(content_rows) * len(relevant_cols) + + for row_idx, row in enumerate(content_rows): + for col_idx, col in enumerate(relevant_cols): + cell = _ocr_cell_crop( + row_idx, col_idx, row, col, + ocr_img, img_bgr, img_w, img_h, + engine_name, lang, lang_map, + ) + yield cell, columns_meta, total_cells + + +# --------------------------------------------------------------------------- +# build_cell_grid_streaming — legacy streaming variant +# --------------------------------------------------------------------------- + +def build_cell_grid_streaming( + ocr_img: np.ndarray, + column_regions: List[PageRegion], + row_geometries: List[RowGeometry], + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr: Optional[np.ndarray] = None, +) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]: + """Like build_cell_grid(), but yields each cell as it is OCR'd. + + DEPRECATED: Use build_cell_grid_v2_streaming instead. + + Yields: + (cell_dict, columns_meta, total_cells) for each cell. + """ + use_rapid = False + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": + use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None + engine_name = "rapid" if use_rapid else "tesseract" + elif ocr_engine == "rapid": + if not RAPIDOCR_AVAILABLE: + logger.warning("RapidOCR requested but not available, falling back to Tesseract") + else: + use_rapid = True + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" + + content_rows = [r for r in row_geometries if r.row_type == 'content'] + if not content_rows: + return + + before = len(content_rows) + content_rows = [r for r in content_rows if r.word_count > 0] + skipped = before - len(content_rows) + if skipped > 0: + logger.info(f"build_cell_grid_streaming: skipped {skipped} phantom rows (word_count=0)") + if not content_rows: + return + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + relevant_cols = [c for c in column_regions if c.type not in _skip_types] + if not relevant_cols: + return + + before_art = len(content_rows) + content_rows = [r for r in content_rows if not _is_artifact_row(r)] + artifact_skipped = before_art - len(content_rows) + if artifact_skipped > 0: + logger.info(f"build_cell_grid_streaming: skipped {artifact_skipped} artifact rows") + if not content_rows: + return + _heal_row_gaps( + content_rows, + top_bound=min(c.y for c in relevant_cols), + bottom_bound=max(c.y + c.height for c in relevant_cols), + ) + + relevant_cols.sort(key=lambda c: c.x) + + columns_meta = [ + { + 'index': col_idx, + 'type': col.type, + 'x': col.x, + 'width': col.width, + } + for col_idx, col in enumerate(relevant_cols) + ] + + lang_map = { + 'column_en': 'eng', + 'column_de': 'deu', + 'column_example': 'eng+deu', + } + + total_cells = len(content_rows) * len(relevant_cols) + + for row_idx, row in enumerate(content_rows): + col_words = _assign_row_words_to_columns(row, relevant_cols) + for col_idx, col in enumerate(relevant_cols): + cell = _ocr_single_cell( + row_idx, col_idx, row, col, + ocr_img, img_bgr, img_w, img_h, + use_rapid, engine_name, lang, lang_map, + preassigned_words=col_words[col_idx], + ) + yield cell, columns_meta, total_cells diff --git a/klausur-service/backend/cv_cell_grid_vocab.py b/klausur-service/backend/cv_cell_grid_vocab.py new file mode 100644 index 0000000..d475c33 --- /dev/null +++ b/klausur-service/backend/cv_cell_grid_vocab.py @@ -0,0 +1,200 @@ +""" +Vocabulary extraction: cells -> vocab entries, and build_word_grid wrapper. + +Extracted from cv_cell_grid.py. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List + +from cv_ocr_engines import ( + _attach_example_sentences, + _fix_phonetic_brackets, + _split_comma_entries, +) +from cv_cell_grid_legacy import build_cell_grid +from cv_cell_grid_merge import ( + _merge_continuation_rows, + _merge_phonetic_continuation_rows, + _merge_wrapped_rows, +) + +logger = logging.getLogger(__name__) + + +def _cells_to_vocab_entries( + cells: List[Dict[str, Any]], + columns_meta: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Map generic cells to vocab entries with english/german/example fields. + + Groups cells by row_index, maps col_type -> field name, and produces + one entry per row (only rows with at least one non-empty field). + """ + col_type_to_field = { + 'column_en': 'english', + 'column_de': 'german', + 'column_example': 'example', + 'page_ref': 'source_page', + 'column_marker': 'marker', + 'column_text': 'text', # generic single-column (box sub-sessions) + } + bbox_key_map = { + 'column_en': 'bbox_en', + 'column_de': 'bbox_de', + 'column_example': 'bbox_ex', + 'page_ref': 'bbox_ref', + 'column_marker': 'bbox_marker', + 'column_text': 'bbox_text', + } + + # Group cells by row_index + rows: Dict[int, List[Dict]] = {} + for cell in cells: + ri = cell['row_index'] + rows.setdefault(ri, []).append(cell) + + entries: List[Dict[str, Any]] = [] + for row_idx in sorted(rows.keys()): + row_cells = rows[row_idx] + entry: Dict[str, Any] = { + 'row_index': row_idx, + 'english': '', + 'german': '', + 'example': '', + 'text': '', # generic single-column (box sub-sessions) + 'source_page': '', + 'marker': '', + 'confidence': 0.0, + 'bbox': None, + 'bbox_en': None, + 'bbox_de': None, + 'bbox_ex': None, + 'bbox_ref': None, + 'bbox_marker': None, + 'bbox_text': None, + 'ocr_engine': row_cells[0].get('ocr_engine', '') if row_cells else '', + } + + confidences = [] + for cell in row_cells: + col_type = cell['col_type'] + field = col_type_to_field.get(col_type) + if field: + entry[field] = cell['text'] + bbox_field = bbox_key_map.get(col_type) + if bbox_field: + entry[bbox_field] = cell['bbox_pct'] + if cell['confidence'] > 0: + confidences.append(cell['confidence']) + + # Compute row-level bbox as union of all cell bboxes + all_bboxes = [c['bbox_pct'] for c in row_cells if c.get('bbox_pct')] + if all_bboxes: + min_x = min(b['x'] for b in all_bboxes) + min_y = min(b['y'] for b in all_bboxes) + max_x2 = max(b['x'] + b['w'] for b in all_bboxes) + max_y2 = max(b['y'] + b['h'] for b in all_bboxes) + entry['bbox'] = { + 'x': round(min_x, 2), + 'y': round(min_y, 2), + 'w': round(max_x2 - min_x, 2), + 'h': round(max_y2 - min_y, 2), + } + + entry['confidence'] = round( + sum(confidences) / len(confidences), 1 + ) if confidences else 0.0 + + # Only include if at least one mapped field has text + has_content = any( + entry.get(f) + for f in col_type_to_field.values() + ) + if has_content: + entries.append(entry) + + return entries + + +def build_word_grid( + ocr_img, + column_regions, + row_geometries, + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr=None, + pronunciation: str = "british", +) -> List[Dict[str, Any]]: + """Vocab-specific: Cell-Grid + Vocab-Mapping + Post-Processing. + + Wrapper around build_cell_grid() that adds vocabulary-specific logic: + - Maps cells to english/german/example entries + - Applies character confusion fixes, IPA lookup, comma splitting, etc. + - Falls back to returning raw cells if no vocab columns detected. + + Args: + ocr_img: Binarized full-page image (for Tesseract). + column_regions: Classified columns from Step 3. + row_geometries: Rows from Step 4. + img_w, img_h: Image dimensions. + lang: Default Tesseract language. + ocr_engine: 'tesseract', 'rapid', or 'auto'. + img_bgr: BGR color image (required for RapidOCR). + pronunciation: 'british' or 'american' for IPA lookup. + + Returns: + List of entry dicts with english/german/example text and bbox info (percent). + """ + cells, columns_meta = build_cell_grid( + ocr_img, column_regions, row_geometries, img_w, img_h, + lang=lang, ocr_engine=ocr_engine, img_bgr=img_bgr, + ) + + if not cells: + return [] + + # Check if vocab layout is present + col_types = {c['type'] for c in columns_meta} + if not (col_types & {'column_en', 'column_de'}): + logger.info("build_word_grid: no vocab columns -- returning raw cells") + return cells + + # Vocab mapping: cells -> entries + entries = _cells_to_vocab_entries(cells, columns_meta) + + # --- Post-processing pipeline (deterministic, no LLM) --- + n_raw = len(entries) + + # 0. Merge cell-wrap continuation rows (empty primary column = text wrap) + entries = _merge_wrapped_rows(entries) + + # 0a. Merge phonetic-only continuation rows into previous entry + entries = _merge_phonetic_continuation_rows(entries) + + # 0b. Merge multi-line continuation rows (lowercase EN, empty DE) + entries = _merge_continuation_rows(entries) + + # 1. Character confusion (| -> I, 1 -> I, 8 -> B) is now run in + # llm_review_entries_streaming so changes are visible to the user in Step 6. + + # 2. Replace OCR'd phonetics with dictionary IPA + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + + # 3. Split comma-separated word forms (break, broke, broken -> 3 entries) + entries = _split_comma_entries(entries) + + # 4. Attach example sentences (rows without DE -> examples for preceding entry) + entries = _attach_example_sentences(entries) + + engine_name = cells[0].get('ocr_engine', 'unknown') if cells else 'unknown' + logger.info(f"build_word_grid: {len(entries)} entries from " + f"{n_raw} raw -> {len(entries)} after post-processing " + f"(engine={engine_name})") + + return entries diff --git a/klausur-service/backend/cv_preprocessing.py b/klausur-service/backend/cv_preprocessing.py index 71c4f50..0cb2841 100644 --- a/klausur-service/backend/cv_preprocessing.py +++ b/klausur-service/backend/cv_preprocessing.py @@ -1,14 +1,19 @@ """ Image I/O, orientation detection, deskew, and dewarp for the CV vocabulary pipeline. +Re-export facade -- all logic lives in the sub-modules: + + cv_preprocessing_deskew Rotation correction (Hough, word-alignment, iterative, two-pass) + cv_preprocessing_dewarp Vertical shear detection and correction (4 methods + ensemble) + +This file contains the image I/O and orientation detection functions. + Lizenz: Apache 2.0 (kommerziell nutzbar) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import logging -import time -from collections import defaultdict -from typing import Any, Dict, List, Tuple +from typing import Tuple import numpy as np @@ -19,7 +24,7 @@ from cv_vocab_types import ( logger = logging.getLogger(__name__) -# Guarded imports — mirror cv_vocab_types guards +# Guarded imports try: import cv2 except ImportError: @@ -32,6 +37,33 @@ except ImportError: pytesseract = None # type: ignore[assignment] Image = None # type: ignore[assignment,misc] +# Re-export all deskew functions +from cv_preprocessing_deskew import ( # noqa: F401 + deskew_image, + deskew_image_by_word_alignment, + deskew_image_iterative, + deskew_two_pass, + _projection_gradient_score, + _measure_textline_slope, +) + +# Re-export all dewarp functions +from cv_preprocessing_dewarp import ( # noqa: F401 + _apply_shear, + _detect_shear_angle, + _detect_shear_by_hough, + _detect_shear_by_projection, + _detect_shear_by_text_lines, + _dewarp_quality_check, + _ensemble_shear, + dewarp_image, + dewarp_image_manual, +) + + +# ============================================================================= +# Image I/O +# ============================================================================= def render_pdf_high_res(pdf_data: bytes, page_number: int = 0, zoom: float = 3.0) -> np.ndarray: """Render a PDF page to a high-resolution numpy array (BGR). @@ -54,7 +86,6 @@ def render_pdf_high_res(pdf_data: bytes, page_number: int = 0, zoom: float = 3.0 mat = fitz.Matrix(zoom, zoom) pix = page.get_pixmap(matrix=mat) - # Convert to numpy BGR img_data = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n) if pix.n == 4: # RGBA img_bgr = cv2.cvtColor(img_data, cv2.COLOR_RGBA2BGR) @@ -84,23 +115,19 @@ def render_image_high_res(image_data: bytes) -> np.ndarray: # ============================================================================= -# Stage 1b: Orientation Detection (0°/90°/180°/270°) +# Orientation Detection (0/90/180/270) # ============================================================================= def detect_and_fix_orientation(img_bgr: np.ndarray) -> Tuple[np.ndarray, int]: """Detect page orientation via Tesseract OSD and rotate if needed. - Handles upside-down scans (180°) common with book scanners where - every other page is flipped due to the scanner hinge. - Returns: - (corrected_image, rotation_degrees) — rotation is 0, 90, 180, or 270. + (corrected_image, rotation_degrees) -- rotation is 0, 90, 180, or 270. """ if pytesseract is None: return img_bgr, 0 try: - # Tesseract OSD needs a grayscale or RGB image gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) pil_img = Image.fromarray(gray) @@ -108,12 +135,11 @@ def detect_and_fix_orientation(img_bgr: np.ndarray) -> Tuple[np.ndarray, int]: rotate = osd.get("rotate", 0) confidence = osd.get("orientation_conf", 0.0) - logger.info(f"OSD: orientation={rotate}° confidence={confidence:.1f}") + logger.info(f"OSD: orientation={rotate}\u00b0 confidence={confidence:.1f}") if rotate == 0 or confidence < 1.0: return img_bgr, 0 - # Apply rotation — OSD rotate is the clockwise correction needed if rotate == 180: corrected = cv2.rotate(img_bgr, cv2.ROTATE_180) elif rotate == 90: @@ -123,1044 +149,9 @@ def detect_and_fix_orientation(img_bgr: np.ndarray) -> Tuple[np.ndarray, int]: else: return img_bgr, 0 - logger.info(f"OSD: rotated {rotate}° to fix orientation") + logger.info(f"OSD: rotated {rotate}\u00b0 to fix orientation") return corrected, rotate except Exception as e: logger.warning(f"OSD orientation detection failed: {e}") return img_bgr, 0 - - -# ============================================================================= -# Stage 2: Deskew (Rotation Correction) -# ============================================================================= - -def deskew_image(img: np.ndarray) -> Tuple[np.ndarray, float]: - """Correct rotation using Hough Line detection. - - Args: - img: BGR image. - - Returns: - Tuple of (corrected image, detected angle in degrees). - """ - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - # Binarize for line detection - _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) - - # Detect lines - lines = cv2.HoughLinesP(binary, 1, np.pi / 180, threshold=100, - minLineLength=img.shape[1] // 4, maxLineGap=20) - - if lines is None or len(lines) < 3: - return img, 0.0 - - # Compute angles of near-horizontal lines - angles = [] - for line in lines: - x1, y1, x2, y2 = line[0] - angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) - if abs(angle) < 15: # Only near-horizontal - angles.append(angle) - - if not angles: - return img, 0.0 - - median_angle = float(np.median(angles)) - - # Limit correction to ±5° - if abs(median_angle) > 5.0: - median_angle = 5.0 * np.sign(median_angle) - - if abs(median_angle) < 0.1: - return img, 0.0 - - # Rotate - h, w = img.shape[:2] - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, median_angle, 1.0) - corrected = cv2.warpAffine(img, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - logger.info(f"Deskew: corrected {median_angle:.2f}° rotation") - return corrected, median_angle - - -def deskew_image_by_word_alignment( - image_data: bytes, - lang: str = "eng+deu", - downscale_factor: float = 0.5, -) -> Tuple[bytes, float]: - """Correct rotation by fitting a line through left-most word starts per text line. - - More robust than Hough-based deskew for vocabulary worksheets where text lines - have consistent left-alignment. Runs a quick Tesseract pass on a downscaled - copy to find word positions, computes the dominant left-edge column, fits a - line through those points and rotates the full-resolution image. - - Args: - image_data: Raw image bytes (PNG/JPEG). - lang: Tesseract language string for the quick pass. - downscale_factor: Shrink factor for the quick Tesseract pass (0.5 = 50%). - - Returns: - Tuple of (rotated image as PNG bytes, detected angle in degrees). - """ - if not CV2_AVAILABLE or not TESSERACT_AVAILABLE: - return image_data, 0.0 - - # 1. Decode image - img_array = np.frombuffer(image_data, dtype=np.uint8) - img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) - if img is None: - logger.warning("deskew_by_word_alignment: could not decode image") - return image_data, 0.0 - - orig_h, orig_w = img.shape[:2] - - # 2. Downscale for fast Tesseract pass - small_w = int(orig_w * downscale_factor) - small_h = int(orig_h * downscale_factor) - small = cv2.resize(img, (small_w, small_h), interpolation=cv2.INTER_AREA) - - # 3. Quick Tesseract — word-level positions - pil_small = Image.fromarray(cv2.cvtColor(small, cv2.COLOR_BGR2RGB)) - try: - data = pytesseract.image_to_data( - pil_small, lang=lang, config="--psm 6 --oem 3", - output_type=pytesseract.Output.DICT, - ) - except Exception as e: - logger.warning(f"deskew_by_word_alignment: Tesseract failed: {e}") - return image_data, 0.0 - - # 4. Per text-line, find the left-most word start - # Group by (block_num, par_num, line_num) - line_groups: Dict[tuple, list] = defaultdict(list) - for i in range(len(data["text"])): - text = (data["text"][i] or "").strip() - conf = int(data["conf"][i]) - if not text or conf < 20: - continue - key = (data["block_num"][i], data["par_num"][i], data["line_num"][i]) - line_groups[key].append(i) - - if len(line_groups) < 5: - logger.info(f"deskew_by_word_alignment: only {len(line_groups)} lines, skipping") - return image_data, 0.0 - - # For each line, pick the word with smallest 'left' → compute (left_x, center_y) - # Scale back to original resolution - scale = 1.0 / downscale_factor - points = [] # list of (x, y) in original-image coords - for key, indices in line_groups.items(): - best_idx = min(indices, key=lambda i: data["left"][i]) - lx = data["left"][best_idx] * scale - top = data["top"][best_idx] * scale - h = data["height"][best_idx] * scale - cy = top + h / 2.0 - points.append((lx, cy)) - - # 5. Find dominant left-edge column + compute angle - xs = np.array([p[0] for p in points]) - ys = np.array([p[1] for p in points]) - median_x = float(np.median(xs)) - tolerance = orig_w * 0.03 # 3% of image width - - mask = np.abs(xs - median_x) <= tolerance - filtered_xs = xs[mask] - filtered_ys = ys[mask] - - if len(filtered_xs) < 5: - logger.info(f"deskew_by_word_alignment: only {len(filtered_xs)} aligned points after filter, skipping") - return image_data, 0.0 - - # polyfit: x = a*y + b → a = dx/dy → angle = arctan(a) - coeffs = np.polyfit(filtered_ys, filtered_xs, 1) - slope = coeffs[0] # dx/dy - angle_rad = np.arctan(slope) - angle_deg = float(np.degrees(angle_rad)) - - # Clamp to ±5° - angle_deg = max(-5.0, min(5.0, angle_deg)) - - logger.info(f"deskew_by_word_alignment: detected {angle_deg:.2f}° from {len(filtered_xs)} points " - f"(total lines: {len(line_groups)})") - - if abs(angle_deg) < 0.05: - return image_data, 0.0 - - # 6. Rotate full-res image - center = (orig_w // 2, orig_h // 2) - M = cv2.getRotationMatrix2D(center, angle_deg, 1.0) - rotated = cv2.warpAffine(img, M, (orig_w, orig_h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - # Encode back to PNG - success, png_buf = cv2.imencode(".png", rotated) - if not success: - logger.warning("deskew_by_word_alignment: PNG encoding failed") - return image_data, 0.0 - - return png_buf.tobytes(), angle_deg - - -def _projection_gradient_score(profile: np.ndarray) -> float: - """Score a projection profile by the L2-norm of its first derivative. - - Higher score = sharper transitions between text-lines and gaps, - i.e. better row/column alignment. - """ - diff = np.diff(profile) - return float(np.sum(diff * diff)) - - -def deskew_image_iterative( - img: np.ndarray, - coarse_range: float = 5.0, - coarse_step: float = 0.1, - fine_range: float = 0.15, - fine_step: float = 0.02, -) -> Tuple[np.ndarray, float, Dict[str, Any]]: - """Iterative deskew using vertical-edge projection optimisation. - - The key insight: at the correct rotation angle, vertical features - (word left-edges, column borders) become truly vertical, producing - the sharpest peaks in the vertical projection of vertical edges. - - Method: - 1. Detect vertical edges via Sobel-X on the central crop. - 2. Coarse sweep: rotate edge image, compute vertical projection - gradient score. The angle where vertical edges align best wins. - 3. Fine sweep: refine around the coarse winner. - - Args: - img: BGR image (full resolution). - coarse_range: half-range in degrees for the coarse sweep. - coarse_step: step size in degrees for the coarse sweep. - fine_range: half-range around the coarse winner for the fine sweep. - fine_step: step size in degrees for the fine sweep. - - Returns: - (rotated_bgr, angle_degrees, debug_dict) - """ - h, w = img.shape[:2] - debug: Dict[str, Any] = {} - - # --- Grayscale + vertical edge detection --- - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - # Central crop (15%-85% height, 10%-90% width) to avoid page margins - y_lo, y_hi = int(h * 0.15), int(h * 0.85) - x_lo, x_hi = int(w * 0.10), int(w * 0.90) - gray_crop = gray[y_lo:y_hi, x_lo:x_hi] - - # Sobel-X → absolute vertical edges - sobel_x = cv2.Sobel(gray_crop, cv2.CV_64F, 1, 0, ksize=3) - edges = np.abs(sobel_x) - # Normalise to 0-255 for consistent scoring - edge_max = edges.max() - if edge_max > 0: - edges = (edges / edge_max * 255).astype(np.uint8) - else: - return img, 0.0, {"error": "no edges detected"} - - crop_h, crop_w = edges.shape[:2] - crop_center = (crop_w // 2, crop_h // 2) - - # Trim margin after rotation to avoid border artifacts - trim_y = max(4, int(crop_h * 0.03)) - trim_x = max(4, int(crop_w * 0.03)) - - def _sweep_edges(angles: np.ndarray) -> list: - """Score each angle by vertical projection gradient of vertical edges.""" - results = [] - for angle in angles: - if abs(angle) < 1e-6: - rotated = edges - else: - M = cv2.getRotationMatrix2D(crop_center, angle, 1.0) - rotated = cv2.warpAffine(edges, M, (crop_w, crop_h), - flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_REPLICATE) - # Trim borders to avoid edge artifacts - trimmed = rotated[trim_y:-trim_y, trim_x:-trim_x] - v_profile = np.sum(trimmed, axis=0, dtype=np.float64) - score = _projection_gradient_score(v_profile) - results.append((float(angle), score)) - return results - - # --- Phase 1: coarse sweep --- - coarse_angles = np.arange(-coarse_range, coarse_range + coarse_step * 0.5, coarse_step) - coarse_results = _sweep_edges(coarse_angles) - best_coarse = max(coarse_results, key=lambda x: x[1]) - best_coarse_angle, best_coarse_score = best_coarse - - debug["coarse_best_angle"] = round(best_coarse_angle, 2) - debug["coarse_best_score"] = round(best_coarse_score, 1) - debug["coarse_scores"] = [(round(a, 2), round(s, 1)) for a, s in coarse_results] - - # --- Phase 2: fine sweep around coarse winner --- - fine_lo = best_coarse_angle - fine_range - fine_hi = best_coarse_angle + fine_range - fine_angles = np.arange(fine_lo, fine_hi + fine_step * 0.5, fine_step) - fine_results = _sweep_edges(fine_angles) - best_fine = max(fine_results, key=lambda x: x[1]) - best_fine_angle, best_fine_score = best_fine - - debug["fine_best_angle"] = round(best_fine_angle, 2) - debug["fine_best_score"] = round(best_fine_score, 1) - debug["fine_scores"] = [(round(a, 2), round(s, 1)) for a, s in fine_results] - - final_angle = best_fine_angle - - # Clamp to ±5° - final_angle = max(-5.0, min(5.0, final_angle)) - - logger.info(f"deskew_iterative: coarse={best_coarse_angle:.2f}° fine={best_fine_angle:.2f}° -> {final_angle:.2f}°") - - if abs(final_angle) < 0.05: - return img, 0.0, debug - - # --- Rotate full-res image --- - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, final_angle, 1.0) - rotated = cv2.warpAffine(img, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - return rotated, final_angle, debug - - -def _measure_textline_slope(img: np.ndarray) -> float: - """Measure residual text-line slope via Tesseract word-position regression. - - Groups Tesseract words by (block, par, line), fits a linear regression - per line (y = slope * x + b), and returns the trimmed-mean slope in - degrees. Positive = text rises to the right, negative = falls. - - This is the most direct measurement of remaining rotation after deskew. - """ - import math as _math - - if not TESSERACT_AVAILABLE or not CV2_AVAILABLE: - return 0.0 - - h, w = img.shape[:2] - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - data = pytesseract.image_to_data( - Image.fromarray(gray), - output_type=pytesseract.Output.DICT, - config="--psm 6", - ) - - # Group word centres by text line - lines: Dict[tuple, list] = {} - for i in range(len(data["text"])): - txt = (data["text"][i] or "").strip() - if len(txt) < 2 or int(data["conf"][i]) < 30: - continue - key = (data["block_num"][i], data["par_num"][i], data["line_num"][i]) - cx = data["left"][i] + data["width"][i] / 2.0 - cy = data["top"][i] + data["height"][i] / 2.0 - lines.setdefault(key, []).append((cx, cy)) - - # Per-line linear regression → slope angle - slopes: list = [] - for pts in lines.values(): - if len(pts) < 3: - continue - pts.sort(key=lambda p: p[0]) - xs = np.array([p[0] for p in pts], dtype=np.float64) - ys = np.array([p[1] for p in pts], dtype=np.float64) - if xs[-1] - xs[0] < w * 0.15: - continue # skip short lines - A = np.vstack([xs, np.ones_like(xs)]).T - result = np.linalg.lstsq(A, ys, rcond=None) - slope = result[0][0] - slopes.append(_math.degrees(_math.atan(slope))) - - if len(slopes) < 3: - return 0.0 - - # Trimmed mean (drop 10% extremes on each side) - slopes.sort() - trim = max(1, len(slopes) // 10) - trimmed = slopes[trim:-trim] if len(slopes) > 2 * trim else slopes - if not trimmed: - return 0.0 - - return sum(trimmed) / len(trimmed) - - -def deskew_two_pass( - img: np.ndarray, - coarse_range: float = 5.0, -) -> Tuple[np.ndarray, float, Dict[str, Any]]: - """Two-pass deskew: iterative projection + word-alignment residual check. - - Pass 1: ``deskew_image_iterative()`` (vertical-edge projection, wide range). - Pass 2: ``deskew_image_by_word_alignment()`` on the already-corrected image - to detect and fix residual skew that the projection method missed. - - The two corrections are summed. If the residual from Pass 2 is below - 0.3° it is ignored (already good enough). - - Returns: - (corrected_bgr, total_angle_degrees, debug_dict) - """ - debug: Dict[str, Any] = {} - - # --- Pass 1: iterative projection --- - corrected, angle1, dbg1 = deskew_image_iterative( - img.copy(), coarse_range=coarse_range, - ) - debug["pass1_angle"] = round(angle1, 3) - debug["pass1_method"] = "iterative" - debug["pass1_debug"] = dbg1 - - # --- Pass 2: word-alignment residual check on corrected image --- - angle2 = 0.0 - try: - # Encode the corrected image to PNG bytes for word-alignment - ok, buf = cv2.imencode(".png", corrected) - if ok: - corrected_bytes, angle2 = deskew_image_by_word_alignment(buf.tobytes()) - if abs(angle2) >= 0.3: - # Significant residual — decode and use the second correction - arr2 = np.frombuffer(corrected_bytes, dtype=np.uint8) - corrected2 = cv2.imdecode(arr2, cv2.IMREAD_COLOR) - if corrected2 is not None: - corrected = corrected2 - logger.info(f"deskew_two_pass: pass2 residual={angle2:.2f}° applied " - f"(total={angle1 + angle2:.2f}°)") - else: - angle2 = 0.0 - else: - logger.info(f"deskew_two_pass: pass2 residual={angle2:.2f}° < 0.3° — skipped") - angle2 = 0.0 - except Exception as e: - logger.warning(f"deskew_two_pass: pass2 word-alignment failed: {e}") - angle2 = 0.0 - - # --- Pass 3: Tesseract text-line regression residual check --- - # The most reliable final check: measure actual text-line slopes - # using Tesseract word positions and linear regression per line. - angle3 = 0.0 - try: - residual = _measure_textline_slope(corrected) - debug["pass3_raw"] = round(residual, 3) - if abs(residual) >= 0.3: - h3, w3 = corrected.shape[:2] - center3 = (w3 // 2, h3 // 2) - M3 = cv2.getRotationMatrix2D(center3, residual, 1.0) - corrected = cv2.warpAffine( - corrected, M3, (w3, h3), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE, - ) - angle3 = residual - logger.info( - "deskew_two_pass: pass3 text-line residual=%.2f° applied", - residual, - ) - else: - logger.info( - "deskew_two_pass: pass3 text-line residual=%.2f° < 0.3° — skipped", - residual, - ) - except Exception as e: - logger.warning("deskew_two_pass: pass3 text-line check failed: %s", e) - - total_angle = angle1 + angle2 + angle3 - debug["pass2_angle"] = round(angle2, 3) - debug["pass2_method"] = "word_alignment" - debug["pass3_angle"] = round(angle3, 3) - debug["pass3_method"] = "textline_regression" - debug["total_angle"] = round(total_angle, 3) - - logger.info( - "deskew_two_pass: pass1=%.2f° + pass2=%.2f° + pass3=%.2f° = %.2f°", - angle1, angle2, angle3, total_angle, - ) - - return corrected, total_angle, debug - - -# ============================================================================= -# Stage 3: Dewarp (Book Curvature Correction) -# ============================================================================= - -def _detect_shear_angle(img: np.ndarray) -> Dict[str, Any]: - """Detect the vertical shear angle of the page. - - After deskew (horizontal lines aligned), vertical features like column - edges may still be tilted. This measures that tilt by tracking the - strongest vertical edge across horizontal strips. - - The result is a shear angle in degrees: the angular difference between - true vertical and the detected column edge. - - Returns: - Dict with keys: method, shear_degrees, confidence. - """ - h, w = img.shape[:2] - result = {"method": "vertical_edge", "shear_degrees": 0.0, "confidence": 0.0} - - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - # Vertical Sobel to find vertical edges - sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) - abs_sobel = np.abs(sobel_x).astype(np.uint8) - - # Binarize with Otsu - _, binary = cv2.threshold(abs_sobel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) - - num_strips = 20 - strip_h = h // num_strips - edge_positions = [] # (y_center, x_position) - - for i in range(num_strips): - y_start = i * strip_h - y_end = min((i + 1) * strip_h, h) - strip = binary[y_start:y_end, :] - - # Project vertically (sum along y-axis) - projection = np.sum(strip, axis=0).astype(np.float64) - if projection.max() == 0: - continue - - # Find the strongest vertical edge in left 40% of image - search_w = int(w * 0.4) - left_proj = projection[:search_w] - if left_proj.max() == 0: - continue - - # Smooth and find peak - kernel_size = max(3, w // 100) - if kernel_size % 2 == 0: - kernel_size += 1 - smoothed = cv2.GaussianBlur(left_proj.reshape(1, -1), (kernel_size, 1), 0).flatten() - x_pos = float(np.argmax(smoothed)) - y_center = (y_start + y_end) / 2.0 - edge_positions.append((y_center, x_pos)) - - if len(edge_positions) < 8: - return result - - ys = np.array([p[0] for p in edge_positions]) - xs = np.array([p[1] for p in edge_positions]) - - # Remove outliers (> 2 std from median) - median_x = np.median(xs) - std_x = max(np.std(xs), 1.0) - mask = np.abs(xs - median_x) < 2 * std_x - ys = ys[mask] - xs = xs[mask] - - if len(ys) < 6: - return result - - # Fit straight line: x = slope * y + intercept - # The slope tells us the tilt of the vertical edge - straight_coeffs = np.polyfit(ys, xs, 1) - slope = straight_coeffs[0] # dx/dy in pixels - fitted = np.polyval(straight_coeffs, ys) - residuals = xs - fitted - rmse = float(np.sqrt(np.mean(residuals ** 2))) - - # Convert slope to angle: arctan(dx/dy) in degrees - import math - shear_degrees = math.degrees(math.atan(slope)) - - confidence = min(1.0, len(ys) / 15.0) * max(0.5, 1.0 - rmse / 5.0) - - result["shear_degrees"] = round(shear_degrees, 3) - result["confidence"] = round(float(confidence), 2) - - return result - - -def _detect_shear_by_projection(img: np.ndarray) -> Dict[str, Any]: - """Detect shear angle by maximising variance of horizontal text-line projections. - - Principle: horizontal text lines produce a row-projection profile with sharp - peaks (high variance) when the image is correctly aligned. Any residual shear - smears the peaks and reduces variance. We sweep ±3° and pick the angle whose - corrected projection has the highest variance. - - Works best on pages with clear horizontal banding (vocabulary tables, prose). - Complements _detect_shear_angle() which needs strong vertical edges. - - Returns: - Dict with keys: method, shear_degrees, confidence. - """ - import math - result = {"method": "projection", "shear_degrees": 0.0, "confidence": 0.0} - - h, w = img.shape[:2] - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - # Otsu binarisation - _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) - - # Work at half resolution for speed - small = cv2.resize(binary, (w // 2, h // 2), interpolation=cv2.INTER_AREA) - sh, sw = small.shape - - # 2-pass angle sweep for 10x better precision: - # Pass 1: Coarse sweep ±3° in 0.5° steps (13 values) - # Pass 2: Fine sweep ±0.5° around coarse best in 0.05° steps (21 values) - - def _sweep_variance(angles_list): - results = [] - for angle_deg in angles_list: - if abs(angle_deg) < 0.001: - rotated = small - else: - shear_tan = math.tan(math.radians(angle_deg)) - M = np.float32([[1, shear_tan, -sh / 2.0 * shear_tan], [0, 1, 0]]) - rotated = cv2.warpAffine(small, M, (sw, sh), - flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT) - profile = np.sum(rotated, axis=1).astype(float) - results.append((angle_deg, float(np.var(profile)))) - return results - - # Pass 1: coarse - coarse_angles = [a * 0.5 for a in range(-6, 7)] # 13 values - coarse_results = _sweep_variance(coarse_angles) - coarse_best = max(coarse_results, key=lambda x: x[1]) - - # Pass 2: fine around coarse best - fine_center = coarse_best[0] - fine_angles = [fine_center + a * 0.05 for a in range(-10, 11)] # 21 values - fine_results = _sweep_variance(fine_angles) - fine_best = max(fine_results, key=lambda x: x[1]) - - best_angle = fine_best[0] - best_variance = fine_best[1] - variances = coarse_results + fine_results - - # Confidence: how much sharper is the best angle vs. the mean? - all_mean = sum(v for _, v in variances) / len(variances) - if all_mean > 0 and best_variance > all_mean: - confidence = min(1.0, (best_variance - all_mean) / (all_mean + 1.0) * 0.6) - else: - confidence = 0.0 - - result["shear_degrees"] = round(best_angle, 3) - result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) - return result - - -def _detect_shear_by_hough(img: np.ndarray) -> Dict[str, Any]: - """Detect shear using Hough transform on printed table / ruled lines. - - Vocabulary worksheets have near-horizontal printed table borders. After - deskew these should be exactly horizontal; any residual tilt equals the - vertical shear angle (with inverted sign). - - The sign convention: a horizontal line tilting +α degrees (left end lower) - means the page has vertical shear of -α degrees (left column edge drifts - to the left going downward). - - Returns: - Dict with keys: method, shear_degrees, confidence. - """ - result = {"method": "hough_lines", "shear_degrees": 0.0, "confidence": 0.0} - - h, w = img.shape[:2] - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - edges = cv2.Canny(gray, 50, 150, apertureSize=3) - - min_len = int(w * 0.15) - lines = cv2.HoughLinesP( - edges, rho=1, theta=np.pi / 360, - threshold=int(w * 0.08), - minLineLength=min_len, - maxLineGap=20, - ) - - if lines is None or len(lines) < 3: - return result - - horizontal_angles: List[Tuple[float, float]] = [] - for line in lines: - x1, y1, x2, y2 = line[0] - if x1 == x2: - continue - angle = float(np.degrees(np.arctan2(y2 - y1, x2 - x1))) - if abs(angle) <= 5.0: - length = float(np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)) - horizontal_angles.append((angle, length)) - - if len(horizontal_angles) < 3: - return result - - # Weighted median - angles_arr = np.array([a for a, _ in horizontal_angles]) - weights_arr = np.array([l for _, l in horizontal_angles]) - sorted_idx = np.argsort(angles_arr) - s_angles = angles_arr[sorted_idx] - s_weights = weights_arr[sorted_idx] - cum = np.cumsum(s_weights) - mid_idx = int(np.searchsorted(cum, cum[-1] / 2.0)) - median_angle = float(s_angles[min(mid_idx, len(s_angles) - 1)]) - - agree = sum(1 for a, _ in horizontal_angles if abs(a - median_angle) < 1.0) - confidence = min(1.0, agree / max(len(horizontal_angles), 1)) * 0.85 - - # Sign inversion: horizontal line tilt is complementary to vertical shear - shear_degrees = -median_angle - - result["shear_degrees"] = round(shear_degrees, 3) - result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) - return result - - -def _detect_shear_by_text_lines(img: np.ndarray) -> Dict[str, Any]: - """Detect shear by measuring text-line straightness (Method D). - - Runs a quick Tesseract scan (PSM 11, 50% downscale) to locate word - bounding boxes, groups them into vertical columns by X-proximity, - and measures how the left-edge X position drifts with Y (vertical - position). The drift dx/dy is the tangent of the shear angle. - - This directly measures vertical shear (column tilt) rather than - horizontal text-line slope, which is already corrected by deskew. - - Returns: - Dict with keys: method, shear_degrees, confidence. - """ - import math - result = {"method": "text_lines", "shear_degrees": 0.0, "confidence": 0.0} - - h, w = img.shape[:2] - # Downscale 50% for speed - scale = 0.5 - small = cv2.resize(img, (int(w * scale), int(h * scale)), - interpolation=cv2.INTER_AREA) - gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) - pil_img = Image.fromarray(gray) - - try: - data = pytesseract.image_to_data( - pil_img, lang='eng+deu', config='--psm 11 --oem 3', - output_type=pytesseract.Output.DICT, - ) - except Exception: - return result - - # Collect word left-edges (x) and vertical centres (y) - words = [] - for i in range(len(data['text'])): - text = data['text'][i].strip() - conf = int(data['conf'][i]) - if not text or conf < 20 or len(text) < 2: - continue - left_x = float(data['left'][i]) - cy = data['top'][i] + data['height'][i] / 2.0 - word_w = float(data['width'][i]) - words.append((left_x, cy, word_w)) - - if len(words) < 15: - return result - - # --- Group words into vertical columns by left-edge X proximity --- - # Sort by x, then cluster words whose left-edges are within x_tol - avg_w = sum(ww for _, _, ww in words) / len(words) - x_tol = max(avg_w * 0.4, 8) # tolerance for "same column" - - words_by_x = sorted(words, key=lambda w: w[0]) - columns: List[List[Tuple[float, float]]] = [] # each: [(left_x, cy), ...] - cur_col: List[Tuple[float, float]] = [(words_by_x[0][0], words_by_x[0][1])] - cur_x = words_by_x[0][0] - - for lx, cy, _ in words_by_x[1:]: - if abs(lx - cur_x) <= x_tol: - cur_col.append((lx, cy)) - # Update running x as median of cluster - cur_x = cur_x * 0.8 + lx * 0.2 - else: - if len(cur_col) >= 5: - columns.append(cur_col) - cur_col = [(lx, cy)] - cur_x = lx - if len(cur_col) >= 5: - columns.append(cur_col) - - if len(columns) < 2: - return result - - # --- For each column, measure X-drift as a function of Y --- - # Fit: left_x = a * cy + b → a = dx/dy = tan(shear_angle) - drifts = [] - for col in columns: - ys = np.array([p[1] for p in col]) - xs = np.array([p[0] for p in col]) - y_range = ys.max() - ys.min() - if y_range < h * scale * 0.3: - continue # column must span at least 30% of image height - # Linear regression: x = a*y + b - coeffs = np.polyfit(ys, xs, 1) - drifts.append(coeffs[0]) # dx/dy - - if len(drifts) < 2: - return result - - # Median dx/dy → shear angle - # dx/dy > 0 means left-edges move RIGHT as we go DOWN → columns lean right - median_drift = float(np.median(drifts)) - shear_degrees = math.degrees(math.atan(median_drift)) - - # Confidence from column count + drift consistency - drift_std = float(np.std(drifts)) - consistency = max(0.0, 1.0 - drift_std * 50) # tighter penalty for drift variance - count_factor = min(1.0, len(drifts) / 4.0) - confidence = count_factor * 0.5 + consistency * 0.5 - - result["shear_degrees"] = round(shear_degrees, 3) - result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) - logger.info("text_lines(v2): %d columns, %d drifts, median=%.4f, " - "shear=%.3f°, conf=%.2f", - len(columns), len(drifts), median_drift, - shear_degrees, confidence) - return result - - -def _dewarp_quality_check(original: np.ndarray, corrected: np.ndarray) -> bool: - """Check whether the dewarp correction actually improved alignment. - - Compares horizontal projection variance before and after correction. - Higher variance means sharper text-line peaks, which indicates better - horizontal alignment. - - Returns True if the correction improved the image, False if it should - be discarded. - """ - def _h_proj_variance(img: np.ndarray) -> float: - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - _, binary = cv2.threshold(gray, 0, 255, - cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) - small = cv2.resize(binary, (binary.shape[1] // 2, binary.shape[0] // 2), - interpolation=cv2.INTER_AREA) - profile = np.sum(small, axis=1).astype(float) - return float(np.var(profile)) - - var_before = _h_proj_variance(original) - var_after = _h_proj_variance(corrected) - - # Correction must improve variance (even by a tiny margin) - return var_after > var_before - - -def _apply_shear(img: np.ndarray, shear_degrees: float) -> np.ndarray: - """Apply a vertical shear correction to an image. - - Shifts each row horizontally proportional to its distance from the - vertical center. This corrects the tilt of vertical features (columns) - without affecting horizontal alignment (text lines). - - Args: - img: BGR image. - shear_degrees: Shear angle in degrees. Positive = shift top-right/bottom-left. - - Returns: - Corrected image. - """ - import math - h, w = img.shape[:2] - shear_tan = math.tan(math.radians(shear_degrees)) - - # Affine matrix: shift x by shear_tan * (y - h/2) - # [1 shear_tan -h/2*shear_tan] - # [0 1 0 ] - M = np.float32([ - [1, shear_tan, -h / 2.0 * shear_tan], - [0, 1, 0], - ]) - - corrected = cv2.warpAffine(img, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - return corrected - - -def _ensemble_shear(detections: List[Dict[str, Any]]) -> Tuple[float, float, str]: - """Combine multiple shear detections into a single weighted estimate (v2). - - Ensemble v2 changes vs v1: - - Minimum confidence raised to 0.5 (was 0.3) - - text_lines method gets 1.5× weight boost (most reliable detector) - - Outlier filter at 1° from weighted mean - - Returns: - (shear_degrees, ensemble_confidence, methods_used_str) - """ - # Confidence threshold — lowered from 0.5 to 0.35 to catch subtle shear - # that individual methods detect with moderate confidence. - _MIN_CONF = 0.35 - - # text_lines gets a weight boost as the most content-aware method - _METHOD_WEIGHT_BOOST = {"text_lines": 1.5} - - accepted = [] - for d in detections: - if d["confidence"] < _MIN_CONF: - continue - boost = _METHOD_WEIGHT_BOOST.get(d["method"], 1.0) - effective_conf = d["confidence"] * boost - accepted.append((d["shear_degrees"], effective_conf, d["method"])) - - if not accepted: - return 0.0, 0.0, "none" - - if len(accepted) == 1: - deg, conf, method = accepted[0] - return deg, min(conf, 1.0), method - - # First pass: weighted mean - total_w = sum(c for _, c, _ in accepted) - w_mean = sum(d * c for d, c, _ in accepted) / total_w - - # Outlier filter: keep results within 1° of weighted mean - filtered = [(d, c, m) for d, c, m in accepted if abs(d - w_mean) <= 1.0] - if not filtered: - filtered = accepted # fallback: keep all - - # Second pass: weighted mean on filtered results - total_w2 = sum(c for _, c, _ in filtered) - final_deg = sum(d * c for d, c, _ in filtered) / total_w2 - - # Ensemble confidence: average of individual confidences, boosted when - # methods agree (all within 0.5° of each other) - avg_conf = total_w2 / len(filtered) - spread = max(d for d, _, _ in filtered) - min(d for d, _, _ in filtered) - agreement_bonus = 0.15 if spread < 0.5 else 0.0 - ensemble_conf = min(1.0, avg_conf + agreement_bonus) - - methods_str = "+".join(m for _, _, m in filtered) - return round(final_deg, 3), round(min(ensemble_conf, 1.0), 2), methods_str - - -def dewarp_image(img: np.ndarray, use_ensemble: bool = True) -> Tuple[np.ndarray, Dict[str, Any]]: - """Correct vertical shear after deskew (v2 with quality gate). - - After deskew aligns horizontal text lines, vertical features (column - edges) may still be tilted. This detects the tilt angle using an ensemble - of four complementary methods and applies an affine shear correction. - - Methods (all run in ~150ms total): - A. _detect_shear_angle() — vertical edge profile (~50ms) - B. _detect_shear_by_projection() — horizontal text-line variance (~30ms) - C. _detect_shear_by_hough() — Hough lines on table borders (~20ms) - D. _detect_shear_by_text_lines() — text-line straightness (~50ms) - - Quality gate: after correction, horizontal projection variance is compared - before vs after. If correction worsened alignment, it is discarded. - - Args: - img: BGR image (already deskewed). - use_ensemble: If False, fall back to single-method behaviour (method A only). - - Returns: - Tuple of (corrected_image, dewarp_info). - dewarp_info keys: method, shear_degrees, confidence, detections. - """ - no_correction = { - "method": "none", - "shear_degrees": 0.0, - "confidence": 0.0, - "detections": [], - } - - if not CV2_AVAILABLE: - return img, no_correction - - t0 = time.time() - - if use_ensemble: - det_a = _detect_shear_angle(img) - det_b = _detect_shear_by_projection(img) - det_c = _detect_shear_by_hough(img) - det_d = _detect_shear_by_text_lines(img) - detections = [det_a, det_b, det_c, det_d] - shear_deg, confidence, method = _ensemble_shear(detections) - else: - det_a = _detect_shear_angle(img) - detections = [det_a] - shear_deg = det_a["shear_degrees"] - confidence = det_a["confidence"] - method = det_a["method"] - - duration = time.time() - t0 - - logger.info( - "dewarp: ensemble shear=%.3f° conf=%.2f method=%s (%.2fs) | " - "A=%.3f/%.2f B=%.3f/%.2f C=%.3f/%.2f D=%.3f/%.2f", - shear_deg, confidence, method, duration, - detections[0]["shear_degrees"], detections[0]["confidence"], - detections[1]["shear_degrees"] if len(detections) > 1 else 0.0, - detections[1]["confidence"] if len(detections) > 1 else 0.0, - detections[2]["shear_degrees"] if len(detections) > 2 else 0.0, - detections[2]["confidence"] if len(detections) > 2 else 0.0, - detections[3]["shear_degrees"] if len(detections) > 3 else 0.0, - detections[3]["confidence"] if len(detections) > 3 else 0.0, - ) - - # Always include individual detections (even when no correction applied) - _all_detections = [ - {"method": d["method"], "shear_degrees": d["shear_degrees"], - "confidence": d["confidence"]} - for d in detections - ] - - # Thresholds: very small shear (<0.08°) is truly irrelevant for OCR. - # For ensemble confidence, require at least 0.4 (lowered from 0.5 to - # catch moderate-confidence detections from multiple agreeing methods). - if abs(shear_deg) < 0.08 or confidence < 0.4: - no_correction["detections"] = _all_detections - return img, no_correction - - # Apply correction (negate the detected shear to straighten) - corrected = _apply_shear(img, -shear_deg) - - # Quality gate: verify the correction actually improved alignment. - # For small corrections (< 0.5°), the projection variance change can be - # negligible, so we skip the quality gate — the cost of a tiny wrong - # correction is much less than the cost of leaving 0.4° uncorrected - # (which shifts content ~25px at image edges on tall scans). - if abs(shear_deg) >= 0.5 and not _dewarp_quality_check(img, corrected): - logger.info("dewarp: quality gate REJECTED correction (%.3f°) — " - "projection variance did not improve", shear_deg) - no_correction["detections"] = _all_detections - return img, no_correction - - info = { - "method": method, - "shear_degrees": shear_deg, - "confidence": confidence, - "detections": _all_detections, - } - - return corrected, info - - -def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray: - """Apply shear correction with a manual angle. - - Args: - img: BGR image (deskewed, before dewarp). - shear_degrees: Shear angle in degrees to correct. - - Returns: - Corrected image. - """ - if abs(shear_degrees) < 0.001: - return img - return _apply_shear(img, -shear_degrees) - diff --git a/klausur-service/backend/cv_preprocessing_deskew.py b/klausur-service/backend/cv_preprocessing_deskew.py new file mode 100644 index 0000000..1bdb27e --- /dev/null +++ b/klausur-service/backend/cv_preprocessing_deskew.py @@ -0,0 +1,437 @@ +""" +CV Preprocessing Deskew — Rotation correction via Hough lines, word alignment, and iterative projection. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from collections import defaultdict +from typing import Any, Dict, Tuple + +import numpy as np + +from cv_vocab_types import ( + CV2_AVAILABLE, + TESSERACT_AVAILABLE, +) + +logger = logging.getLogger(__name__) + +try: + import cv2 +except ImportError: + cv2 = None # type: ignore[assignment] + +try: + import pytesseract + from PIL import Image +except ImportError: + pytesseract = None # type: ignore[assignment] + Image = None # type: ignore[assignment,misc] + + +# ============================================================================= +# Deskew via Hough Lines +# ============================================================================= + +def deskew_image(img: np.ndarray) -> Tuple[np.ndarray, float]: + """Correct rotation using Hough Line detection. + + Args: + img: BGR image. + + Returns: + Tuple of (corrected image, detected angle in degrees). + """ + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + lines = cv2.HoughLinesP(binary, 1, np.pi / 180, threshold=100, + minLineLength=img.shape[1] // 4, maxLineGap=20) + + if lines is None or len(lines) < 3: + return img, 0.0 + + angles = [] + for line in lines: + x1, y1, x2, y2 = line[0] + angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) + if abs(angle) < 15: + angles.append(angle) + + if not angles: + return img, 0.0 + + median_angle = float(np.median(angles)) + + if abs(median_angle) > 5.0: + median_angle = 5.0 * np.sign(median_angle) + + if abs(median_angle) < 0.1: + return img, 0.0 + + h, w = img.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, median_angle, 1.0) + corrected = cv2.warpAffine(img, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + logger.info(f"Deskew: corrected {median_angle:.2f}\u00b0 rotation") + return corrected, median_angle + + +# ============================================================================= +# Deskew via Word Alignment +# ============================================================================= + +def deskew_image_by_word_alignment( + image_data: bytes, + lang: str = "eng+deu", + downscale_factor: float = 0.5, +) -> Tuple[bytes, float]: + """Correct rotation by fitting a line through left-most word starts per text line. + + More robust than Hough-based deskew for vocabulary worksheets where text lines + have consistent left-alignment. + + Args: + image_data: Raw image bytes (PNG/JPEG). + lang: Tesseract language string for the quick pass. + downscale_factor: Shrink factor for the quick Tesseract pass (0.5 = 50%). + + Returns: + Tuple of (rotated image as PNG bytes, detected angle in degrees). + """ + if not CV2_AVAILABLE or not TESSERACT_AVAILABLE: + return image_data, 0.0 + + img_array = np.frombuffer(image_data, dtype=np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + if img is None: + logger.warning("deskew_by_word_alignment: could not decode image") + return image_data, 0.0 + + orig_h, orig_w = img.shape[:2] + + small_w = int(orig_w * downscale_factor) + small_h = int(orig_h * downscale_factor) + small = cv2.resize(img, (small_w, small_h), interpolation=cv2.INTER_AREA) + + pil_small = Image.fromarray(cv2.cvtColor(small, cv2.COLOR_BGR2RGB)) + try: + data = pytesseract.image_to_data( + pil_small, lang=lang, config="--psm 6 --oem 3", + output_type=pytesseract.Output.DICT, + ) + except Exception as e: + logger.warning(f"deskew_by_word_alignment: Tesseract failed: {e}") + return image_data, 0.0 + + line_groups: Dict[tuple, list] = defaultdict(list) + for i in range(len(data["text"])): + text = (data["text"][i] or "").strip() + conf = int(data["conf"][i]) + if not text or conf < 20: + continue + key = (data["block_num"][i], data["par_num"][i], data["line_num"][i]) + line_groups[key].append(i) + + if len(line_groups) < 5: + logger.info(f"deskew_by_word_alignment: only {len(line_groups)} lines, skipping") + return image_data, 0.0 + + scale = 1.0 / downscale_factor + points = [] + for key, indices in line_groups.items(): + best_idx = min(indices, key=lambda i: data["left"][i]) + lx = data["left"][best_idx] * scale + top = data["top"][best_idx] * scale + h = data["height"][best_idx] * scale + cy = top + h / 2.0 + points.append((lx, cy)) + + xs = np.array([p[0] for p in points]) + ys = np.array([p[1] for p in points]) + median_x = float(np.median(xs)) + tolerance = orig_w * 0.03 + + mask = np.abs(xs - median_x) <= tolerance + filtered_xs = xs[mask] + filtered_ys = ys[mask] + + if len(filtered_xs) < 5: + logger.info(f"deskew_by_word_alignment: only {len(filtered_xs)} aligned points after filter, skipping") + return image_data, 0.0 + + coeffs = np.polyfit(filtered_ys, filtered_xs, 1) + slope = coeffs[0] + angle_rad = np.arctan(slope) + angle_deg = float(np.degrees(angle_rad)) + + angle_deg = max(-5.0, min(5.0, angle_deg)) + + logger.info(f"deskew_by_word_alignment: detected {angle_deg:.2f}\u00b0 from {len(filtered_xs)} points " + f"(total lines: {len(line_groups)})") + + if abs(angle_deg) < 0.05: + return image_data, 0.0 + + center = (orig_w // 2, orig_h // 2) + M = cv2.getRotationMatrix2D(center, angle_deg, 1.0) + rotated = cv2.warpAffine(img, M, (orig_w, orig_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + success, png_buf = cv2.imencode(".png", rotated) + if not success: + logger.warning("deskew_by_word_alignment: PNG encoding failed") + return image_data, 0.0 + + return png_buf.tobytes(), angle_deg + + +# ============================================================================= +# Projection Gradient Scoring +# ============================================================================= + +def _projection_gradient_score(profile: np.ndarray) -> float: + """Score a projection profile by the L2-norm of its first derivative.""" + diff = np.diff(profile) + return float(np.sum(diff * diff)) + + +# ============================================================================= +# Iterative Deskew (Vertical-Edge Projection) +# ============================================================================= + +def deskew_image_iterative( + img: np.ndarray, + coarse_range: float = 5.0, + coarse_step: float = 0.1, + fine_range: float = 0.15, + fine_step: float = 0.02, +) -> Tuple[np.ndarray, float, Dict[str, Any]]: + """Iterative deskew using vertical-edge projection optimisation. + + Args: + img: BGR image (full resolution). + coarse_range: half-range in degrees for the coarse sweep. + coarse_step: step size in degrees for the coarse sweep. + fine_range: half-range around the coarse winner for the fine sweep. + fine_step: step size in degrees for the fine sweep. + + Returns: + (rotated_bgr, angle_degrees, debug_dict) + """ + h, w = img.shape[:2] + debug: Dict[str, Any] = {} + + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + y_lo, y_hi = int(h * 0.15), int(h * 0.85) + x_lo, x_hi = int(w * 0.10), int(w * 0.90) + gray_crop = gray[y_lo:y_hi, x_lo:x_hi] + + sobel_x = cv2.Sobel(gray_crop, cv2.CV_64F, 1, 0, ksize=3) + edges = np.abs(sobel_x) + edge_max = edges.max() + if edge_max > 0: + edges = (edges / edge_max * 255).astype(np.uint8) + else: + return img, 0.0, {"error": "no edges detected"} + + crop_h, crop_w = edges.shape[:2] + crop_center = (crop_w // 2, crop_h // 2) + + trim_y = max(4, int(crop_h * 0.03)) + trim_x = max(4, int(crop_w * 0.03)) + + def _sweep_edges(angles: np.ndarray) -> list: + results = [] + for angle in angles: + if abs(angle) < 1e-6: + rotated = edges + else: + M = cv2.getRotationMatrix2D(crop_center, angle, 1.0) + rotated = cv2.warpAffine(edges, M, (crop_w, crop_h), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_REPLICATE) + trimmed = rotated[trim_y:-trim_y, trim_x:-trim_x] + v_profile = np.sum(trimmed, axis=0, dtype=np.float64) + score = _projection_gradient_score(v_profile) + results.append((float(angle), score)) + return results + + coarse_angles = np.arange(-coarse_range, coarse_range + coarse_step * 0.5, coarse_step) + coarse_results = _sweep_edges(coarse_angles) + best_coarse = max(coarse_results, key=lambda x: x[1]) + best_coarse_angle, best_coarse_score = best_coarse + + debug["coarse_best_angle"] = round(best_coarse_angle, 2) + debug["coarse_best_score"] = round(best_coarse_score, 1) + debug["coarse_scores"] = [(round(a, 2), round(s, 1)) for a, s in coarse_results] + + fine_lo = best_coarse_angle - fine_range + fine_hi = best_coarse_angle + fine_range + fine_angles = np.arange(fine_lo, fine_hi + fine_step * 0.5, fine_step) + fine_results = _sweep_edges(fine_angles) + best_fine = max(fine_results, key=lambda x: x[1]) + best_fine_angle, best_fine_score = best_fine + + debug["fine_best_angle"] = round(best_fine_angle, 2) + debug["fine_best_score"] = round(best_fine_score, 1) + debug["fine_scores"] = [(round(a, 2), round(s, 1)) for a, s in fine_results] + + final_angle = best_fine_angle + final_angle = max(-5.0, min(5.0, final_angle)) + + logger.info(f"deskew_iterative: coarse={best_coarse_angle:.2f}\u00b0 fine={best_fine_angle:.2f}\u00b0 -> {final_angle:.2f}\u00b0") + + if abs(final_angle) < 0.05: + return img, 0.0, debug + + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, final_angle, 1.0) + rotated = cv2.warpAffine(img, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + return rotated, final_angle, debug + + +# ============================================================================= +# Text-Line Slope Measurement +# ============================================================================= + +def _measure_textline_slope(img: np.ndarray) -> float: + """Measure residual text-line slope via Tesseract word-position regression.""" + import math as _math + + if not TESSERACT_AVAILABLE or not CV2_AVAILABLE: + return 0.0 + + h, w = img.shape[:2] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + data = pytesseract.image_to_data( + Image.fromarray(gray), + output_type=pytesseract.Output.DICT, + config="--psm 6", + ) + + lines: Dict[tuple, list] = {} + for i in range(len(data["text"])): + txt = (data["text"][i] or "").strip() + if len(txt) < 2 or int(data["conf"][i]) < 30: + continue + key = (data["block_num"][i], data["par_num"][i], data["line_num"][i]) + cx = data["left"][i] + data["width"][i] / 2.0 + cy = data["top"][i] + data["height"][i] / 2.0 + lines.setdefault(key, []).append((cx, cy)) + + slopes: list = [] + for pts in lines.values(): + if len(pts) < 3: + continue + pts.sort(key=lambda p: p[0]) + xs = np.array([p[0] for p in pts], dtype=np.float64) + ys = np.array([p[1] for p in pts], dtype=np.float64) + if xs[-1] - xs[0] < w * 0.15: + continue + A = np.vstack([xs, np.ones_like(xs)]).T + result = np.linalg.lstsq(A, ys, rcond=None) + slope = result[0][0] + slopes.append(_math.degrees(_math.atan(slope))) + + if len(slopes) < 3: + return 0.0 + + slopes.sort() + trim = max(1, len(slopes) // 10) + trimmed = slopes[trim:-trim] if len(slopes) > 2 * trim else slopes + if not trimmed: + return 0.0 + + return sum(trimmed) / len(trimmed) + + +# ============================================================================= +# Two-Pass Deskew +# ============================================================================= + +def deskew_two_pass( + img: np.ndarray, + coarse_range: float = 5.0, +) -> Tuple[np.ndarray, float, Dict[str, Any]]: + """Two-pass deskew: iterative projection + word-alignment residual check. + + Returns: + (corrected_bgr, total_angle_degrees, debug_dict) + """ + debug: Dict[str, Any] = {} + + # --- Pass 1: iterative projection --- + corrected, angle1, dbg1 = deskew_image_iterative( + img.copy(), coarse_range=coarse_range, + ) + debug["pass1_angle"] = round(angle1, 3) + debug["pass1_method"] = "iterative" + debug["pass1_debug"] = dbg1 + + # --- Pass 2: word-alignment residual check --- + angle2 = 0.0 + try: + ok, buf = cv2.imencode(".png", corrected) + if ok: + corrected_bytes, angle2 = deskew_image_by_word_alignment(buf.tobytes()) + if abs(angle2) >= 0.3: + arr2 = np.frombuffer(corrected_bytes, dtype=np.uint8) + corrected2 = cv2.imdecode(arr2, cv2.IMREAD_COLOR) + if corrected2 is not None: + corrected = corrected2 + logger.info(f"deskew_two_pass: pass2 residual={angle2:.2f}\u00b0 applied " + f"(total={angle1 + angle2:.2f}\u00b0)") + else: + angle2 = 0.0 + else: + logger.info(f"deskew_two_pass: pass2 residual={angle2:.2f}\u00b0 < 0.3\u00b0 -- skipped") + angle2 = 0.0 + except Exception as e: + logger.warning(f"deskew_two_pass: pass2 word-alignment failed: {e}") + angle2 = 0.0 + + # --- Pass 3: Tesseract text-line regression residual check --- + angle3 = 0.0 + try: + residual = _measure_textline_slope(corrected) + debug["pass3_raw"] = round(residual, 3) + if abs(residual) >= 0.3: + h3, w3 = corrected.shape[:2] + center3 = (w3 // 2, h3 // 2) + M3 = cv2.getRotationMatrix2D(center3, residual, 1.0) + corrected = cv2.warpAffine( + corrected, M3, (w3, h3), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE, + ) + angle3 = residual + logger.info("deskew_two_pass: pass3 text-line residual=%.2f\u00b0 applied", residual) + else: + logger.info("deskew_two_pass: pass3 text-line residual=%.2f\u00b0 < 0.3\u00b0 -- skipped", residual) + except Exception as e: + logger.warning("deskew_two_pass: pass3 text-line check failed: %s", e) + + total_angle = angle1 + angle2 + angle3 + debug["pass2_angle"] = round(angle2, 3) + debug["pass2_method"] = "word_alignment" + debug["pass3_angle"] = round(angle3, 3) + debug["pass3_method"] = "textline_regression" + debug["total_angle"] = round(total_angle, 3) + + logger.info( + "deskew_two_pass: pass1=%.2f\u00b0 + pass2=%.2f\u00b0 + pass3=%.2f\u00b0 = %.2f\u00b0", + angle1, angle2, angle3, total_angle, + ) + + return corrected, total_angle, debug diff --git a/klausur-service/backend/cv_preprocessing_dewarp.py b/klausur-service/backend/cv_preprocessing_dewarp.py new file mode 100644 index 0000000..640c87c --- /dev/null +++ b/klausur-service/backend/cv_preprocessing_dewarp.py @@ -0,0 +1,474 @@ +""" +CV Preprocessing Dewarp — Vertical shear detection and correction. + +Provides four shear detection methods (vertical edge, projection variance, +Hough lines, text-line drift), ensemble combination, quality gating, +and the main dewarp_image() function. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import math +import time +from typing import Any, Dict, List, Tuple + +import numpy as np + +from cv_vocab_types import ( + CV2_AVAILABLE, + TESSERACT_AVAILABLE, +) + +logger = logging.getLogger(__name__) + +try: + import cv2 +except ImportError: + cv2 = None # type: ignore[assignment] + +try: + import pytesseract + from PIL import Image +except ImportError: + pytesseract = None # type: ignore[assignment] + Image = None # type: ignore[assignment,misc] + + +# ============================================================================= +# Shear Detection Methods +# ============================================================================= + +def _detect_shear_angle(img: np.ndarray) -> Dict[str, Any]: + """Detect vertical shear angle via strongest vertical edge tracking (Method A).""" + h, w = img.shape[:2] + result = {"method": "vertical_edge", "shear_degrees": 0.0, "confidence": 0.0} + + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + abs_sobel = np.abs(sobel_x).astype(np.uint8) + + _, binary = cv2.threshold(abs_sobel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + num_strips = 20 + strip_h = h // num_strips + edge_positions = [] + + for i in range(num_strips): + y_start = i * strip_h + y_end = min((i + 1) * strip_h, h) + strip = binary[y_start:y_end, :] + + projection = np.sum(strip, axis=0).astype(np.float64) + if projection.max() == 0: + continue + + search_w = int(w * 0.4) + left_proj = projection[:search_w] + if left_proj.max() == 0: + continue + + kernel_size = max(3, w // 100) + if kernel_size % 2 == 0: + kernel_size += 1 + smoothed = cv2.GaussianBlur(left_proj.reshape(1, -1), (kernel_size, 1), 0).flatten() + x_pos = float(np.argmax(smoothed)) + y_center = (y_start + y_end) / 2.0 + edge_positions.append((y_center, x_pos)) + + if len(edge_positions) < 8: + return result + + ys = np.array([p[0] for p in edge_positions]) + xs = np.array([p[1] for p in edge_positions]) + + median_x = np.median(xs) + std_x = max(np.std(xs), 1.0) + mask = np.abs(xs - median_x) < 2 * std_x + ys = ys[mask] + xs = xs[mask] + + if len(ys) < 6: + return result + + straight_coeffs = np.polyfit(ys, xs, 1) + slope = straight_coeffs[0] + fitted = np.polyval(straight_coeffs, ys) + residuals = xs - fitted + rmse = float(np.sqrt(np.mean(residuals ** 2))) + + shear_degrees = math.degrees(math.atan(slope)) + + confidence = min(1.0, len(ys) / 15.0) * max(0.5, 1.0 - rmse / 5.0) + + result["shear_degrees"] = round(shear_degrees, 3) + result["confidence"] = round(float(confidence), 2) + + return result + + +def _detect_shear_by_projection(img: np.ndarray) -> Dict[str, Any]: + """Detect shear angle by maximising variance of horizontal text-line projections (Method B).""" + result = {"method": "projection", "shear_degrees": 0.0, "confidence": 0.0} + + h, w = img.shape[:2] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + small = cv2.resize(binary, (w // 2, h // 2), interpolation=cv2.INTER_AREA) + sh, sw = small.shape + + def _sweep_variance(angles_list): + results = [] + for angle_deg in angles_list: + if abs(angle_deg) < 0.001: + rotated = small + else: + shear_tan = math.tan(math.radians(angle_deg)) + M = np.float32([[1, shear_tan, -sh / 2.0 * shear_tan], [0, 1, 0]]) + rotated = cv2.warpAffine(small, M, (sw, sh), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT) + profile = np.sum(rotated, axis=1).astype(float) + results.append((angle_deg, float(np.var(profile)))) + return results + + coarse_angles = [a * 0.5 for a in range(-6, 7)] + coarse_results = _sweep_variance(coarse_angles) + coarse_best = max(coarse_results, key=lambda x: x[1]) + + fine_center = coarse_best[0] + fine_angles = [fine_center + a * 0.05 for a in range(-10, 11)] + fine_results = _sweep_variance(fine_angles) + fine_best = max(fine_results, key=lambda x: x[1]) + + best_angle = fine_best[0] + best_variance = fine_best[1] + variances = coarse_results + fine_results + + all_mean = sum(v for _, v in variances) / len(variances) + if all_mean > 0 and best_variance > all_mean: + confidence = min(1.0, (best_variance - all_mean) / (all_mean + 1.0) * 0.6) + else: + confidence = 0.0 + + result["shear_degrees"] = round(best_angle, 3) + result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) + return result + + +def _detect_shear_by_hough(img: np.ndarray) -> Dict[str, Any]: + """Detect shear using Hough transform on printed table / ruled lines (Method C).""" + result = {"method": "hough_lines", "shear_degrees": 0.0, "confidence": 0.0} + + h, w = img.shape[:2] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + edges = cv2.Canny(gray, 50, 150, apertureSize=3) + + min_len = int(w * 0.15) + lines = cv2.HoughLinesP( + edges, rho=1, theta=np.pi / 360, + threshold=int(w * 0.08), + minLineLength=min_len, + maxLineGap=20, + ) + + if lines is None or len(lines) < 3: + return result + + horizontal_angles: List[Tuple[float, float]] = [] + for line in lines: + x1, y1, x2, y2 = line[0] + if x1 == x2: + continue + angle = float(np.degrees(np.arctan2(y2 - y1, x2 - x1))) + if abs(angle) <= 5.0: + length = float(np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)) + horizontal_angles.append((angle, length)) + + if len(horizontal_angles) < 3: + return result + + angles_arr = np.array([a for a, _ in horizontal_angles]) + weights_arr = np.array([l for _, l in horizontal_angles]) + sorted_idx = np.argsort(angles_arr) + s_angles = angles_arr[sorted_idx] + s_weights = weights_arr[sorted_idx] + cum = np.cumsum(s_weights) + mid_idx = int(np.searchsorted(cum, cum[-1] / 2.0)) + median_angle = float(s_angles[min(mid_idx, len(s_angles) - 1)]) + + agree = sum(1 for a, _ in horizontal_angles if abs(a - median_angle) < 1.0) + confidence = min(1.0, agree / max(len(horizontal_angles), 1)) * 0.85 + + shear_degrees = -median_angle + + result["shear_degrees"] = round(shear_degrees, 3) + result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) + return result + + +def _detect_shear_by_text_lines(img: np.ndarray) -> Dict[str, Any]: + """Detect shear by measuring text-line straightness (Method D).""" + result = {"method": "text_lines", "shear_degrees": 0.0, "confidence": 0.0} + + h, w = img.shape[:2] + scale = 0.5 + small = cv2.resize(img, (int(w * scale), int(h * scale)), + interpolation=cv2.INTER_AREA) + gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) + pil_img = Image.fromarray(gray) + + try: + data = pytesseract.image_to_data( + pil_img, lang='eng+deu', config='--psm 11 --oem 3', + output_type=pytesseract.Output.DICT, + ) + except Exception: + return result + + words = [] + for i in range(len(data['text'])): + text = data['text'][i].strip() + conf = int(data['conf'][i]) + if not text or conf < 20 or len(text) < 2: + continue + left_x = float(data['left'][i]) + cy = data['top'][i] + data['height'][i] / 2.0 + word_w = float(data['width'][i]) + words.append((left_x, cy, word_w)) + + if len(words) < 15: + return result + + avg_w = sum(ww for _, _, ww in words) / len(words) + x_tol = max(avg_w * 0.4, 8) + + words_by_x = sorted(words, key=lambda w: w[0]) + columns: List[List[Tuple[float, float]]] = [] + cur_col: List[Tuple[float, float]] = [(words_by_x[0][0], words_by_x[0][1])] + cur_x = words_by_x[0][0] + + for lx, cy, _ in words_by_x[1:]: + if abs(lx - cur_x) <= x_tol: + cur_col.append((lx, cy)) + cur_x = cur_x * 0.8 + lx * 0.2 + else: + if len(cur_col) >= 5: + columns.append(cur_col) + cur_col = [(lx, cy)] + cur_x = lx + if len(cur_col) >= 5: + columns.append(cur_col) + + if len(columns) < 2: + return result + + drifts = [] + for col in columns: + ys = np.array([p[1] for p in col]) + xs = np.array([p[0] for p in col]) + y_range = ys.max() - ys.min() + if y_range < h * scale * 0.3: + continue + coeffs = np.polyfit(ys, xs, 1) + drifts.append(coeffs[0]) + + if len(drifts) < 2: + return result + + median_drift = float(np.median(drifts)) + shear_degrees = math.degrees(math.atan(median_drift)) + + drift_std = float(np.std(drifts)) + consistency = max(0.0, 1.0 - drift_std * 50) + count_factor = min(1.0, len(drifts) / 4.0) + confidence = count_factor * 0.5 + consistency * 0.5 + + result["shear_degrees"] = round(shear_degrees, 3) + result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) + logger.info("text_lines(v2): %d columns, %d drifts, median=%.4f, " + "shear=%.3f\u00b0, conf=%.2f", + len(columns), len(drifts), median_drift, + shear_degrees, confidence) + return result + + +# ============================================================================= +# Quality Check and Shear Application +# ============================================================================= + +def _dewarp_quality_check(original: np.ndarray, corrected: np.ndarray) -> bool: + """Check whether the dewarp correction actually improved alignment.""" + def _h_proj_variance(img: np.ndarray) -> float: + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold(gray, 0, 255, + cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + small = cv2.resize(binary, (binary.shape[1] // 2, binary.shape[0] // 2), + interpolation=cv2.INTER_AREA) + profile = np.sum(small, axis=1).astype(float) + return float(np.var(profile)) + + var_before = _h_proj_variance(original) + var_after = _h_proj_variance(corrected) + + return var_after > var_before + + +def _apply_shear(img: np.ndarray, shear_degrees: float) -> np.ndarray: + """Apply a vertical shear correction to an image.""" + h, w = img.shape[:2] + shear_tan = math.tan(math.radians(shear_degrees)) + + M = np.float32([ + [1, shear_tan, -h / 2.0 * shear_tan], + [0, 1, 0], + ]) + + corrected = cv2.warpAffine(img, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + return corrected + + +# ============================================================================= +# Ensemble Shear Combination +# ============================================================================= + +def _ensemble_shear(detections: List[Dict[str, Any]]) -> Tuple[float, float, str]: + """Combine multiple shear detections into a single weighted estimate (v2).""" + _MIN_CONF = 0.35 + _METHOD_WEIGHT_BOOST = {"text_lines": 1.5} + + accepted = [] + for d in detections: + if d["confidence"] < _MIN_CONF: + continue + boost = _METHOD_WEIGHT_BOOST.get(d["method"], 1.0) + effective_conf = d["confidence"] * boost + accepted.append((d["shear_degrees"], effective_conf, d["method"])) + + if not accepted: + return 0.0, 0.0, "none" + + if len(accepted) == 1: + deg, conf, method = accepted[0] + return deg, min(conf, 1.0), method + + total_w = sum(c for _, c, _ in accepted) + w_mean = sum(d * c for d, c, _ in accepted) / total_w + + filtered = [(d, c, m) for d, c, m in accepted if abs(d - w_mean) <= 1.0] + if not filtered: + filtered = accepted + + total_w2 = sum(c for _, c, _ in filtered) + final_deg = sum(d * c for d, c, _ in filtered) / total_w2 + + avg_conf = total_w2 / len(filtered) + spread = max(d for d, _, _ in filtered) - min(d for d, _, _ in filtered) + agreement_bonus = 0.15 if spread < 0.5 else 0.0 + ensemble_conf = min(1.0, avg_conf + agreement_bonus) + + methods_str = "+".join(m for _, _, m in filtered) + return round(final_deg, 3), round(min(ensemble_conf, 1.0), 2), methods_str + + +# ============================================================================= +# Main Dewarp Function +# ============================================================================= + +def dewarp_image(img: np.ndarray, use_ensemble: bool = True) -> Tuple[np.ndarray, Dict[str, Any]]: + """Correct vertical shear after deskew (v2 with quality gate). + + Methods (all run in ~150ms total): + A. _detect_shear_angle() -- vertical edge profile (~50ms) + B. _detect_shear_by_projection() -- horizontal text-line variance (~30ms) + C. _detect_shear_by_hough() -- Hough lines on table borders (~20ms) + D. _detect_shear_by_text_lines() -- text-line straightness (~50ms) + + Args: + img: BGR image (already deskewed). + use_ensemble: If False, fall back to single-method behaviour (method A only). + + Returns: + Tuple of (corrected_image, dewarp_info). + """ + no_correction = { + "method": "none", + "shear_degrees": 0.0, + "confidence": 0.0, + "detections": [], + } + + if not CV2_AVAILABLE: + return img, no_correction + + t0 = time.time() + + if use_ensemble: + det_a = _detect_shear_angle(img) + det_b = _detect_shear_by_projection(img) + det_c = _detect_shear_by_hough(img) + det_d = _detect_shear_by_text_lines(img) + detections = [det_a, det_b, det_c, det_d] + shear_deg, confidence, method = _ensemble_shear(detections) + else: + det_a = _detect_shear_angle(img) + detections = [det_a] + shear_deg = det_a["shear_degrees"] + confidence = det_a["confidence"] + method = det_a["method"] + + duration = time.time() - t0 + + logger.info( + "dewarp: ensemble shear=%.3f\u00b0 conf=%.2f method=%s (%.2fs) | " + "A=%.3f/%.2f B=%.3f/%.2f C=%.3f/%.2f D=%.3f/%.2f", + shear_deg, confidence, method, duration, + detections[0]["shear_degrees"], detections[0]["confidence"], + detections[1]["shear_degrees"] if len(detections) > 1 else 0.0, + detections[1]["confidence"] if len(detections) > 1 else 0.0, + detections[2]["shear_degrees"] if len(detections) > 2 else 0.0, + detections[2]["confidence"] if len(detections) > 2 else 0.0, + detections[3]["shear_degrees"] if len(detections) > 3 else 0.0, + detections[3]["confidence"] if len(detections) > 3 else 0.0, + ) + + _all_detections = [ + {"method": d["method"], "shear_degrees": d["shear_degrees"], + "confidence": d["confidence"]} + for d in detections + ] + + if abs(shear_deg) < 0.08 or confidence < 0.4: + no_correction["detections"] = _all_detections + return img, no_correction + + corrected = _apply_shear(img, -shear_deg) + + if abs(shear_deg) >= 0.5 and not _dewarp_quality_check(img, corrected): + logger.info("dewarp: quality gate REJECTED correction (%.3f\u00b0) -- " + "projection variance did not improve", shear_deg) + no_correction["detections"] = _all_detections + return img, no_correction + + info = { + "method": method, + "shear_degrees": shear_deg, + "confidence": confidence, + "detections": _all_detections, + } + + return corrected, info + + +def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray: + """Apply shear correction with a manual angle.""" + if abs(shear_degrees) < 0.001: + return img + return _apply_shear(img, -shear_degrees) diff --git a/klausur-service/backend/cv_review.py b/klausur-service/backend/cv_review.py index 5da85c2..217e463 100644 --- a/klausur-service/backend/cv_review.py +++ b/klausur-service/backend/cv_review.py @@ -1,1248 +1,46 @@ """ Multi-pass OCR, line matching, LLM/spell review, and pipeline orchestration. +Re-export facade -- all logic lives in the sub-modules: + + cv_review_pipeline Stages 6-8: OCR, line alignment, orchestrator + cv_review_spell Rule-based spell-checker OCR correction + cv_review_llm LLM-based OCR correction, prompt building, streaming + Lizenz: Apache 2.0 (kommerziell nutzbar) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import json -import logging -import os -import re -import time -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np - -from cv_vocab_types import ( - CV_PIPELINE_AVAILABLE, - PageRegion, - PipelineResult, - VocabRow, -) -from cv_preprocessing import ( - deskew_image, - dewarp_image, - render_image_high_res, - render_pdf_high_res, -) -from cv_layout import ( - analyze_layout, - create_layout_image, - create_ocr_image, -) -from cv_ocr_engines import ( - _fix_character_confusion, - _group_words_into_lines, +# Re-export everything for backward compatibility +from cv_review_pipeline import ( # noqa: F401 + ocr_region, + run_multi_pass_ocr, + match_lines_to_vocab, + llm_post_correct, + run_cv_pipeline, ) -logger = logging.getLogger(__name__) - -try: - import cv2 -except ImportError: - cv2 = None # type: ignore[assignment] - -try: - import pytesseract - from PIL import Image -except ImportError: - pytesseract = None # type: ignore[assignment] - Image = None # type: ignore[assignment,misc] - - -# ============================================================================= -# Stage 6: Multi-Pass OCR -# ============================================================================= - -def ocr_region(ocr_img: np.ndarray, region: PageRegion, lang: str, - psm: int, fallback_psm: Optional[int] = None, - min_confidence: float = 40.0) -> List[Dict[str, Any]]: - """Run Tesseract OCR on a specific region with given PSM. - - Args: - ocr_img: Binarized full-page image. - region: Region to crop and OCR. - lang: Tesseract language string. - psm: Page Segmentation Mode. - fallback_psm: If confidence too low, retry with this PSM per line. - min_confidence: Minimum average confidence before fallback. - - Returns: - List of word dicts with text, position, confidence. - """ - # Crop region - crop = ocr_img[region.y:region.y + region.height, - region.x:region.x + region.width] - - if crop.size == 0: - return [] - - # Convert to PIL for pytesseract - pil_img = Image.fromarray(crop) - - # Run Tesseract with specified PSM - config = f'--psm {psm} --oem 3' - try: - data = pytesseract.image_to_data(pil_img, lang=lang, config=config, - output_type=pytesseract.Output.DICT) - except Exception as e: - logger.warning(f"Tesseract failed for region {region.type}: {e}") - return [] - - words = [] - for i in range(len(data['text'])): - text = data['text'][i].strip() - conf = int(data['conf'][i]) - if not text or conf < 10: - continue - words.append({ - 'text': text, - 'left': data['left'][i] + region.x, # Absolute coords - 'top': data['top'][i] + region.y, - 'width': data['width'][i], - 'height': data['height'][i], - 'conf': conf, - 'region_type': region.type, - }) - - # Check average confidence - if words and fallback_psm is not None: - avg_conf = sum(w['conf'] for w in words) / len(words) - if avg_conf < min_confidence: - logger.info(f"Region {region.type}: avg confidence {avg_conf:.0f}% < {min_confidence}%, " - f"trying fallback PSM {fallback_psm}") - words = _ocr_region_line_by_line(ocr_img, region, lang, fallback_psm) - - return words - - -def _ocr_region_line_by_line(ocr_img: np.ndarray, region: PageRegion, - lang: str, psm: int) -> List[Dict[str, Any]]: - """OCR a region line by line (fallback for low-confidence regions). - - Splits the region into horizontal strips based on text density, - then OCRs each strip individually with the given PSM. - """ - crop = ocr_img[region.y:region.y + region.height, - region.x:region.x + region.width] - - if crop.size == 0: - return [] - - # Find text lines via horizontal projection - inv = cv2.bitwise_not(crop) - h_proj = np.sum(inv, axis=1) - threshold = np.max(h_proj) * 0.05 if np.max(h_proj) > 0 else 0 - - # Find line boundaries - lines = [] - in_text = False - line_start = 0 - for y in range(len(h_proj)): - if h_proj[y] > threshold and not in_text: - line_start = y - in_text = True - elif h_proj[y] <= threshold and in_text: - if y - line_start > 5: # Minimum line height - lines.append((line_start, y)) - in_text = False - if in_text and len(h_proj) - line_start > 5: - lines.append((line_start, len(h_proj))) - - all_words = [] - config = f'--psm {psm} --oem 3' - - for line_y_start, line_y_end in lines: - # Add small padding - pad = 3 - y1 = max(0, line_y_start - pad) - y2 = min(crop.shape[0], line_y_end + pad) - line_crop = crop[y1:y2, :] - - if line_crop.size == 0: - continue - - pil_img = Image.fromarray(line_crop) - try: - data = pytesseract.image_to_data(pil_img, lang=lang, config=config, - output_type=pytesseract.Output.DICT) - except Exception: - continue - - for i in range(len(data['text'])): - text = data['text'][i].strip() - conf = int(data['conf'][i]) - if not text or conf < 10: - continue - all_words.append({ - 'text': text, - 'left': data['left'][i] + region.x, - 'top': data['top'][i] + region.y + y1, - 'width': data['width'][i], - 'height': data['height'][i], - 'conf': conf, - 'region_type': region.type, - }) - - return all_words - - -def run_multi_pass_ocr(ocr_img: np.ndarray, - regions: List[PageRegion], - lang: str = "eng+deu") -> Dict[str, List[Dict]]: - """Run OCR on each detected region with optimized settings. - - Args: - ocr_img: Binarized full-page image. - regions: Detected page regions. - lang: Default language. - - Returns: - Dict mapping region type to list of word dicts. - """ - results: Dict[str, List[Dict]] = {} - - _ocr_skip = {'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - for region in regions: - if region.type in _ocr_skip: - continue # Skip non-content regions - - if region.type == 'column_en': - words = ocr_region(ocr_img, region, lang='eng', psm=4) - elif region.type == 'column_de': - words = ocr_region(ocr_img, region, lang='deu', psm=4) - elif region.type == 'column_example': - words = ocr_region(ocr_img, region, lang=lang, psm=6, - fallback_psm=7, min_confidence=40.0) - else: - words = ocr_region(ocr_img, region, lang=lang, psm=6) - - results[region.type] = words - logger.info(f"OCR {region.type}: {len(words)} words") - - return results - - -# ============================================================================= -# Stage 7: Line Alignment → Vocabulary Entries -# ============================================================================= - -def match_lines_to_vocab(ocr_results: Dict[str, List[Dict]], - regions: List[PageRegion], - y_tolerance_px: int = 25) -> List[VocabRow]: - """Align OCR results from different columns into vocabulary rows. - - Uses Y-coordinate matching to pair English words, German translations, - and example sentences that appear on the same line. - - Args: - ocr_results: Dict mapping region type to word lists. - regions: Detected regions (for reference). - y_tolerance_px: Max Y-distance to consider words on the same row. - - Returns: - List of VocabRow objects. - """ - # If no vocabulary columns detected (e.g. plain text page), return empty - if 'column_en' not in ocr_results and 'column_de' not in ocr_results: - logger.info("match_lines_to_vocab: no column_en/column_de in OCR results, returning empty") - return [] - - # Group words into lines per column - en_lines = _group_words_into_lines(ocr_results.get('column_en', []), y_tolerance_px) - de_lines = _group_words_into_lines(ocr_results.get('column_de', []), y_tolerance_px) - ex_lines = _group_words_into_lines(ocr_results.get('column_example', []), y_tolerance_px) - - def line_y_center(line: List[Dict]) -> float: - return sum(w['top'] + w['height'] / 2 for w in line) / len(line) - - def line_text(line: List[Dict]) -> str: - return ' '.join(w['text'] for w in line) - - def line_confidence(line: List[Dict]) -> float: - return sum(w['conf'] for w in line) / len(line) if line else 0 - - # Build EN entries as the primary reference - vocab_rows: List[VocabRow] = [] - - for en_line in en_lines: - en_y = line_y_center(en_line) - en_text = line_text(en_line) - en_conf = line_confidence(en_line) - - # Skip very short or likely header content - if len(en_text.strip()) < 2: - continue - - # Find matching DE line - de_text = "" - de_conf = 0.0 - best_de_dist = float('inf') - best_de_idx = -1 - for idx, de_line in enumerate(de_lines): - dist = abs(line_y_center(de_line) - en_y) - if dist < y_tolerance_px and dist < best_de_dist: - best_de_dist = dist - best_de_idx = idx - - if best_de_idx >= 0: - de_text = line_text(de_lines[best_de_idx]) - de_conf = line_confidence(de_lines[best_de_idx]) - - # Find matching example line - ex_text = "" - ex_conf = 0.0 - best_ex_dist = float('inf') - best_ex_idx = -1 - for idx, ex_line in enumerate(ex_lines): - dist = abs(line_y_center(ex_line) - en_y) - if dist < y_tolerance_px and dist < best_ex_dist: - best_ex_dist = dist - best_ex_idx = idx - - if best_ex_idx >= 0: - ex_text = line_text(ex_lines[best_ex_idx]) - ex_conf = line_confidence(ex_lines[best_ex_idx]) - - avg_conf = en_conf - conf_count = 1 - if de_conf > 0: - avg_conf += de_conf - conf_count += 1 - if ex_conf > 0: - avg_conf += ex_conf - conf_count += 1 - - vocab_rows.append(VocabRow( - english=en_text.strip(), - german=de_text.strip(), - example=ex_text.strip(), - confidence=avg_conf / conf_count, - y_position=int(en_y), - )) - - # Handle multi-line wrapping in example column: - # If an example line has no matching EN/DE, append to previous entry - matched_ex_ys = set() - for row in vocab_rows: - if row.example: - matched_ex_ys.add(row.y_position) - - for ex_line in ex_lines: - ex_y = line_y_center(ex_line) - # Check if already matched - already_matched = any(abs(ex_y - y) < y_tolerance_px for y in matched_ex_ys) - if already_matched: - continue - - # Find nearest previous vocab row - best_row = None - best_dist = float('inf') - for row in vocab_rows: - dist = ex_y - row.y_position - if 0 < dist < y_tolerance_px * 3 and dist < best_dist: - best_dist = dist - best_row = row - - if best_row: - continuation = line_text(ex_line).strip() - if continuation: - best_row.example = (best_row.example + " " + continuation).strip() - - # Sort by Y position - vocab_rows.sort(key=lambda r: r.y_position) - - return vocab_rows - - -# ============================================================================= -# Stage 8: Optional LLM Post-Correction -# ============================================================================= - -async def llm_post_correct(img: np.ndarray, vocab_rows: List[VocabRow], - confidence_threshold: float = 50.0, - enabled: bool = False) -> List[VocabRow]: - """Optionally send low-confidence regions to Qwen-VL for correction. - - Default: disabled. Enable per parameter. - - Args: - img: Original BGR image. - vocab_rows: Current vocabulary rows. - confidence_threshold: Rows below this get LLM correction. - enabled: Whether to actually run LLM correction. - - Returns: - Corrected vocabulary rows. - """ - if not enabled: - return vocab_rows - - # TODO: Implement Qwen-VL correction for low-confidence entries - # For each row with confidence < threshold: - # 1. Crop the relevant region from img - # 2. Send crop + OCR text to Qwen-VL - # 3. Replace text if LLM provides a confident correction - logger.info(f"LLM post-correction skipped (not yet implemented)") - return vocab_rows - - -# ============================================================================= -# Orchestrator -# ============================================================================= - -async def run_cv_pipeline( - pdf_data: Optional[bytes] = None, - image_data: Optional[bytes] = None, - page_number: int = 0, - zoom: float = 3.0, - enable_dewarp: bool = True, - enable_llm_correction: bool = False, - lang: str = "eng+deu", -) -> PipelineResult: - """Run the complete CV document reconstruction pipeline. - - Args: - pdf_data: Raw PDF bytes (mutually exclusive with image_data). - image_data: Raw image bytes (mutually exclusive with pdf_data). - page_number: 0-indexed page number (for PDF). - zoom: PDF rendering zoom factor. - enable_dewarp: Whether to run dewarp stage. - enable_llm_correction: Whether to run LLM post-correction. - lang: Tesseract language string. - - Returns: - PipelineResult with vocabulary and timing info. - """ - if not CV_PIPELINE_AVAILABLE: - return PipelineResult(error="CV pipeline not available (OpenCV or Tesseract missing)") - - result = PipelineResult() - total_start = time.time() - - try: - # Stage 1: Render - t = time.time() - if pdf_data: - img = render_pdf_high_res(pdf_data, page_number, zoom) - elif image_data: - img = render_image_high_res(image_data) - else: - return PipelineResult(error="No input data (pdf_data or image_data required)") - result.stages['render'] = round(time.time() - t, 2) - result.image_width = img.shape[1] - result.image_height = img.shape[0] - logger.info(f"Stage 1 (render): {img.shape[1]}x{img.shape[0]} in {result.stages['render']}s") - - # Stage 2: Deskew - t = time.time() - img, angle = deskew_image(img) - result.stages['deskew'] = round(time.time() - t, 2) - logger.info(f"Stage 2 (deskew): {angle:.2f}° in {result.stages['deskew']}s") - - # Stage 3: Dewarp - if enable_dewarp: - t = time.time() - img, _dewarp_info = dewarp_image(img) - result.stages['dewarp'] = round(time.time() - t, 2) - - # Stage 4: Dual image preparation - t = time.time() - ocr_img = create_ocr_image(img) - layout_img = create_layout_image(img) - result.stages['image_prep'] = round(time.time() - t, 2) - - # Stage 5: Layout analysis - t = time.time() - regions = analyze_layout(layout_img, ocr_img) - result.stages['layout'] = round(time.time() - t, 2) - result.columns_detected = len([r for r in regions if r.type.startswith('column')]) - logger.info(f"Stage 5 (layout): {result.columns_detected} columns in {result.stages['layout']}s") - - # Stage 6: Multi-pass OCR - t = time.time() - ocr_results = run_multi_pass_ocr(ocr_img, regions, lang) - result.stages['ocr'] = round(time.time() - t, 2) - total_words = sum(len(w) for w in ocr_results.values()) - result.word_count = total_words - logger.info(f"Stage 6 (OCR): {total_words} words in {result.stages['ocr']}s") - - # Stage 7: Line alignment - t = time.time() - vocab_rows = match_lines_to_vocab(ocr_results, regions) - result.stages['alignment'] = round(time.time() - t, 2) - - # Stage 8: Optional LLM correction - if enable_llm_correction: - t = time.time() - vocab_rows = await llm_post_correct(img, vocab_rows) - result.stages['llm_correction'] = round(time.time() - t, 2) - - # Convert to output format - result.vocabulary = [ - { - "english": row.english, - "german": row.german, - "example": row.example, - "confidence": round(row.confidence, 1), - } - for row in vocab_rows - if row.english or row.german # Skip empty rows - ] - - result.duration_seconds = round(time.time() - total_start, 2) - logger.info(f"CV Pipeline complete: {len(result.vocabulary)} entries in {result.duration_seconds}s") - - except Exception as e: - logger.error(f"CV Pipeline error: {e}") - import traceback - logger.debug(traceback.format_exc()) - result.error = str(e) - result.duration_seconds = round(time.time() - total_start, 2) - - return result - - -# --------------------------------------------------------------------------- -# LLM-based OCR Correction (Step 6) -# --------------------------------------------------------------------------- - -import httpx -import os -import json as _json -import re as _re - -_OLLAMA_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") -OLLAMA_REVIEW_MODEL = os.getenv("OLLAMA_REVIEW_MODEL", "qwen3:0.6b") -_REVIEW_BATCH_SIZE = int(os.getenv("OLLAMA_REVIEW_BATCH_SIZE", "20")) -logger.info("LLM review model: %s (batch=%d)", OLLAMA_REVIEW_MODEL, _REVIEW_BATCH_SIZE) - -# Regex: entry contains IPA phonetic brackets like "dance [dɑːns]" -_HAS_PHONETIC_RE = _re.compile(r'\[.*?[ˈˌːʃʒθðŋɑɒɔəɜɪʊʌæ].*?\]') - -# Regex: digit adjacent to a letter — the hallmark of OCR digit↔letter confusion. -# Matches digits 0,1,5,6,8 (common OCR confusions: 0→O, 1→l/I, 5→S, 6→G, 8→B) -# when they appear inside or next to a word character. -_OCR_DIGIT_IN_WORD_RE = _re.compile(r'(?<=[A-Za-zÄÖÜäöüß])[01568]|[01568](?=[A-Za-zÄÖÜäöüß])') - - -def _entry_needs_review(entry: Dict) -> bool: - """Check if an entry should be sent to the LLM for review. - - Sends all non-empty entries that don't have IPA phonetic transcriptions. - The LLM prompt and _is_spurious_change() guard against unwanted changes. - """ - en = entry.get("english", "") or "" - de = entry.get("german", "") or "" - - # Skip completely empty entries - if not en.strip() and not de.strip(): - return False - # Skip entries with IPA/phonetic brackets — dictionary-corrected, LLM must not touch them - if _HAS_PHONETIC_RE.search(en) or _HAS_PHONETIC_RE.search(de): - return False - return True - - -def _build_llm_prompt(table_lines: List[Dict]) -> str: - """Build the LLM correction prompt for a batch of entries.""" - return f"""Du bist ein OCR-Zeichenkorrektur-Werkzeug fuer Vokabeltabellen (Englisch-Deutsch). - -DEINE EINZIGE AUFGABE: Einzelne Zeichen korrigieren, die vom OCR-Scanner als Ziffer statt als Buchstabe erkannt wurden. - -NUR diese Korrekturen sind erlaubt: -- Ziffer 8 statt B: "8en" → "Ben", "8uch" → "Buch", "8all" → "Ball" -- Ziffer 0 statt O oder o: "L0ndon" → "London", "0ld" → "Old" -- Ziffer 1 statt l oder I: "1ong" → "long", "Ber1in" → "Berlin" -- Ziffer 5 statt S oder s: "5tadt" → "Stadt", "5ee" → "See" -- Ziffer 6 statt G oder g: "6eld" → "Geld" -- Senkrechter Strich | statt I oder l: "| want" → "I want", "|ong" → "long", "he| p" → "help" - -ABSOLUT VERBOTEN — aendere NIEMALS: -- Woerter die korrekt geschrieben sind — auch wenn du eine andere Schreibweise kennst -- Uebersetzungen — du uebersetzt NICHTS, weder EN→DE noch DE→EN -- Korrekte englische Woerter (en-Spalte) — auch wenn du eine Bedeutung kennst -- Korrekte deutsche Woerter (de-Spalte) — auch wenn du sie anders sagen wuerdest -- Eigennamen: Ben, London, China, Africa, Shakespeare usw. -- Abkuerzungen: sth., sb., etc., e.g., i.e., v.t., smb. usw. -- Lautschrift in eckigen Klammern [...] — diese NIEMALS beruehren -- Beispielsaetze in der ex-Spalte — NIEMALS aendern - -Wenn ein Wort keinen Ziffer-Buchstaben-Fehler enthaelt: gib es UNVERAENDERT zurueck und setze "corrected": false. - -Antworte NUR mit dem JSON-Array. Kein Text davor oder danach. -Behalte die exakte Struktur (gleiche Anzahl Eintraege, gleiche Reihenfolge). - -/no_think - -Eingabe: -{_json.dumps(table_lines, ensure_ascii=False, indent=2)}""" - - -def _is_spurious_change(old_val: str, new_val: str) -> bool: - """Detect LLM changes that are likely wrong and should be discarded. - - Only digit↔letter substitutions (0→O, 1→l, 5→S, 6→G, 8→B) are - legitimate OCR corrections. Everything else is rejected. - - Filters out: - - Case-only changes - - Changes that don't contain any digit→letter fix - - Completely different words (LLM translating or hallucinating) - - Additions or removals of whole words (count changed) - """ - if not old_val or not new_val: - return False - - # Case-only change — never a real OCR error - if old_val.lower() == new_val.lower(): - return True - - # If the word count changed significantly, the LLM rewrote rather than fixed - old_words = old_val.split() - new_words = new_val.split() - if abs(len(old_words) - len(new_words)) > 1: - return True - - # Core rule: a legitimate correction replaces a digit with the corresponding - # letter. If the change doesn't include such a substitution, reject it. - # Build a set of (old_char, new_char) pairs that differ between old and new. - # Use character-level diff heuristic: if lengths are close, zip and compare. - # Map of characters that OCR commonly misreads → set of correct replacements - _OCR_CHAR_MAP = { - # Digits mistaken for letters - '0': set('oOgG'), - '1': set('lLiI'), - '5': set('sS'), - '6': set('gG'), - '8': set('bB'), - # Non-letter symbols mistaken for letters - '|': set('lLiI1'), # pipe → lowercase l, capital I, or digit 1 - 'l': set('iI|1'), # lowercase l → capital I (and reverse) - } - has_valid_fix = False - if len(old_val) == len(new_val): - for oc, nc in zip(old_val, new_val): - if oc != nc: - if oc in _OCR_CHAR_MAP and nc in _OCR_CHAR_MAP[oc]: - has_valid_fix = True - elif nc in _OCR_CHAR_MAP and oc in _OCR_CHAR_MAP[nc]: - # Reverse check (e.g. l→I where new is the "correct" char) - has_valid_fix = True - else: - # Length changed by 1: accept if old had a suspicious char sequence - _OCR_SUSPICIOUS_RE = _re.compile(r'[|01568]') - if abs(len(old_val) - len(new_val)) <= 1 and _OCR_SUSPICIOUS_RE.search(old_val): - has_valid_fix = True - - if not has_valid_fix: - return True # Reject — looks like translation or hallucination - - return False - - -def _diff_batch(originals: List[Dict], corrected: List[Dict]) -> Tuple[List[Dict], List[Dict]]: - """Compare original entries with LLM-corrected ones, return (changes, corrected_entries).""" - changes = [] - entries_out = [] - for i, orig in enumerate(originals): - if i < len(corrected): - c = corrected[i] - entry = dict(orig) - for field_name, key in [("english", "en"), ("german", "de"), ("example", "ex")]: - new_val = c.get(key, "").strip() - old_val = (orig.get(field_name, "") or "").strip() - if new_val and new_val != old_val: - # Filter spurious LLM changes - if _is_spurious_change(old_val, new_val): - continue - changes.append({ - "row_index": orig.get("row_index", i), - "field": field_name, - "old": old_val, - "new": new_val, - }) - entry[field_name] = new_val - entry["llm_corrected"] = True - entries_out.append(entry) - else: - entries_out.append(dict(orig)) - return changes, entries_out - - -# ─── Spell-Checker OCR Review (Rule-Based, no LLM) ──────────────────────────── - -REVIEW_ENGINE = os.getenv("REVIEW_ENGINE", "spell") # "spell" (default) | "llm" - -try: - from spellchecker import SpellChecker as _SpellChecker - _en_spell = _SpellChecker(language='en', distance=1) - _de_spell = _SpellChecker(language='de', distance=1) - _SPELL_AVAILABLE = True - logger.info("pyspellchecker loaded (EN+DE), review engine: %s", REVIEW_ENGINE) -except ImportError: - _SPELL_AVAILABLE = False - logger.warning("pyspellchecker not installed — falling back to LLM review") - -# ─── Page-Ref Normalization ─────────────────────────────────────────────────── -# Normalizes OCR variants like "p-60", "p 61", "p60" → "p.60" -_PAGE_REF_RE = _re.compile(r'\bp[\s\-]?(\d+)', _re.IGNORECASE) - - -def _normalize_page_ref(text: str) -> str: - """Normalize page references: 'p-60' / 'p 61' / 'p60' → 'p.60'.""" - if not text: - return text - return _PAGE_REF_RE.sub(lambda m: f"p.{m.group(1)}", text) - - -# Suspicious OCR chars → ordered list of most-likely correct replacements -_SPELL_SUBS: Dict[str, List[str]] = { - '0': ['O', 'o'], - '1': ['l', 'I'], - '5': ['S', 's'], - '6': ['G', 'g'], - '8': ['B', 'b'], - '|': ['I', 'l', '1'], -} -_SPELL_SUSPICIOUS = frozenset(_SPELL_SUBS.keys()) - -# Tokenizer: word tokens (letters + pipe) alternating with separators -_SPELL_TOKEN_RE = _re.compile(r'([A-Za-zÄÖÜäöüß|]+)([^A-Za-zÄÖÜäöüß|]*)') - - -def _spell_dict_knows(word: str) -> bool: - """True if word is known in EN or DE dictionary.""" - if not _SPELL_AVAILABLE: - return False - w = word.lower() - return bool(_en_spell.known([w])) or bool(_de_spell.known([w])) - - -def _try_split_merged_word(token: str) -> Optional[str]: - """Try to split a merged word like 'atmyschool' into 'at my school'. - - Uses dynamic programming to find the shortest sequence of dictionary - words that covers the entire token. Only returns a result when the - split produces at least 2 words and ALL parts are known dictionary words. - - Preserves original capitalisation by mapping back to the input string. - """ - if not _SPELL_AVAILABLE or len(token) < 4: - return None - - lower = token.lower() - n = len(lower) - - # dp[i] = (word_lengths_list, score) for best split of lower[:i], or None - # Score: (-word_count, sum_of_squared_lengths) — fewer words first, - # then prefer longer words (e.g. "come on" over "com eon") - dp: list = [None] * (n + 1) - dp[0] = ([], 0) - - for i in range(1, n + 1): - for j in range(max(0, i - 20), i): - if dp[j] is None: - continue - candidate = lower[j:i] - word_len = i - j - if word_len == 1 and candidate not in ('a', 'i'): - continue - if _spell_dict_knows(candidate): - prev_words, prev_sq = dp[j] - new_words = prev_words + [word_len] - new_sq = prev_sq + word_len * word_len - new_key = (-len(new_words), new_sq) - if dp[i] is None: - dp[i] = (new_words, new_sq) - else: - old_key = (-len(dp[i][0]), dp[i][1]) - if new_key >= old_key: - # >= so that later splits (longer first word) win ties - dp[i] = (new_words, new_sq) - - if dp[n] is None or len(dp[n][0]) < 2: - return None - - # Reconstruct with original casing - result = [] - pos = 0 - for wlen in dp[n][0]: - result.append(token[pos:pos + wlen]) - pos += wlen - - logger.debug("Split merged word: %r → %r", token, " ".join(result)) - return " ".join(result) - - -def _spell_fix_token(token: str, field: str = "") -> Optional[str]: - """Return corrected form of token, or None if no fix needed/possible. - - *field* is 'english' or 'german' — used to pick the right dictionary - for general spell correction (step 3 below). - """ - has_suspicious = any(ch in _SPELL_SUSPICIOUS for ch in token) - - # 1. Already known word → no fix needed - if _spell_dict_knows(token): - return None - - # 2. Digit/pipe substitution (existing logic) - if has_suspicious: - # Standalone pipe → capital I - if token == '|': - return 'I' - # Dictionary-backed single-char substitution - for i, ch in enumerate(token): - if ch not in _SPELL_SUBS: - continue - for replacement in _SPELL_SUBS[ch]: - candidate = token[:i] + replacement + token[i + 1:] - if _spell_dict_knows(candidate): - return candidate - # Structural rule: suspicious char at position 0 + rest is all lowercase letters - first = token[0] - if first in _SPELL_SUBS and len(token) >= 2: - rest = token[1:] - if rest.isalpha() and rest.islower(): - candidate = _SPELL_SUBS[first][0] + rest - if not candidate[0].isdigit(): - return candidate - - # 3. OCR umlaut confusion: OCR often drops umlaut dots (ü→i, ä→a, ö→o, ü→u) - # Try single-char umlaut substitutions and check against dictionary. - if len(token) >= 3 and token.isalpha() and field == "german": - _UMLAUT_SUBS = {'a': 'ä', 'o': 'ö', 'u': 'ü', 'i': 'ü', - 'A': 'Ä', 'O': 'Ö', 'U': 'Ü', 'I': 'Ü'} - for i, ch in enumerate(token): - if ch in _UMLAUT_SUBS: - candidate = token[:i] + _UMLAUT_SUBS[ch] + token[i + 1:] - if _spell_dict_knows(candidate): - return candidate - - # 4. General spell correction for unknown words (no digits/pipes) - # e.g. "beautful" → "beautiful" - if not has_suspicious and len(token) >= 3 and token.isalpha(): - spell = _en_spell if field == "english" else _de_spell if field == "german" else None - if spell is not None: - correction = spell.correction(token.lower()) - if correction and correction != token.lower(): - # Preserve original capitalisation pattern - if token[0].isupper(): - correction = correction[0].upper() + correction[1:] - if _spell_dict_knows(correction): - return correction - - # 5. Merged-word split: OCR often merges adjacent words when spacing - # is too tight, e.g. "atmyschool" → "at my school" - if len(token) >= 4 and token.isalpha(): - split = _try_split_merged_word(token) - if split: - return split - - return None - - -def _spell_fix_field(text: str, field: str = "") -> Tuple[str, bool]: - """Apply OCR corrections to a text field. Returns (fixed_text, was_changed). - - *field* is 'english' or 'german' — forwarded to _spell_fix_token for - dictionary selection. - """ - if not text: - return text, False - has_suspicious = any(ch in text for ch in _SPELL_SUSPICIOUS) - # If no suspicious chars AND no alpha chars that could be misspelled, skip - if not has_suspicious and not any(c.isalpha() for c in text): - return text, False - # Pattern: | immediately before . or , → numbered list prefix ("|. " → "1. ") - fixed = _re.sub(r'(? Dict: - """Rule-based OCR correction: spell-checker + structural heuristics. - - Deterministic — never translates, never touches IPA, never hallucinates. - Uses SmartSpellChecker for language-aware corrections with context-based - disambiguation (a/I), multi-digit substitution, and cross-language guard. - """ - t0 = time.time() - changes: List[Dict] = [] - all_corrected: List[Dict] = [] - - # Use SmartSpellChecker if available, fall back to legacy _spell_fix_field - _smart = None - try: - from smart_spell import SmartSpellChecker - _smart = SmartSpellChecker() - logger.debug("spell_review: using SmartSpellChecker") - except Exception: - logger.debug("spell_review: SmartSpellChecker not available, using legacy") - - # Map field names → language codes for SmartSpellChecker - _LANG_MAP = {"english": "en", "german": "de", "example": "auto"} - - for i, entry in enumerate(entries): - e = dict(entry) - # Page-ref normalization (always, regardless of review status) - old_ref = (e.get("source_page") or "").strip() - if old_ref: - new_ref = _normalize_page_ref(old_ref) - if new_ref != old_ref: - changes.append({ - "row_index": e.get("row_index", i), - "field": "source_page", - "old": old_ref, - "new": new_ref, - }) - e["source_page"] = new_ref - e["llm_corrected"] = True - if not _entry_needs_review(e): - all_corrected.append(e) - continue - for field_name in ("english", "german", "example"): - old_val = (e.get(field_name) or "").strip() - if not old_val: - continue - - if _smart: - # SmartSpellChecker path — language-aware, context-based - lang_code = _LANG_MAP.get(field_name, "en") - result = _smart.correct_text(old_val, lang=lang_code) - new_val = result.corrected - was_changed = result.changed - else: - # Legacy path - lang = "german" if field_name in ("german", "example") else "english" - new_val, was_changed = _spell_fix_field(old_val, field=lang) - - if was_changed and new_val != old_val: - changes.append({ - "row_index": e.get("row_index", i), - "field": field_name, - "old": old_val, - "new": new_val, - }) - e[field_name] = new_val - e["llm_corrected"] = True - all_corrected.append(e) - duration_ms = int((time.time() - t0) * 1000) - model_name = "smart-spell-checker" if _smart else "spell-checker" - return { - "entries_original": entries, - "entries_corrected": all_corrected, - "changes": changes, - "skipped_count": 0, - "model_used": model_name, - "duration_ms": duration_ms, - } - - -async def spell_review_entries_streaming(entries: List[Dict], batch_size: int = 50): - """Async generator yielding SSE-compatible events for spell-checker review.""" - total = len(entries) - yield { - "type": "meta", - "total_entries": total, - "to_review": total, - "skipped": 0, - "model": "spell-checker", - "batch_size": batch_size, - } - result = spell_review_entries_sync(entries) - changes = result["changes"] - yield { - "type": "batch", - "batch_index": 0, - "entries_reviewed": [e.get("row_index", i) for i, e in enumerate(entries)], - "changes": changes, - "duration_ms": result["duration_ms"], - "progress": {"current": total, "total": total}, - } - yield { - "type": "complete", - "changes": changes, - "model_used": "spell-checker", - "duration_ms": result["duration_ms"], - "total_entries": total, - "reviewed": total, - "skipped": 0, - "corrections_found": len(changes), - "entries_corrected": result["entries_corrected"], - } - -# ─── End Spell-Checker ──────────────────────────────────────────────────────── - - -async def llm_review_entries( - entries: List[Dict], - model: str = None, -) -> Dict: - """OCR error correction. Uses spell-checker (REVIEW_ENGINE=spell) or LLM (REVIEW_ENGINE=llm).""" - if REVIEW_ENGINE == "spell" and _SPELL_AVAILABLE: - return spell_review_entries_sync(entries) - if REVIEW_ENGINE == "spell" and not _SPELL_AVAILABLE: - logger.warning("REVIEW_ENGINE=spell but pyspellchecker not installed, using LLM") - - model = model or OLLAMA_REVIEW_MODEL - - # Filter: only entries that need review - reviewable = [(i, e) for i, e in enumerate(entries) if _entry_needs_review(e)] - - if not reviewable: - return { - "entries_original": entries, - "entries_corrected": [dict(e) for e in entries], - "changes": [], - "skipped_count": len(entries), - "model_used": model, - "duration_ms": 0, - } - - review_entries = [e for _, e in reviewable] - table_lines = [ - {"row": e.get("row_index", 0), "en": e.get("english", ""), "de": e.get("german", ""), "ex": e.get("example", "")} - for e in review_entries - ] - - logger.info("LLM review: sending %d/%d entries to %s (skipped %d without digit-pattern)", - len(review_entries), len(entries), model, len(entries) - len(reviewable)) - logger.debug("LLM review input: %s", _json.dumps(table_lines[:3], ensure_ascii=False)) - - prompt = _build_llm_prompt(table_lines) - - t0 = time.time() - async with httpx.AsyncClient(timeout=300.0) as client: - resp = await client.post( - f"{_OLLAMA_URL}/api/chat", - json={ - "model": model, - "messages": [{"role": "user", "content": prompt}], - "stream": False, - "think": False, # qwen3: disable chain-of-thought (Ollama >=0.6) - "options": {"temperature": 0.1, "num_predict": 8192}, - }, - ) - resp.raise_for_status() - content = resp.json().get("message", {}).get("content", "") - duration_ms = int((time.time() - t0) * 1000) - - logger.info("LLM review: response in %dms, raw length=%d chars", duration_ms, len(content)) - logger.debug("LLM review raw response (first 500): %.500s", content) - - corrected = _parse_llm_json_array(content) - logger.info("LLM review: parsed %d corrected entries, applying diff...", len(corrected)) - changes, corrected_entries = _diff_batch(review_entries, corrected) - - # Merge corrected entries back into the full list - all_corrected = [dict(e) for e in entries] - for batch_idx, (orig_idx, _) in enumerate(reviewable): - if batch_idx < len(corrected_entries): - all_corrected[orig_idx] = corrected_entries[batch_idx] - - return { - "entries_original": entries, - "entries_corrected": all_corrected, - "changes": changes, - "skipped_count": len(entries) - len(reviewable), - "model_used": model, - "duration_ms": duration_ms, - } - - -async def llm_review_entries_streaming( - entries: List[Dict], - model: str = None, - batch_size: int = _REVIEW_BATCH_SIZE, -): - """Async generator: yield SSE events. Uses spell-checker or LLM depending on REVIEW_ENGINE. - - Phase 0 (always): Run _fix_character_confusion and emit any changes so they are - visible in the UI — this is the only place the fix now runs (removed from Step 1 - of build_vocab_pipeline_streaming). - """ - # --- Phase 0: Character confusion fix (| → I, 1 → I, 8 → B, etc.) --- - _CONF_FIELDS = ('english', 'german', 'example') - originals = [{f: e.get(f, '') for f in _CONF_FIELDS} for e in entries] - _fix_character_confusion(entries) # modifies in-place, returns same list - char_changes = [ - {'row_index': i, 'field': f, 'old': originals[i][f], 'new': entries[i].get(f, '')} - for i in range(len(entries)) - for f in _CONF_FIELDS - if originals[i][f] != entries[i].get(f, '') - ] - - if REVIEW_ENGINE == "spell" and _SPELL_AVAILABLE: - # Inject char_changes as a batch right after the meta event from the spell checker - _meta_sent = False - async for event in spell_review_entries_streaming(entries, batch_size): - yield event - if not _meta_sent and event.get('type') == 'meta' and char_changes: - _meta_sent = True - yield { - 'type': 'batch', - 'changes': char_changes, - 'entries_reviewed': sorted({c['row_index'] for c in char_changes}), - 'progress': {'current': 0, 'total': len(entries)}, - } - return - - if REVIEW_ENGINE == "spell" and not _SPELL_AVAILABLE: - logger.warning("REVIEW_ENGINE=spell but pyspellchecker not installed, using LLM") - - # LLM path: emit char_changes first (before meta) so they appear in the UI - if char_changes: - yield { - 'type': 'batch', - 'changes': char_changes, - 'entries_reviewed': sorted({c['row_index'] for c in char_changes}), - 'progress': {'current': 0, 'total': len(entries)}, - } - - model = model or OLLAMA_REVIEW_MODEL - - # Separate reviewable from skipped entries - reviewable = [] - skipped_indices = [] - for i, e in enumerate(entries): - if _entry_needs_review(e): - reviewable.append((i, e)) - else: - skipped_indices.append(i) - - total_to_review = len(reviewable) - - # meta event - yield { - "type": "meta", - "total_entries": len(entries), - "to_review": total_to_review, - "skipped": len(skipped_indices), - "model": model, - "batch_size": batch_size, - } - - all_changes = [] - all_corrected = [dict(e) for e in entries] - total_duration_ms = 0 - reviewed_count = 0 - - # Process in batches - for batch_start in range(0, total_to_review, batch_size): - batch_items = reviewable[batch_start:batch_start + batch_size] - batch_entries = [e for _, e in batch_items] - - table_lines = [ - {"row": e.get("row_index", 0), "en": e.get("english", ""), "de": e.get("german", ""), "ex": e.get("example", "")} - for e in batch_entries - ] - - prompt = _build_llm_prompt(table_lines) - - logger.info("LLM review streaming: batch %d — sending %d entries to %s", - batch_start // batch_size, len(batch_entries), model) - - t0 = time.time() - async with httpx.AsyncClient(timeout=300.0) as client: - resp = await client.post( - f"{_OLLAMA_URL}/api/chat", - json={ - "model": model, - "messages": [{"role": "user", "content": prompt}], - "stream": False, - "think": False, # qwen3: disable chain-of-thought - "options": {"temperature": 0.1, "num_predict": 8192}, - }, - ) - resp.raise_for_status() - content = resp.json().get("message", {}).get("content", "") - batch_ms = int((time.time() - t0) * 1000) - total_duration_ms += batch_ms - - logger.info("LLM review streaming: response %dms, length=%d chars", batch_ms, len(content)) - logger.debug("LLM review streaming raw (first 500): %.500s", content) - - corrected = _parse_llm_json_array(content) - logger.info("LLM review streaming: parsed %d entries, applying diff...", len(corrected)) - batch_changes, batch_corrected = _diff_batch(batch_entries, corrected) - - # Merge back - for batch_idx, (orig_idx, _) in enumerate(batch_items): - if batch_idx < len(batch_corrected): - all_corrected[orig_idx] = batch_corrected[batch_idx] - - all_changes.extend(batch_changes) - reviewed_count += len(batch_items) - - # Yield batch result - yield { - "type": "batch", - "batch_index": batch_start // batch_size, - "entries_reviewed": [e.get("row_index", 0) for _, e in batch_items], - "changes": batch_changes, - "duration_ms": batch_ms, - "progress": {"current": reviewed_count, "total": total_to_review}, - } - - # Complete event - yield { - "type": "complete", - "changes": all_changes, - "model_used": model, - "duration_ms": total_duration_ms, - "total_entries": len(entries), - "reviewed": total_to_review, - "skipped": len(skipped_indices), - "corrections_found": len(all_changes), - "entries_corrected": all_corrected, - } - - -def _sanitize_for_json(text: str) -> str: - """Remove or escape control characters that break JSON parsing. - - Keeps tab (\\t), newline (\\n), carriage return (\\r) which are valid - JSON whitespace. Removes all other ASCII control characters (0x00-0x1f) - that are only valid inside JSON strings when properly escaped. - """ - # Replace literal control chars (except \\t \\n \\r) with a space - return _re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', ' ', text) - - -def _parse_llm_json_array(text: str) -> List[Dict]: - """Extract JSON array from LLM response (handles markdown fences and qwen3 think-tags).""" - # Strip qwen3 ... blocks (present even with think=False on some builds) - text = _re.sub(r'.*?', '', text, flags=_re.DOTALL) - # Strip markdown code fences - text = _re.sub(r'```json\s*', '', text) - text = _re.sub(r'```\s*', '', text) - # Sanitize control characters before JSON parsing - text = _sanitize_for_json(text) - # Find first [ ... last ] - match = _re.search(r'\[.*\]', text, _re.DOTALL) - if match: - try: - return _json.loads(match.group()) - except (ValueError, _json.JSONDecodeError) as e: - logger.warning("LLM review: JSON parse failed: %s | raw snippet: %.200s", e, match.group()[:200]) - else: - logger.warning("LLM review: no JSON array found in response (%.200s)", text[:200]) - return [] +from cv_review_spell import ( # noqa: F401 + _SPELL_AVAILABLE, + _spell_dict_knows, + _spell_fix_field, + _spell_fix_token, + _try_split_merged_word, + _normalize_page_ref, + spell_review_entries_sync, + spell_review_entries_streaming, +) + +from cv_review_llm import ( # noqa: F401 + OLLAMA_REVIEW_MODEL, + REVIEW_ENGINE, + _REVIEW_BATCH_SIZE, + _build_llm_prompt, + _diff_batch, + _entry_needs_review, + _is_spurious_change, + _parse_llm_json_array, + _sanitize_for_json, + llm_review_entries, + llm_review_entries_streaming, +) diff --git a/klausur-service/backend/cv_review_llm.py b/klausur-service/backend/cv_review_llm.py new file mode 100644 index 0000000..dc3b288 --- /dev/null +++ b/klausur-service/backend/cv_review_llm.py @@ -0,0 +1,388 @@ +""" +CV Review LLM — LLM-based OCR correction: prompt building, change detection, streaming. + +Handles the LLM review path (REVIEW_ENGINE=llm) and shared utilities like +_entry_needs_review, _is_spurious_change, _diff_batch, and JSON parsing. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import re +import time +from typing import Dict, List, Tuple + +import httpx + +logger = logging.getLogger(__name__) + +_OLLAMA_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") +OLLAMA_REVIEW_MODEL = os.getenv("OLLAMA_REVIEW_MODEL", "qwen3:0.6b") +_REVIEW_BATCH_SIZE = int(os.getenv("OLLAMA_REVIEW_BATCH_SIZE", "20")) +logger.info("LLM review model: %s (batch=%d)", OLLAMA_REVIEW_MODEL, _REVIEW_BATCH_SIZE) + +REVIEW_ENGINE = os.getenv("REVIEW_ENGINE", "spell") # "spell" (default) | "llm" + +# Regex: entry contains IPA phonetic brackets like "dance [da:ns]" +_HAS_PHONETIC_RE = re.compile(r'\[.*?[\u02c8\u02cc\u02d0\u0283\u0292\u03b8\u00f0\u014b\u0251\u0252\u0254\u0259\u025c\u026a\u028a\u028c\u00e6].*?\]') + +# Regex: digit adjacent to a letter -- OCR digit<->letter confusion +_OCR_DIGIT_IN_WORD_RE = re.compile(r'(?<=[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df])[01568]|[01568](?=[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df])') + + +def _entry_needs_review(entry: Dict) -> bool: + """Check if an entry should be sent for review. + + Sends all non-empty entries that don't have IPA phonetic transcriptions. + """ + en = entry.get("english", "") or "" + de = entry.get("german", "") or "" + + if not en.strip() and not de.strip(): + return False + if _HAS_PHONETIC_RE.search(en) or _HAS_PHONETIC_RE.search(de): + return False + return True + + +def _build_llm_prompt(table_lines: List[Dict]) -> str: + """Build the LLM correction prompt for a batch of entries.""" + return f"""Du bist ein OCR-Zeichenkorrektur-Werkzeug fuer Vokabeltabellen (Englisch-Deutsch). + +DEINE EINZIGE AUFGABE: Einzelne Zeichen korrigieren, die vom OCR-Scanner als Ziffer statt als Buchstabe erkannt wurden. + +NUR diese Korrekturen sind erlaubt: +- Ziffer 8 statt B: "8en" -> "Ben", "8uch" -> "Buch", "8all" -> "Ball" +- Ziffer 0 statt O oder o: "L0ndon" -> "London", "0ld" -> "Old" +- Ziffer 1 statt l oder I: "1ong" -> "long", "Ber1in" -> "Berlin" +- Ziffer 5 statt S oder s: "5tadt" -> "Stadt", "5ee" -> "See" +- Ziffer 6 statt G oder g: "6eld" -> "Geld" +- Senkrechter Strich | statt I oder l: "| want" -> "I want", "|ong" -> "long", "he| p" -> "help" + +ABSOLUT VERBOTEN -- aendere NIEMALS: +- Woerter die korrekt geschrieben sind -- auch wenn du eine andere Schreibweise kennst +- Uebersetzungen -- du uebersetzt NICHTS, weder EN->DE noch DE->EN +- Korrekte englische Woerter (en-Spalte) -- auch wenn du eine Bedeutung kennst +- Korrekte deutsche Woerter (de-Spalte) -- auch wenn du sie anders sagen wuerdest +- Eigennamen: Ben, London, China, Africa, Shakespeare usw. +- Abkuerzungen: sth., sb., etc., e.g., i.e., v.t., smb. usw. +- Lautschrift in eckigen Klammern [...] -- diese NIEMALS beruehren +- Beispielsaetze in der ex-Spalte -- NIEMALS aendern + +Wenn ein Wort keinen Ziffer-Buchstaben-Fehler enthaelt: gib es UNVERAENDERT zurueck und setze "corrected": false. + +Antworte NUR mit dem JSON-Array. Kein Text davor oder danach. +Behalte die exakte Struktur (gleiche Anzahl Eintraege, gleiche Reihenfolge). + +/no_think + +Eingabe: +{json.dumps(table_lines, ensure_ascii=False, indent=2)}""" + + +def _is_spurious_change(old_val: str, new_val: str) -> bool: + """Detect LLM changes that are likely wrong and should be discarded. + + Only digit<->letter substitutions (0->O, 1->l, 5->S, 6->G, 8->B) are + legitimate OCR corrections. Everything else is rejected. + """ + if not old_val or not new_val: + return False + + if old_val.lower() == new_val.lower(): + return True + + old_words = old_val.split() + new_words = new_val.split() + if abs(len(old_words) - len(new_words)) > 1: + return True + + _OCR_CHAR_MAP = { + '0': set('oOgG'), + '1': set('lLiI'), + '5': set('sS'), + '6': set('gG'), + '8': set('bB'), + '|': set('lLiI1'), + 'l': set('iI|1'), + } + has_valid_fix = False + if len(old_val) == len(new_val): + for oc, nc in zip(old_val, new_val): + if oc != nc: + if oc in _OCR_CHAR_MAP and nc in _OCR_CHAR_MAP[oc]: + has_valid_fix = True + elif nc in _OCR_CHAR_MAP and oc in _OCR_CHAR_MAP[nc]: + has_valid_fix = True + else: + _OCR_SUSPICIOUS_RE = re.compile(r'[|01568]') + if abs(len(old_val) - len(new_val)) <= 1 and _OCR_SUSPICIOUS_RE.search(old_val): + has_valid_fix = True + + if not has_valid_fix: + return True + + return False + + +def _diff_batch(originals: List[Dict], corrected: List[Dict]) -> Tuple[List[Dict], List[Dict]]: + """Compare original entries with LLM-corrected ones, return (changes, corrected_entries).""" + changes = [] + entries_out = [] + for i, orig in enumerate(originals): + if i < len(corrected): + c = corrected[i] + entry = dict(orig) + for field_name, key in [("english", "en"), ("german", "de"), ("example", "ex")]: + new_val = c.get(key, "").strip() + old_val = (orig.get(field_name, "") or "").strip() + if new_val and new_val != old_val: + if _is_spurious_change(old_val, new_val): + continue + changes.append({ + "row_index": orig.get("row_index", i), + "field": field_name, + "old": old_val, + "new": new_val, + }) + entry[field_name] = new_val + entry["llm_corrected"] = True + entries_out.append(entry) + else: + entries_out.append(dict(orig)) + return changes, entries_out + + +def _sanitize_for_json(text: str) -> str: + """Remove or escape control characters that break JSON parsing.""" + return re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', ' ', text) + + +def _parse_llm_json_array(text: str) -> List[Dict]: + """Extract JSON array from LLM response (handles markdown fences and qwen3 think-tags).""" + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'```json\s*', '', text) + text = re.sub(r'```\s*', '', text) + text = _sanitize_for_json(text) + match = re.search(r'\[.*\]', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except (ValueError, json.JSONDecodeError) as e: + logger.warning("LLM review: JSON parse failed: %s | raw snippet: %.200s", e, match.group()[:200]) + else: + logger.warning("LLM review: no JSON array found in response (%.200s)", text[:200]) + return [] + + +async def llm_review_entries( + entries: List[Dict], + model: str = None, +) -> Dict: + """OCR error correction. Uses spell-checker (REVIEW_ENGINE=spell) or LLM (REVIEW_ENGINE=llm).""" + from cv_review_spell import spell_review_entries_sync, _SPELL_AVAILABLE + + if REVIEW_ENGINE == "spell" and _SPELL_AVAILABLE: + return spell_review_entries_sync(entries) + if REVIEW_ENGINE == "spell" and not _SPELL_AVAILABLE: + logger.warning("REVIEW_ENGINE=spell but pyspellchecker not installed, using LLM") + + model = model or OLLAMA_REVIEW_MODEL + + reviewable = [(i, e) for i, e in enumerate(entries) if _entry_needs_review(e)] + + if not reviewable: + return { + "entries_original": entries, + "entries_corrected": [dict(e) for e in entries], + "changes": [], + "skipped_count": len(entries), + "model_used": model, + "duration_ms": 0, + } + + review_entries = [e for _, e in reviewable] + table_lines = [ + {"row": e.get("row_index", 0), "en": e.get("english", ""), "de": e.get("german", ""), "ex": e.get("example", "")} + for e in review_entries + ] + + logger.info("LLM review: sending %d/%d entries to %s (skipped %d without digit-pattern)", + len(review_entries), len(entries), model, len(entries) - len(reviewable)) + + prompt = _build_llm_prompt(table_lines) + + t0 = time.time() + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post( + f"{_OLLAMA_URL}/api/chat", + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "stream": False, + "think": False, + "options": {"temperature": 0.1, "num_predict": 8192}, + }, + ) + resp.raise_for_status() + content = resp.json().get("message", {}).get("content", "") + duration_ms = int((time.time() - t0) * 1000) + + logger.info("LLM review: response in %dms, raw length=%d chars", duration_ms, len(content)) + + corrected = _parse_llm_json_array(content) + changes, corrected_entries = _diff_batch(review_entries, corrected) + + all_corrected = [dict(e) for e in entries] + for batch_idx, (orig_idx, _) in enumerate(reviewable): + if batch_idx < len(corrected_entries): + all_corrected[orig_idx] = corrected_entries[batch_idx] + + return { + "entries_original": entries, + "entries_corrected": all_corrected, + "changes": changes, + "skipped_count": len(entries) - len(reviewable), + "model_used": model, + "duration_ms": duration_ms, + } + + +async def llm_review_entries_streaming( + entries: List[Dict], + model: str = None, + batch_size: int = _REVIEW_BATCH_SIZE, +): + """Async generator: yield SSE events. Uses spell-checker or LLM depending on REVIEW_ENGINE. + + Phase 0 (always): Run _fix_character_confusion and emit any changes. + """ + from cv_ocr_engines import _fix_character_confusion + from cv_review_spell import spell_review_entries_streaming, _SPELL_AVAILABLE + + _CONF_FIELDS = ('english', 'german', 'example') + originals = [{f: e.get(f, '') for f in _CONF_FIELDS} for e in entries] + _fix_character_confusion(entries) + char_changes = [ + {'row_index': i, 'field': f, 'old': originals[i][f], 'new': entries[i].get(f, '')} + for i in range(len(entries)) + for f in _CONF_FIELDS + if originals[i][f] != entries[i].get(f, '') + ] + + if REVIEW_ENGINE == "spell" and _SPELL_AVAILABLE: + _meta_sent = False + async for event in spell_review_entries_streaming(entries, batch_size): + yield event + if not _meta_sent and event.get('type') == 'meta' and char_changes: + _meta_sent = True + yield { + 'type': 'batch', + 'changes': char_changes, + 'entries_reviewed': sorted({c['row_index'] for c in char_changes}), + 'progress': {'current': 0, 'total': len(entries)}, + } + return + + if REVIEW_ENGINE == "spell" and not _SPELL_AVAILABLE: + logger.warning("REVIEW_ENGINE=spell but pyspellchecker not installed, using LLM") + + # LLM path + if char_changes: + yield { + 'type': 'batch', + 'changes': char_changes, + 'entries_reviewed': sorted({c['row_index'] for c in char_changes}), + 'progress': {'current': 0, 'total': len(entries)}, + } + + model = model or OLLAMA_REVIEW_MODEL + + reviewable = [] + skipped_indices = [] + for i, e in enumerate(entries): + if _entry_needs_review(e): + reviewable.append((i, e)) + else: + skipped_indices.append(i) + + total_to_review = len(reviewable) + + yield { + "type": "meta", + "total_entries": len(entries), + "to_review": total_to_review, + "skipped": len(skipped_indices), + "model": model, + "batch_size": batch_size, + } + + all_changes = [] + all_corrected = [dict(e) for e in entries] + total_duration_ms = 0 + reviewed_count = 0 + + for batch_start in range(0, total_to_review, batch_size): + batch_items = reviewable[batch_start:batch_start + batch_size] + batch_entries = [e for _, e in batch_items] + + table_lines = [ + {"row": e.get("row_index", 0), "en": e.get("english", ""), "de": e.get("german", ""), "ex": e.get("example", "")} + for e in batch_entries + ] + + prompt = _build_llm_prompt(table_lines) + + logger.info("LLM review streaming: batch %d -- sending %d entries to %s", + batch_start // batch_size, len(batch_entries), model) + + t0 = time.time() + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post( + f"{_OLLAMA_URL}/api/chat", + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "stream": False, + "think": False, + "options": {"temperature": 0.1, "num_predict": 8192}, + }, + ) + resp.raise_for_status() + content = resp.json().get("message", {}).get("content", "") + batch_ms = int((time.time() - t0) * 1000) + total_duration_ms += batch_ms + + corrected = _parse_llm_json_array(content) + batch_changes, batch_corrected = _diff_batch(batch_entries, corrected) + + for batch_idx, (orig_idx, _) in enumerate(batch_items): + if batch_idx < len(batch_corrected): + all_corrected[orig_idx] = batch_corrected[batch_idx] + + all_changes.extend(batch_changes) + reviewed_count += len(batch_items) + + yield { + "type": "batch", + "batch_index": batch_start // batch_size, + "entries_reviewed": [e.get("row_index", 0) for _, e in batch_items], + "changes": batch_changes, + "duration_ms": batch_ms, + "progress": {"current": reviewed_count, "total": total_to_review}, + } + + yield { + "type": "complete", + "changes": all_changes, + "model_used": model, + "duration_ms": total_duration_ms, + "total_entries": len(entries), + "reviewed": total_to_review, + "skipped": len(skipped_indices), + "corrections_found": len(all_changes), + "entries_corrected": all_corrected, + } diff --git a/klausur-service/backend/cv_review_pipeline.py b/klausur-service/backend/cv_review_pipeline.py new file mode 100644 index 0000000..746b45c --- /dev/null +++ b/klausur-service/backend/cv_review_pipeline.py @@ -0,0 +1,430 @@ +""" +CV Review Pipeline — Multi-pass OCR, line alignment, LLM post-correction, and orchestration. + +Stages 6-8 of the CV vocabulary pipeline plus the main orchestrator. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from cv_vocab_types import ( + CV_PIPELINE_AVAILABLE, + PageRegion, + PipelineResult, + VocabRow, +) +from cv_preprocessing import ( + deskew_image, + dewarp_image, + render_image_high_res, + render_pdf_high_res, +) +from cv_layout import ( + analyze_layout, + create_layout_image, + create_ocr_image, +) +from cv_ocr_engines import ( + _group_words_into_lines, +) + +logger = logging.getLogger(__name__) + +try: + import cv2 +except ImportError: + cv2 = None # type: ignore[assignment] + +try: + import pytesseract + from PIL import Image +except ImportError: + pytesseract = None # type: ignore[assignment] + Image = None # type: ignore[assignment,misc] + + +# ============================================================================= +# Stage 6: Multi-Pass OCR +# ============================================================================= + +def ocr_region(ocr_img: np.ndarray, region: PageRegion, lang: str, + psm: int, fallback_psm: Optional[int] = None, + min_confidence: float = 40.0) -> List[Dict[str, Any]]: + """Run Tesseract OCR on a specific region with given PSM. + + Args: + ocr_img: Binarized full-page image. + region: Region to crop and OCR. + lang: Tesseract language string. + psm: Page Segmentation Mode. + fallback_psm: If confidence too low, retry with this PSM per line. + min_confidence: Minimum average confidence before fallback. + + Returns: + List of word dicts with text, position, confidence. + """ + crop = ocr_img[region.y:region.y + region.height, + region.x:region.x + region.width] + + if crop.size == 0: + return [] + + pil_img = Image.fromarray(crop) + + config = f'--psm {psm} --oem 3' + try: + data = pytesseract.image_to_data(pil_img, lang=lang, config=config, + output_type=pytesseract.Output.DICT) + except Exception as e: + logger.warning(f"Tesseract failed for region {region.type}: {e}") + return [] + + words = [] + for i in range(len(data['text'])): + text = data['text'][i].strip() + conf = int(data['conf'][i]) + if not text or conf < 10: + continue + words.append({ + 'text': text, + 'left': data['left'][i] + region.x, + 'top': data['top'][i] + region.y, + 'width': data['width'][i], + 'height': data['height'][i], + 'conf': conf, + 'region_type': region.type, + }) + + if words and fallback_psm is not None: + avg_conf = sum(w['conf'] for w in words) / len(words) + if avg_conf < min_confidence: + logger.info(f"Region {region.type}: avg confidence {avg_conf:.0f}% < {min_confidence}%, " + f"trying fallback PSM {fallback_psm}") + words = _ocr_region_line_by_line(ocr_img, region, lang, fallback_psm) + + return words + + +def _ocr_region_line_by_line(ocr_img: np.ndarray, region: PageRegion, + lang: str, psm: int) -> List[Dict[str, Any]]: + """OCR a region line by line (fallback for low-confidence regions).""" + crop = ocr_img[region.y:region.y + region.height, + region.x:region.x + region.width] + + if crop.size == 0: + return [] + + inv = cv2.bitwise_not(crop) + h_proj = np.sum(inv, axis=1) + threshold = np.max(h_proj) * 0.05 if np.max(h_proj) > 0 else 0 + + lines = [] + in_text = False + line_start = 0 + for y in range(len(h_proj)): + if h_proj[y] > threshold and not in_text: + line_start = y + in_text = True + elif h_proj[y] <= threshold and in_text: + if y - line_start > 5: + lines.append((line_start, y)) + in_text = False + if in_text and len(h_proj) - line_start > 5: + lines.append((line_start, len(h_proj))) + + all_words = [] + config = f'--psm {psm} --oem 3' + + for line_y_start, line_y_end in lines: + pad = 3 + y1 = max(0, line_y_start - pad) + y2 = min(crop.shape[0], line_y_end + pad) + line_crop = crop[y1:y2, :] + + if line_crop.size == 0: + continue + + pil_img = Image.fromarray(line_crop) + try: + data = pytesseract.image_to_data(pil_img, lang=lang, config=config, + output_type=pytesseract.Output.DICT) + except Exception: + continue + + for i in range(len(data['text'])): + text = data['text'][i].strip() + conf = int(data['conf'][i]) + if not text or conf < 10: + continue + all_words.append({ + 'text': text, + 'left': data['left'][i] + region.x, + 'top': data['top'][i] + region.y + y1, + 'width': data['width'][i], + 'height': data['height'][i], + 'conf': conf, + 'region_type': region.type, + }) + + return all_words + + +def run_multi_pass_ocr(ocr_img: np.ndarray, + regions: List[PageRegion], + lang: str = "eng+deu") -> Dict[str, List[Dict]]: + """Run OCR on each detected region with optimized settings.""" + results: Dict[str, List[Dict]] = {} + + _ocr_skip = {'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + for region in regions: + if region.type in _ocr_skip: + continue + + if region.type == 'column_en': + words = ocr_region(ocr_img, region, lang='eng', psm=4) + elif region.type == 'column_de': + words = ocr_region(ocr_img, region, lang='deu', psm=4) + elif region.type == 'column_example': + words = ocr_region(ocr_img, region, lang=lang, psm=6, + fallback_psm=7, min_confidence=40.0) + else: + words = ocr_region(ocr_img, region, lang=lang, psm=6) + + results[region.type] = words + logger.info(f"OCR {region.type}: {len(words)} words") + + return results + + +# ============================================================================= +# Stage 7: Line Alignment -> Vocabulary Entries +# ============================================================================= + +def match_lines_to_vocab(ocr_results: Dict[str, List[Dict]], + regions: List[PageRegion], + y_tolerance_px: int = 25) -> List[VocabRow]: + """Align OCR results from different columns into vocabulary rows.""" + if 'column_en' not in ocr_results and 'column_de' not in ocr_results: + logger.info("match_lines_to_vocab: no column_en/column_de in OCR results, returning empty") + return [] + + en_lines = _group_words_into_lines(ocr_results.get('column_en', []), y_tolerance_px) + de_lines = _group_words_into_lines(ocr_results.get('column_de', []), y_tolerance_px) + ex_lines = _group_words_into_lines(ocr_results.get('column_example', []), y_tolerance_px) + + def line_y_center(line: List[Dict]) -> float: + return sum(w['top'] + w['height'] / 2 for w in line) / len(line) + + def line_text(line: List[Dict]) -> str: + return ' '.join(w['text'] for w in line) + + def line_confidence(line: List[Dict]) -> float: + return sum(w['conf'] for w in line) / len(line) if line else 0 + + vocab_rows: List[VocabRow] = [] + + for en_line in en_lines: + en_y = line_y_center(en_line) + en_text = line_text(en_line) + en_conf = line_confidence(en_line) + + if len(en_text.strip()) < 2: + continue + + de_text = "" + de_conf = 0.0 + best_de_dist = float('inf') + best_de_idx = -1 + for idx, de_line in enumerate(de_lines): + dist = abs(line_y_center(de_line) - en_y) + if dist < y_tolerance_px and dist < best_de_dist: + best_de_dist = dist + best_de_idx = idx + + if best_de_idx >= 0: + de_text = line_text(de_lines[best_de_idx]) + de_conf = line_confidence(de_lines[best_de_idx]) + + ex_text = "" + ex_conf = 0.0 + best_ex_dist = float('inf') + best_ex_idx = -1 + for idx, ex_line in enumerate(ex_lines): + dist = abs(line_y_center(ex_line) - en_y) + if dist < y_tolerance_px and dist < best_ex_dist: + best_ex_dist = dist + best_ex_idx = idx + + if best_ex_idx >= 0: + ex_text = line_text(ex_lines[best_ex_idx]) + ex_conf = line_confidence(ex_lines[best_ex_idx]) + + avg_conf = en_conf + conf_count = 1 + if de_conf > 0: + avg_conf += de_conf + conf_count += 1 + if ex_conf > 0: + avg_conf += ex_conf + conf_count += 1 + + vocab_rows.append(VocabRow( + english=en_text.strip(), + german=de_text.strip(), + example=ex_text.strip(), + confidence=avg_conf / conf_count, + y_position=int(en_y), + )) + + # Handle multi-line wrapping in example column + matched_ex_ys = set() + for row in vocab_rows: + if row.example: + matched_ex_ys.add(row.y_position) + + for ex_line in ex_lines: + ex_y = line_y_center(ex_line) + already_matched = any(abs(ex_y - y) < y_tolerance_px for y in matched_ex_ys) + if already_matched: + continue + + best_row = None + best_dist = float('inf') + for row in vocab_rows: + dist = ex_y - row.y_position + if 0 < dist < y_tolerance_px * 3 and dist < best_dist: + best_dist = dist + best_row = row + + if best_row: + continuation = line_text(ex_line).strip() + if continuation: + best_row.example = (best_row.example + " " + continuation).strip() + + vocab_rows.sort(key=lambda r: r.y_position) + + return vocab_rows + + +# ============================================================================= +# Stage 8: Optional LLM Post-Correction +# ============================================================================= + +async def llm_post_correct(img: np.ndarray, vocab_rows: List[VocabRow], + confidence_threshold: float = 50.0, + enabled: bool = False) -> List[VocabRow]: + """Optionally send low-confidence regions to Qwen-VL for correction.""" + if not enabled: + return vocab_rows + + logger.info(f"LLM post-correction skipped (not yet implemented)") + return vocab_rows + + +# ============================================================================= +# Orchestrator +# ============================================================================= + +async def run_cv_pipeline( + pdf_data: Optional[bytes] = None, + image_data: Optional[bytes] = None, + page_number: int = 0, + zoom: float = 3.0, + enable_dewarp: bool = True, + enable_llm_correction: bool = False, + lang: str = "eng+deu", +) -> PipelineResult: + """Run the complete CV document reconstruction pipeline.""" + if not CV_PIPELINE_AVAILABLE: + return PipelineResult(error="CV pipeline not available (OpenCV or Tesseract missing)") + + result = PipelineResult() + total_start = time.time() + + try: + # Stage 1: Render + t = time.time() + if pdf_data: + img = render_pdf_high_res(pdf_data, page_number, zoom) + elif image_data: + img = render_image_high_res(image_data) + else: + return PipelineResult(error="No input data (pdf_data or image_data required)") + result.stages['render'] = round(time.time() - t, 2) + result.image_width = img.shape[1] + result.image_height = img.shape[0] + logger.info(f"Stage 1 (render): {img.shape[1]}x{img.shape[0]} in {result.stages['render']}s") + + # Stage 2: Deskew + t = time.time() + img, angle = deskew_image(img) + result.stages['deskew'] = round(time.time() - t, 2) + logger.info(f"Stage 2 (deskew): {angle:.2f}\u00b0 in {result.stages['deskew']}s") + + # Stage 3: Dewarp + if enable_dewarp: + t = time.time() + img, _dewarp_info = dewarp_image(img) + result.stages['dewarp'] = round(time.time() - t, 2) + + # Stage 4: Dual image preparation + t = time.time() + ocr_img = create_ocr_image(img) + layout_img = create_layout_image(img) + result.stages['image_prep'] = round(time.time() - t, 2) + + # Stage 5: Layout analysis + t = time.time() + regions = analyze_layout(layout_img, ocr_img) + result.stages['layout'] = round(time.time() - t, 2) + result.columns_detected = len([r for r in regions if r.type.startswith('column')]) + logger.info(f"Stage 5 (layout): {result.columns_detected} columns in {result.stages['layout']}s") + + # Stage 6: Multi-pass OCR + t = time.time() + ocr_results = run_multi_pass_ocr(ocr_img, regions, lang) + result.stages['ocr'] = round(time.time() - t, 2) + total_words = sum(len(w) for w in ocr_results.values()) + result.word_count = total_words + logger.info(f"Stage 6 (OCR): {total_words} words in {result.stages['ocr']}s") + + # Stage 7: Line alignment + t = time.time() + vocab_rows = match_lines_to_vocab(ocr_results, regions) + result.stages['alignment'] = round(time.time() - t, 2) + + # Stage 8: Optional LLM correction + if enable_llm_correction: + t = time.time() + vocab_rows = await llm_post_correct(img, vocab_rows) + result.stages['llm_correction'] = round(time.time() - t, 2) + + # Convert to output format + result.vocabulary = [ + { + "english": row.english, + "german": row.german, + "example": row.example, + "confidence": round(row.confidence, 1), + } + for row in vocab_rows + if row.english or row.german + ] + + result.duration_seconds = round(time.time() - total_start, 2) + logger.info(f"CV Pipeline complete: {len(result.vocabulary)} entries in {result.duration_seconds}s") + + except Exception as e: + logger.error(f"CV Pipeline error: {e}") + import traceback + logger.debug(traceback.format_exc()) + result.error = str(e) + result.duration_seconds = round(time.time() - total_start, 2) + + return result diff --git a/klausur-service/backend/cv_review_spell.py b/klausur-service/backend/cv_review_spell.py new file mode 100644 index 0000000..5398a21 --- /dev/null +++ b/klausur-service/backend/cv_review_spell.py @@ -0,0 +1,315 @@ +""" +CV Review Spell — Rule-based OCR spell correction (no LLM). + +Provides dictionary-backed digit-to-letter substitution, umlaut correction, +general spell correction, merged-word splitting, and page-ref normalization. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +import time +from typing import Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +try: + from spellchecker import SpellChecker as _SpellChecker + _en_spell = _SpellChecker(language='en', distance=1) + _de_spell = _SpellChecker(language='de', distance=1) + _SPELL_AVAILABLE = True + logger.info("pyspellchecker loaded (EN+DE)") +except ImportError: + _SPELL_AVAILABLE = False + _en_spell = None # type: ignore[assignment] + _de_spell = None # type: ignore[assignment] + logger.warning("pyspellchecker not installed") + + +# ---- Page-Ref Normalization ---- +# Normalizes OCR variants like "p-60", "p 61", "p60" -> "p.60" +_PAGE_REF_RE = re.compile(r'\bp[\s\-]?(\d+)', re.IGNORECASE) + + +def _normalize_page_ref(text: str) -> str: + """Normalize page references: 'p-60' / 'p 61' / 'p60' -> 'p.60'.""" + if not text: + return text + return _PAGE_REF_RE.sub(lambda m: f"p.{m.group(1)}", text) + + +# Suspicious OCR chars -> ordered list of most-likely correct replacements +_SPELL_SUBS: Dict[str, List[str]] = { + '0': ['O', 'o'], + '1': ['l', 'I'], + '5': ['S', 's'], + '6': ['G', 'g'], + '8': ['B', 'b'], + '|': ['I', 'l', '1'], +} +_SPELL_SUSPICIOUS = frozenset(_SPELL_SUBS.keys()) + +# Tokenizer: word tokens (letters + pipe) alternating with separators +_SPELL_TOKEN_RE = re.compile(r'([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df|]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df|]*)') + + +def _spell_dict_knows(word: str) -> bool: + """True if word is known in EN or DE dictionary.""" + if not _SPELL_AVAILABLE: + return False + w = word.lower() + return bool(_en_spell.known([w])) or bool(_de_spell.known([w])) + + +def _try_split_merged_word(token: str) -> Optional[str]: + """Try to split a merged word like 'atmyschool' into 'at my school'. + + Uses dynamic programming to find the shortest sequence of dictionary + words that covers the entire token. Only returns a result when the + split produces at least 2 words and ALL parts are known dictionary words. + + Preserves original capitalisation by mapping back to the input string. + """ + if not _SPELL_AVAILABLE or len(token) < 4: + return None + + lower = token.lower() + n = len(lower) + + # dp[i] = (word_lengths_list, score) for best split of lower[:i], or None + dp: list = [None] * (n + 1) + dp[0] = ([], 0) + + for i in range(1, n + 1): + for j in range(max(0, i - 20), i): + if dp[j] is None: + continue + candidate = lower[j:i] + word_len = i - j + if word_len == 1 and candidate not in ('a', 'i'): + continue + if _spell_dict_knows(candidate): + prev_words, prev_sq = dp[j] + new_words = prev_words + [word_len] + new_sq = prev_sq + word_len * word_len + new_key = (-len(new_words), new_sq) + if dp[i] is None: + dp[i] = (new_words, new_sq) + else: + old_key = (-len(dp[i][0]), dp[i][1]) + if new_key >= old_key: + dp[i] = (new_words, new_sq) + + if dp[n] is None or len(dp[n][0]) < 2: + return None + + result = [] + pos = 0 + for wlen in dp[n][0]: + result.append(token[pos:pos + wlen]) + pos += wlen + + logger.debug("Split merged word: %r -> %r", token, " ".join(result)) + return " ".join(result) + + +def _spell_fix_token(token: str, field: str = "") -> Optional[str]: + """Return corrected form of token, or None if no fix needed/possible. + + *field* is 'english' or 'german' -- used to pick the right dictionary. + """ + has_suspicious = any(ch in _SPELL_SUSPICIOUS for ch in token) + + # 1. Already known word -> no fix needed + if _spell_dict_knows(token): + return None + + # 2. Digit/pipe substitution + if has_suspicious: + if token == '|': + return 'I' + for i, ch in enumerate(token): + if ch not in _SPELL_SUBS: + continue + for replacement in _SPELL_SUBS[ch]: + candidate = token[:i] + replacement + token[i + 1:] + if _spell_dict_knows(candidate): + return candidate + first = token[0] + if first in _SPELL_SUBS and len(token) >= 2: + rest = token[1:] + if rest.isalpha() and rest.islower(): + candidate = _SPELL_SUBS[first][0] + rest + if not candidate[0].isdigit(): + return candidate + + # 3. OCR umlaut confusion + if len(token) >= 3 and token.isalpha() and field == "german": + _UMLAUT_SUBS = {'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc', + 'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc'} + for i, ch in enumerate(token): + if ch in _UMLAUT_SUBS: + candidate = token[:i] + _UMLAUT_SUBS[ch] + token[i + 1:] + if _spell_dict_knows(candidate): + return candidate + + # 4. General spell correction for unknown words (no digits/pipes) + if not has_suspicious and len(token) >= 3 and token.isalpha(): + spell = _en_spell if field == "english" else _de_spell if field == "german" else None + if spell is not None: + correction = spell.correction(token.lower()) + if correction and correction != token.lower(): + if token[0].isupper(): + correction = correction[0].upper() + correction[1:] + if _spell_dict_knows(correction): + return correction + + # 5. Merged-word split + if len(token) >= 4 and token.isalpha(): + split = _try_split_merged_word(token) + if split: + return split + + return None + + +def _spell_fix_field(text: str, field: str = "") -> Tuple[str, bool]: + """Apply OCR corrections to a text field. Returns (fixed_text, was_changed).""" + if not text: + return text, False + has_suspicious = any(ch in text for ch in _SPELL_SUSPICIOUS) + if not has_suspicious and not any(c.isalpha() for c in text): + return text, False + # Pattern: | immediately before . or , -> numbered list prefix + fixed = re.sub(r'(? Dict: + """Rule-based OCR correction: spell-checker + structural heuristics. + + Deterministic -- never translates, never touches IPA, never hallucinates. + Uses SmartSpellChecker for language-aware corrections with context-based + disambiguation (a/I), multi-digit substitution, and cross-language guard. + """ + from cv_review_llm import _entry_needs_review + + t0 = time.time() + changes: List[Dict] = [] + all_corrected: List[Dict] = [] + + # Use SmartSpellChecker if available + _smart = None + try: + from smart_spell import SmartSpellChecker + _smart = SmartSpellChecker() + logger.debug("spell_review: using SmartSpellChecker") + except Exception: + logger.debug("spell_review: SmartSpellChecker not available, using legacy") + + _LANG_MAP = {"english": "en", "german": "de", "example": "auto"} + + for i, entry in enumerate(entries): + e = dict(entry) + # Page-ref normalization + old_ref = (e.get("source_page") or "").strip() + if old_ref: + new_ref = _normalize_page_ref(old_ref) + if new_ref != old_ref: + changes.append({ + "row_index": e.get("row_index", i), + "field": "source_page", + "old": old_ref, + "new": new_ref, + }) + e["source_page"] = new_ref + e["llm_corrected"] = True + if not _entry_needs_review(e): + all_corrected.append(e) + continue + for field_name in ("english", "german", "example"): + old_val = (e.get(field_name) or "").strip() + if not old_val: + continue + + if _smart: + lang_code = _LANG_MAP.get(field_name, "en") + result = _smart.correct_text(old_val, lang=lang_code) + new_val = result.corrected + was_changed = result.changed + else: + lang = "german" if field_name in ("german", "example") else "english" + new_val, was_changed = _spell_fix_field(old_val, field=lang) + + if was_changed and new_val != old_val: + changes.append({ + "row_index": e.get("row_index", i), + "field": field_name, + "old": old_val, + "new": new_val, + }) + e[field_name] = new_val + e["llm_corrected"] = True + all_corrected.append(e) + duration_ms = int((time.time() - t0) * 1000) + model_name = "smart-spell-checker" if _smart else "spell-checker" + return { + "entries_original": entries, + "entries_corrected": all_corrected, + "changes": changes, + "skipped_count": 0, + "model_used": model_name, + "duration_ms": duration_ms, + } + + +async def spell_review_entries_streaming(entries: List[Dict], batch_size: int = 50): + """Async generator yielding SSE-compatible events for spell-checker review.""" + total = len(entries) + yield { + "type": "meta", + "total_entries": total, + "to_review": total, + "skipped": 0, + "model": "spell-checker", + "batch_size": batch_size, + } + result = spell_review_entries_sync(entries) + changes = result["changes"] + yield { + "type": "batch", + "batch_index": 0, + "entries_reviewed": [e.get("row_index", i) for i, e in enumerate(entries)], + "changes": changes, + "duration_ms": result["duration_ms"], + "progress": {"current": total, "total": total}, + } + yield { + "type": "complete", + "changes": changes, + "model_used": "spell-checker", + "duration_ms": result["duration_ms"], + "total_entries": total, + "reviewed": total, + "skipped": 0, + "corrections_found": len(changes), + "entries_corrected": result["entries_corrected"], + } diff --git a/klausur-service/backend/grid_editor_columns.py b/klausur-service/backend/grid_editor_columns.py new file mode 100644 index 0000000..6731798 --- /dev/null +++ b/klausur-service/backend/grid_editor_columns.py @@ -0,0 +1,492 @@ +""" +Grid Editor — column detection, cross-column splitting, marker merging. + +Split from grid_editor_helpers.py for maintainability. +All functions are pure computation — no HTTP, DB, or session side effects. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Cross-column word splitting +# --------------------------------------------------------------------------- + +_spell_cache: Optional[Any] = None +_spell_loaded = False + + +def _is_recognized_word(text: str) -> bool: + """Check if *text* is a recognized German or English word. + + Uses the spellchecker library (same as cv_syllable_detect.py). + Returns True for real words like "oder", "Kabel", "Zeitung". + Returns False for OCR merge artifacts like "sichzie", "dasZimmer". + """ + global _spell_cache, _spell_loaded + if not text or len(text) < 2: + return False + + if not _spell_loaded: + _spell_loaded = True + try: + from spellchecker import SpellChecker + _spell_cache = SpellChecker(language="de") + except Exception: + pass + + if _spell_cache is None: + return False + + return text.lower() in _spell_cache + + +def _split_cross_column_words( + words: List[Dict], + columns: List[Dict], +) -> List[Dict]: + """Split word boxes that span across column boundaries. + + When OCR merges adjacent words from different columns (e.g. "sichzie" + spanning Col 1 and Col 2, or "dasZimmer" crossing the boundary), + split the word box at the column boundary so each piece is assigned + to the correct column. + + Only splits when: + - The word has significant overlap (>15% of its width) on both sides + - AND the word is not a recognized real word (OCR merge artifact), OR + the word contains a case transition (lowercase->uppercase) near the + boundary indicating two merged words like "dasZimmer". + """ + if len(columns) < 2: + return words + + # Column boundaries = midpoints between adjacent column edges + boundaries = [] + for i in range(len(columns) - 1): + boundary = (columns[i]["x_max"] + columns[i + 1]["x_min"]) / 2 + boundaries.append(boundary) + + new_words: List[Dict] = [] + split_count = 0 + for w in words: + w_left = w["left"] + w_width = w["width"] + w_right = w_left + w_width + text = (w.get("text") or "").strip() + + if not text or len(text) < 4 or w_width < 10: + new_words.append(w) + continue + + # Find the first boundary this word straddles significantly + split_boundary = None + for b in boundaries: + if w_left < b < w_right: + left_part = b - w_left + right_part = w_right - b + # Both sides must have at least 15% of the word width + if left_part > w_width * 0.15 and right_part > w_width * 0.15: + split_boundary = b + break + + if split_boundary is None: + new_words.append(w) + continue + + # Compute approximate split position in the text. + left_width = split_boundary - w_left + split_ratio = left_width / w_width + approx_pos = len(text) * split_ratio + + # Strategy 1: look for a case transition (lowercase->uppercase) near + # the approximate split point — e.g. "dasZimmer" splits at 'Z'. + split_char = None + search_lo = max(1, int(approx_pos) - 3) + search_hi = min(len(text), int(approx_pos) + 2) + for i in range(search_lo, search_hi): + if text[i - 1].islower() and text[i].isupper(): + split_char = i + break + + # Strategy 2: if no case transition, only split if the whole word + # is NOT a real word (i.e. it's an OCR merge artifact like "sichzie"). + # Real words like "oder", "Kabel", "Zeitung" must not be split. + if split_char is None: + clean = re.sub(r"[,;:.!?]+$", "", text) # strip trailing punct + if _is_recognized_word(clean): + new_words.append(w) + continue + # Not a real word — use floor of proportional position + split_char = max(1, min(len(text) - 1, int(approx_pos))) + + left_text = text[:split_char].rstrip() + right_text = text[split_char:].lstrip() + + if len(left_text) < 2 or len(right_text) < 2: + new_words.append(w) + continue + + right_width = w_width - round(left_width) + new_words.append({ + **w, + "text": left_text, + "width": round(left_width), + }) + new_words.append({ + **w, + "text": right_text, + "left": round(split_boundary), + "width": right_width, + }) + split_count += 1 + logger.info( + "split cross-column word %r -> %r + %r at boundary %.0f", + text, left_text, right_text, split_boundary, + ) + + if split_count: + logger.info("split %d cross-column word(s)", split_count) + return new_words + + +def _cluster_columns_by_alignment( + words: List[Dict], + zone_w: int, + rows: List[Dict], +) -> List[Dict[str, Any]]: + """Detect columns by clustering left-edge alignment across rows. + + Hybrid approach: + 1. Group words by row, find "group start" positions within each row + (words preceded by a large gap or first word in row) + 2. Cluster group-start left-edges by X-proximity across rows + 3. Filter by row coverage (how many rows have a group start here) + 4. Merge nearby clusters + 5. Build column boundaries + + This filters out mid-phrase word positions (e.g. IPA transcriptions, + second words in multi-word entries) by only considering positions + where a new word group begins within a row. + """ + if not words or not rows: + return [] + + total_rows = len(rows) + if total_rows == 0: + return [] + + # --- Group words by row --- + row_words: Dict[int, List[Dict]] = {} + for w in words: + y_center = w["top"] + w["height"] / 2 + best = min(rows, key=lambda r: abs(r["y_center"] - y_center)) + row_words.setdefault(best["index"], []).append(w) + + # --- Compute adaptive gap threshold for group-start detection --- + all_gaps: List[float] = [] + for ri, rw_list in row_words.items(): + sorted_rw = sorted(rw_list, key=lambda w: w["left"]) + for i in range(len(sorted_rw) - 1): + right = sorted_rw[i]["left"] + sorted_rw[i]["width"] + gap = sorted_rw[i + 1]["left"] - right + if gap > 0: + all_gaps.append(gap) + + if all_gaps: + sorted_gaps = sorted(all_gaps) + median_gap = sorted_gaps[len(sorted_gaps) // 2] + heights = [w["height"] for w in words if w.get("height", 0) > 0] + median_h = sorted(heights)[len(heights) // 2] if heights else 25 + + # For small word counts (boxes, sub-zones): PaddleOCR returns + # multi-word blocks, so ALL inter-word gaps are potential column + # boundaries. Use a low threshold based on word height — any gap + # wider than ~1x median word height is a column separator. + if len(words) <= 60: + gap_threshold = max(median_h * 1.0, 25) + logger.info( + "alignment columns (small zone): gap_threshold=%.0f " + "(median_h=%.0f, %d words, %d gaps: %s)", + gap_threshold, median_h, len(words), len(sorted_gaps), + [int(g) for g in sorted_gaps[:10]], + ) + else: + # Standard approach for large zones (full pages) + gap_threshold = max(median_gap * 3, median_h * 1.5, 30) + # Cap at 25% of zone width + max_gap = zone_w * 0.25 + if gap_threshold > max_gap > 30: + logger.info("alignment columns: capping gap_threshold %.0f -> %.0f (25%% of zone_w=%d)", gap_threshold, max_gap, zone_w) + gap_threshold = max_gap + else: + gap_threshold = 50 + + # --- Find group-start positions (left-edges that begin a new column) --- + start_positions: List[tuple] = [] # (left_edge, row_index) + for ri, rw_list in row_words.items(): + sorted_rw = sorted(rw_list, key=lambda w: w["left"]) + # First word in row is always a group start + start_positions.append((sorted_rw[0]["left"], ri)) + for i in range(1, len(sorted_rw)): + right_prev = sorted_rw[i - 1]["left"] + sorted_rw[i - 1]["width"] + gap = sorted_rw[i]["left"] - right_prev + if gap >= gap_threshold: + start_positions.append((sorted_rw[i]["left"], ri)) + + start_positions.sort(key=lambda x: x[0]) + + logger.info( + "alignment columns: %d group-start positions from %d words " + "(gap_threshold=%.0f, %d rows)", + len(start_positions), len(words), gap_threshold, total_rows, + ) + + if not start_positions: + x_min = min(w["left"] for w in words) + x_max = max(w["left"] + w["width"] for w in words) + return [{"index": 0, "type": "column_text", "x_min": x_min, "x_max": x_max}] + + # --- Cluster group-start positions by X-proximity --- + tolerance = max(10, int(zone_w * 0.01)) + clusters: List[Dict[str, Any]] = [] + cur_edges = [start_positions[0][0]] + cur_rows = {start_positions[0][1]} + + for left, row_idx in start_positions[1:]: + if left - cur_edges[-1] <= tolerance: + cur_edges.append(left) + cur_rows.add(row_idx) + else: + clusters.append({ + "mean_x": int(sum(cur_edges) / len(cur_edges)), + "min_edge": min(cur_edges), + "max_edge": max(cur_edges), + "count": len(cur_edges), + "distinct_rows": len(cur_rows), + "row_coverage": len(cur_rows) / total_rows, + }) + cur_edges = [left] + cur_rows = {row_idx} + clusters.append({ + "mean_x": int(sum(cur_edges) / len(cur_edges)), + "min_edge": min(cur_edges), + "max_edge": max(cur_edges), + "count": len(cur_edges), + "distinct_rows": len(cur_rows), + "row_coverage": len(cur_rows) / total_rows, + }) + + # --- Filter by row coverage --- + # These thresholds must be high enough to avoid false columns in flowing + # text (random inter-word gaps) while still detecting real columns in + # vocabulary worksheets (which typically have >80% row coverage). + MIN_COVERAGE_PRIMARY = 0.35 + MIN_COVERAGE_SECONDARY = 0.12 + MIN_WORDS_SECONDARY = 4 + MIN_DISTINCT_ROWS = 3 + + # Content boundary for left-margin detection + content_x_min = min(w["left"] for w in words) + content_x_max = max(w["left"] + w["width"] for w in words) + content_span = content_x_max - content_x_min + + primary = [ + c for c in clusters + if c["row_coverage"] >= MIN_COVERAGE_PRIMARY + and c["distinct_rows"] >= MIN_DISTINCT_ROWS + ] + primary_ids = {id(c) for c in primary} + secondary = [ + c for c in clusters + if id(c) not in primary_ids + and c["row_coverage"] >= MIN_COVERAGE_SECONDARY + and c["count"] >= MIN_WORDS_SECONDARY + and c["distinct_rows"] >= MIN_DISTINCT_ROWS + ] + + # Tertiary: narrow left-margin columns (page refs, markers) that have + # too few rows for secondary but are clearly left-aligned and separated + # from the main content. These appear at the far left or far right and + # have a large gap to the nearest significant cluster. + used_ids = {id(c) for c in primary} | {id(c) for c in secondary} + sig_xs = [c["mean_x"] for c in primary + secondary] + + # Tertiary: clusters that are clearly to the LEFT of the first + # significant column (or RIGHT of the last). If words consistently + # start at a position left of the established first column boundary, + # they MUST be a separate column — regardless of how few rows they + # cover. The only requirement is a clear spatial gap. + MIN_COVERAGE_TERTIARY = 0.02 # at least 1 row effectively + tertiary = [] + for c in clusters: + if id(c) in used_ids: + continue + if c["distinct_rows"] < 1: + continue + if c["row_coverage"] < MIN_COVERAGE_TERTIARY: + continue + # Must be near left or right content margin (within 15%) + rel_pos = (c["mean_x"] - content_x_min) / content_span if content_span else 0.5 + if not (rel_pos < 0.15 or rel_pos > 0.85): + continue + # Must have significant gap to nearest significant cluster + if sig_xs: + min_dist = min(abs(c["mean_x"] - sx) for sx in sig_xs) + if min_dist < max(30, content_span * 0.02): + continue + tertiary.append(c) + + if tertiary: + for c in tertiary: + logger.info( + " tertiary (margin) cluster: x=%d (range %d-%d), %d words, %d rows (%.0f%%)", + c["mean_x"], c["min_edge"], c["max_edge"], + c["count"], c["distinct_rows"], c["row_coverage"] * 100, + ) + + significant = sorted(primary + secondary + tertiary, key=lambda c: c["mean_x"]) + + for c in significant: + logger.info( + " significant cluster: x=%d (range %d-%d), %d words, %d rows (%.0f%%)", + c["mean_x"], c["min_edge"], c["max_edge"], + c["count"], c["distinct_rows"], c["row_coverage"] * 100, + ) + logger.info( + "alignment columns: %d clusters, %d primary, %d secondary -> %d significant", + len(clusters), len(primary), len(secondary), len(significant), + ) + + if not significant: + # Fallback: single column covering all content + x_min = min(w["left"] for w in words) + x_max = max(w["left"] + w["width"] for w in words) + return [{"index": 0, "type": "column_text", "x_min": x_min, "x_max": x_max}] + + # --- Merge nearby clusters --- + merge_distance = max(25, int(zone_w * 0.03)) + merged = [significant[0].copy()] + for s in significant[1:]: + if s["mean_x"] - merged[-1]["mean_x"] < merge_distance: + prev = merged[-1] + total = prev["count"] + s["count"] + prev["mean_x"] = ( + prev["mean_x"] * prev["count"] + s["mean_x"] * s["count"] + ) // total + prev["count"] = total + prev["min_edge"] = min(prev["min_edge"], s["min_edge"]) + prev["max_edge"] = max(prev["max_edge"], s["max_edge"]) + prev["distinct_rows"] = max(prev["distinct_rows"], s["distinct_rows"]) + else: + merged.append(s.copy()) + + logger.info( + "alignment columns: %d after merge (distance=%d)", + len(merged), merge_distance, + ) + + # --- Build column boundaries --- + margin = max(5, int(zone_w * 0.005)) + content_x_min = min(w["left"] for w in words) + content_x_max = max(w["left"] + w["width"] for w in words) + + columns: List[Dict[str, Any]] = [] + for i, cluster in enumerate(merged): + x_min = max(content_x_min, cluster["min_edge"] - margin) + if i + 1 < len(merged): + x_max = merged[i + 1]["min_edge"] - margin + else: + x_max = content_x_max + + columns.append({ + "index": i, + "type": f"column_{i + 1}" if len(merged) > 1 else "column_text", + "x_min": x_min, + "x_max": x_max, + }) + + return columns + + +_MARKER_CHARS = set("*-+#>") + + +def _merge_inline_marker_columns( + columns: List[Dict], + words: List[Dict], +) -> List[Dict]: + """Merge narrow marker columns (bullets, numbering) into adjacent text. + + Bullet points (*, -) and numbering (1., 2.) create narrow columns + at the left edge of a zone. These are inline markers that indent text, + not real separate columns. Merge them with their right neighbour. + + Does NOT merge columns containing alphabetic words like "to", "in", + "der", "die", "das" — those are legitimate content columns. + """ + if len(columns) < 2: + return columns + + merged: List[Dict] = [] + skip: set = set() + + for i, col in enumerate(columns): + if i in skip: + continue + + # Find words in this column + col_words = [ + w for w in words + if col["x_min"] <= w["left"] + w["width"] / 2 < col["x_max"] + ] + col_width = col["x_max"] - col["x_min"] + + # Narrow column with mostly short words -> MIGHT be inline markers + if col_words and col_width < 80: + avg_len = sum(len(w.get("text", "")) for w in col_words) / len(col_words) + if avg_len <= 2 and i + 1 < len(columns): + # Check if words are actual markers (symbols/numbers) vs + # real alphabetic words like "to", "in", "der", "die" + texts = [(w.get("text") or "").strip() for w in col_words] + alpha_count = sum( + 1 for t in texts + if t and t[0].isalpha() and t not in _MARKER_CHARS + ) + alpha_ratio = alpha_count / len(texts) if texts else 0 + + # If >=50% of words are alphabetic, this is a real column + if alpha_ratio >= 0.5: + logger.info( + " kept narrow column %d (w=%d, avg_len=%.1f, " + "alpha=%.0f%%) -- contains real words", + i, col_width, avg_len, alpha_ratio * 100, + ) + else: + # Merge into next column + next_col = columns[i + 1].copy() + next_col["x_min"] = col["x_min"] + merged.append(next_col) + skip.add(i + 1) + logger.info( + " merged inline marker column %d (w=%d, avg_len=%.1f) " + "into column %d", + i, col_width, avg_len, i + 1, + ) + continue + + merged.append(col) + + # Re-index + for i, col in enumerate(merged): + col["index"] = i + col["type"] = f"column_{i + 1}" if len(merged) > 1 else "column_text" + + return merged diff --git a/klausur-service/backend/grid_editor_filters.py b/klausur-service/backend/grid_editor_filters.py new file mode 100644 index 0000000..c938569 --- /dev/null +++ b/klausur-service/backend/grid_editor_filters.py @@ -0,0 +1,402 @@ +""" +Grid Editor — word/zone filtering, border ghosts, decorative margins, footers. + +Split from grid_editor_helpers.py for maintainability. +All functions are pure computation — no HTTP, DB, or session side effects. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _filter_border_strip_words(words: List[Dict]) -> Tuple[List[Dict], int]: + """Remove page-border decoration strip words BEFORE column detection. + + Scans from each page edge inward to find the first significant x-gap + (>30 px). If the edge cluster contains <15 % of total words, those + words are removed as border-strip artifacts (alphabet letters, + illustration fragments). + + Must run BEFORE ``_build_zone_grid`` so that column detection only + sees real content words and doesn't produce inflated row counts. + """ + if len(words) < 10: + return words, 0 + + sorted_words = sorted(words, key=lambda w: w.get("left", 0)) + total = len(sorted_words) + + # -- Left-edge scan (running max right-edge) -- + left_count = 0 + running_right = 0 + for gi in range(total - 1): + running_right = max( + running_right, + sorted_words[gi].get("left", 0) + sorted_words[gi].get("width", 0), + ) + if sorted_words[gi + 1].get("left", 0) - running_right > 30: + left_count = gi + 1 + break + + # -- Right-edge scan (running min left) -- + right_count = 0 + running_left = sorted_words[-1].get("left", 0) + for gi in range(total - 1, 0, -1): + running_left = min(running_left, sorted_words[gi].get("left", 0)) + prev_right = ( + sorted_words[gi - 1].get("left", 0) + + sorted_words[gi - 1].get("width", 0) + ) + if running_left - prev_right > 30: + right_count = total - gi + break + + # Validate candidate strip: real border decorations are mostly short + # words (alphabet letters like "A", "Bb", stray marks). Multi-word + # content like "der Ranzen" or "die Schals" (continuation of German + # translations) must NOT be removed. + def _is_decorative_strip(candidates: List[Dict]) -> bool: + if not candidates: + return False + short = sum(1 for w in candidates if len((w.get("text") or "").strip()) <= 2) + return short / len(candidates) >= 0.45 + + strip_ids: set = set() + if left_count > 0 and left_count / total < 0.20: + candidates = sorted_words[:left_count] + if _is_decorative_strip(candidates): + strip_ids = {id(w) for w in candidates} + elif right_count > 0 and right_count / total < 0.20: + candidates = sorted_words[total - right_count:] + if _is_decorative_strip(candidates): + strip_ids = {id(w) for w in candidates} + + if not strip_ids: + return words, 0 + + return [w for w in words if id(w) not in strip_ids], len(strip_ids) + + +# Characters that are typically OCR artefacts from box border lines. +# Intentionally excludes ! (red markers) and . , ; (real punctuation). +_GRID_GHOST_CHARS = set("|1lI[](){}/\\-\u2014\u2013_~=+") + + +def _filter_border_ghosts( + words: List[Dict], + boxes: List, +) -> tuple: + """Remove words sitting on box borders that are OCR artefacts. + + Returns (filtered_words, removed_count). + """ + if not boxes or not words: + return words, 0 + + # Build border bands from detected boxes + x_bands: List[tuple] = [] + y_bands: List[tuple] = [] + for b in boxes: + bt = ( + b.border_thickness + if hasattr(b, "border_thickness") + else b.get("border_thickness", 3) + ) + # Skip borderless boxes (images/graphics) -- no border line to produce ghosts + if bt == 0: + continue + bx = b.x if hasattr(b, "x") else b.get("x", 0) + by = b.y if hasattr(b, "y") else b.get("y", 0) + bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) + bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) + margin = max(bt * 2, 10) + 6 + x_bands.append((bx - margin, bx + margin)) + x_bands.append((bx + bw - margin, bx + bw + margin)) + y_bands.append((by - margin, by + margin)) + y_bands.append((by + bh - margin, by + bh + margin)) + + def _is_ghost(w: Dict) -> bool: + text = (w.get("text") or "").strip() + if not text: + return False + # Check if any word edge (not just center) touches a border band + w_left = w["left"] + w_right = w["left"] + w["width"] + w_top = w["top"] + w_bottom = w["top"] + w["height"] + on_border = ( + any(lo <= w_left <= hi or lo <= w_right <= hi for lo, hi in x_bands) + or any(lo <= w_top <= hi or lo <= w_bottom <= hi for lo, hi in y_bands) + ) + if not on_border: + return False + if len(text) == 1 and text in _GRID_GHOST_CHARS: + return True + return False + + filtered = [w for w in words if not _is_ghost(w)] + return filtered, len(words) - len(filtered) + + +def _flatten_word_boxes(cells: List[Dict]) -> List[Dict]: + """Extract all word_boxes from cells into a flat list of word dicts.""" + words: List[Dict] = [] + for cell in cells: + for wb in cell.get("word_boxes") or []: + if wb.get("text", "").strip(): + words.append({ + "text": wb["text"], + "left": wb["left"], + "top": wb["top"], + "width": wb["width"], + "height": wb["height"], + "conf": wb.get("conf", 0), + }) + return words + + +def _words_in_zone( + words: List[Dict], + zone_y: int, + zone_h: int, + zone_x: int, + zone_w: int, +) -> List[Dict]: + """Filter words whose Y-center falls within a zone's bounds.""" + zone_y_end = zone_y + zone_h + zone_x_end = zone_x + zone_w + result = [] + for w in words: + cy = w["top"] + w["height"] / 2 + cx = w["left"] + w["width"] / 2 + if zone_y <= cy <= zone_y_end and zone_x <= cx <= zone_x_end: + result.append(w) + return result + + +def _get_content_bounds(words: List[Dict]) -> tuple: + """Get content bounds from word positions.""" + if not words: + return 0, 0, 0, 0 + x_min = min(w["left"] for w in words) + y_min = min(w["top"] for w in words) + x_max = max(w["left"] + w["width"] for w in words) + y_max = max(w["top"] + w["height"] for w in words) + return x_min, y_min, x_max - x_min, y_max - y_min + + +def _filter_decorative_margin( + words: List[Dict], + img_w: int, + log: Any, + session_id: str, +) -> Dict[str, Any]: + """Remove words that belong to a decorative alphabet strip on a margin. + + Some vocabulary worksheets have a vertical A-Z alphabet graphic along + the left or right edge. OCR reads each letter as an isolated single- + character word. These decorative elements are not content and confuse + column/row detection. + + Detection criteria (phase 1 -- find the strip using single-char words): + - Words are in the outer 30% of the page (left or right) + - Nearly all words are single characters (letters or digits) + - At least 8 such words form a vertical strip (>=8 unique Y positions) + - Average horizontal spread of the strip is small (< 80px) + + Phase 2 -- once a strip is confirmed, also remove any short word (<=3 + chars) in the same narrow x-range. This catches multi-char OCR + artifacts like "Vv" that belong to the same decorative element. + + Modifies *words* in place. + + Returns: + Dict with 'found' (bool), 'side' (str), 'letters_detected' (int). + """ + no_strip: Dict[str, Any] = {"found": False, "side": "", "letters_detected": 0} + if not words or img_w <= 0: + return no_strip + + margin_cutoff = img_w * 0.30 + # Phase 1: find candidate strips using short words (1-2 chars). + # OCR often reads alphabet sidebar letters as pairs ("Aa", "Bb") + # rather than singles, so accept <=2-char words as strip candidates. + left_strip = [ + w for w in words + if len((w.get("text") or "").strip()) <= 2 + and w["left"] + w.get("width", 0) / 2 < margin_cutoff + ] + right_strip = [ + w for w in words + if len((w.get("text") or "").strip()) <= 2 + and w["left"] + w.get("width", 0) / 2 > img_w - margin_cutoff + ] + + for strip, side in [(left_strip, "left"), (right_strip, "right")]: + if len(strip) < 6: + continue + # Check vertical distribution: should have many distinct Y positions + y_centers = sorted(set( + int(w["top"] + w.get("height", 0) / 2) // 20 * 20 # bucket + for w in strip + )) + if len(y_centers) < 6: + continue + # Check horizontal compactness + x_positions = [w["left"] for w in strip] + x_min = min(x_positions) + x_max = max(x_positions) + x_spread = x_max - x_min + if x_spread > 80: + continue + + # Phase 2: strip confirmed -- also collect short words in same x-range + # Expand x-range slightly to catch neighbors (e.g. "Vv" next to "U") + strip_x_lo = x_min - 20 + strip_x_hi = x_max + 60 # word width + tolerance + all_strip_words = [ + w for w in words + if len((w.get("text") or "").strip()) <= 3 + and strip_x_lo <= w["left"] <= strip_x_hi + and (w["left"] + w.get("width", 0) / 2 < margin_cutoff + if side == "left" + else w["left"] + w.get("width", 0) / 2 > img_w - margin_cutoff) + ] + + strip_set = set(id(w) for w in all_strip_words) + before = len(words) + words[:] = [w for w in words if id(w) not in strip_set] + removed = before - len(words) + if removed: + log.info( + "build-grid session %s: removed %d decorative %s-margin words " + "(strip x=%d-%d)", + session_id, removed, side, strip_x_lo, strip_x_hi, + ) + return {"found": True, "side": side, "letters_detected": len(strip)} + + return no_strip + + +def _filter_footer_words( + words: List[Dict], + img_h: int, + log: Any, + session_id: str, +) -> Optional[Dict]: + """Remove isolated words in the bottom 5% of the page (page numbers). + + Modifies *words* in place and returns a page_number metadata dict + if a page number was extracted, or None. + """ + if not words or img_h <= 0: + return None + footer_y = img_h * 0.95 + footer_words = [ + w for w in words + if w["top"] + w.get("height", 0) / 2 > footer_y + ] + if not footer_words: + return None + # Only remove if footer has very few words (<= 3) with short text + total_text = "".join((w.get("text") or "").strip() for w in footer_words) + if len(footer_words) <= 3 and len(total_text) <= 10: + # Extract page number metadata before removing + page_number_info = { + "text": total_text.strip(), + "y_pct": round(footer_words[0]["top"] / img_h * 100, 1), + } + # Try to parse as integer + digits = "".join(c for c in total_text if c.isdigit()) + if digits: + page_number_info["number"] = int(digits) + + footer_set = set(id(w) for w in footer_words) + words[:] = [w for w in words if id(w) not in footer_set] + log.info( + "build-grid session %s: extracted page number '%s' and removed %d footer words", + session_id, total_text, len(footer_words), + ) + return page_number_info + return None + + +def _filter_header_junk( + words: List[Dict], + img_h: int, + log: Any, + session_id: str, +) -> None: + """Remove OCR junk from header illustrations above the real content. + + Textbook pages often have decorative header graphics (illustrations, + icons) that OCR reads as low-confidence junk characters. Real content + typically starts further down the page. + + Algorithm: + 1. Find the "content start" -- the first Y position where a dense + horizontal row of 3+ high-confidence words begins. + 2. Above that line, remove words with conf < 75 and text <= 3 chars. + These are almost certainly OCR artifacts from illustrations. + + Modifies *words* in place. + """ + if not words or img_h <= 0: + return + + # --- Find content start: first horizontal row with >=3 high-conf words --- + # Sort words by Y + sorted_by_y = sorted(words, key=lambda w: w["top"]) + content_start_y = 0 + _ROW_TOLERANCE = img_h * 0.02 # words within 2% of page height = same row + _MIN_ROW_WORDS = 3 + _MIN_CONF = 80 + + i = 0 + while i < len(sorted_by_y): + row_y = sorted_by_y[i]["top"] + # Collect words in this row band + row_words = [] + j = i + while j < len(sorted_by_y) and sorted_by_y[j]["top"] - row_y < _ROW_TOLERANCE: + row_words.append(sorted_by_y[j]) + j += 1 + # Count high-confidence words with real text (> 1 char) + high_conf = [ + w for w in row_words + if w.get("conf", 0) >= _MIN_CONF + and len((w.get("text") or "").strip()) > 1 + ] + if len(high_conf) >= _MIN_ROW_WORDS: + content_start_y = row_y + break + i = j if j > i else i + 1 + + if content_start_y <= 0: + return # no clear content start found + + # --- Remove low-conf short junk above content start --- + junk = [ + w for w in words + if w["top"] + w.get("height", 0) < content_start_y + and w.get("conf", 0) < 75 + and len((w.get("text") or "").strip()) <= 3 + ] + if not junk: + return + + junk_set = set(id(w) for w in junk) + before = len(words) + words[:] = [w for w in words if id(w) not in junk_set] + removed = before - len(words) + if removed: + log.info( + "build-grid session %s: removed %d header junk words above y=%d " + "(content start)", + session_id, removed, content_start_y, + ) diff --git a/klausur-service/backend/grid_editor_headers.py b/klausur-service/backend/grid_editor_headers.py new file mode 100644 index 0000000..3096e59 --- /dev/null +++ b/klausur-service/backend/grid_editor_headers.py @@ -0,0 +1,499 @@ +""" +Grid Editor — header/heading detection and colspan (merged cell) detection. +Split from grid_editor_helpers.py. Pure computation, no HTTP/DB side effects. +Lizenz: Apache 2.0 | DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Any, Dict, List, Optional + +from cv_ocr_engines import _text_has_garbled_ipa + +logger = logging.getLogger(__name__) + + +def _detect_heading_rows_by_color(zones_data: List[Dict], img_w: int, img_h: int) -> int: + """Detect heading rows by color + height after color annotation. + + A row is a heading if: + 1. ALL word_boxes have color_name != 'black' (typically 'blue') + 2. Mean word height > 1.2x median height of all words in the zone + + Detected heading rows are merged into a single spanning cell. + Returns count of headings detected. + """ + heading_count = 0 + + for z in zones_data: + cells = z.get("cells", []) + rows = z.get("rows", []) + columns = z.get("columns", []) + if not cells or not rows or len(columns) < 2: + continue + + # Compute median word height across the zone + all_heights = [] + for cell in cells: + for wb in cell.get("word_boxes") or []: + h = wb.get("height", 0) + if h > 0: + all_heights.append(h) + if not all_heights: + continue + all_heights_sorted = sorted(all_heights) + median_h = all_heights_sorted[len(all_heights_sorted) // 2] + + heading_row_indices = [] + for row in rows: + if row.get("is_header"): + continue # already detected as header + ri = row["index"] + row_cells = [c for c in cells if c.get("row_index") == ri] + row_wbs = [ + wb for cell in row_cells + for wb in cell.get("word_boxes") or [] + ] + if not row_wbs: + continue + + # Condition 1: ALL words are non-black + all_colored = all( + wb.get("color_name", "black") != "black" + for wb in row_wbs + ) + if not all_colored: + continue + + # Condition 2: mean height > 1.2x median + mean_h = sum(wb.get("height", 0) for wb in row_wbs) / len(row_wbs) + if mean_h <= median_h * 1.2: + continue + + heading_row_indices.append(ri) + + # Merge heading cells into spanning cells + for hri in heading_row_indices: + header_cells = [c for c in cells if c.get("row_index") == hri] + if len(header_cells) <= 1: + # Single cell -- just mark it as heading + if header_cells: + header_cells[0]["col_type"] = "heading" + heading_count += 1 + # Mark row as header + for row in rows: + if row["index"] == hri: + row["is_header"] = True + continue + + # Collect all word_boxes and text from all columns + all_wb = [] + all_text_parts = [] + for hc in sorted(header_cells, key=lambda c: c["col_index"]): + all_wb.extend(hc.get("word_boxes", [])) + if hc.get("text", "").strip(): + all_text_parts.append(hc["text"].strip()) + + # Remove all cells for this row, replace with one spanning cell + z["cells"] = [c for c in z["cells"] if c.get("row_index") != hri] + + if all_wb: + x_min = min(wb["left"] for wb in all_wb) + y_min = min(wb["top"] for wb in all_wb) + x_max = max(wb["left"] + wb["width"] for wb in all_wb) + y_max = max(wb["top"] + wb["height"] for wb in all_wb) + + # Use the actual starting col_index from the first cell + first_col = min(hc["col_index"] for hc in header_cells) + zone_idx = z.get("zone_index", 0) + z["cells"].append({ + "cell_id": f"Z{zone_idx}_R{hri:02d}_C{first_col}", + "zone_index": zone_idx, + "row_index": hri, + "col_index": first_col, + "col_type": "heading", + "text": " ".join(all_text_parts), + "confidence": 0.0, + "bbox_px": {"x": x_min, "y": y_min, + "w": x_max - x_min, "h": y_max - y_min}, + "bbox_pct": { + "x": round(x_min / img_w * 100, 2) if img_w else 0, + "y": round(y_min / img_h * 100, 2) if img_h else 0, + "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, + "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, + }, + "word_boxes": all_wb, + "ocr_engine": "words_first", + "is_bold": True, + }) + + # Mark row as header + for row in rows: + if row["index"] == hri: + row["is_header"] = True + heading_count += 1 + + return heading_count + + +def _detect_heading_rows_by_single_cell( + zones_data: List[Dict], img_w: int, img_h: int, +) -> int: + """Detect heading rows that have only a single content cell. + + Black headings like "Theme" have normal color and height, so they are + missed by ``_detect_heading_rows_by_color``. The distinguishing signal + is that they occupy only one column while normal vocabulary rows fill + at least 2-3 columns. + + A row qualifies as a heading if: + 1. It is not already marked as a header/heading. + 2. It has exactly ONE cell whose col_type starts with ``column_`` + (excluding column_1 / page_ref which only carries page numbers). + 3. That single cell is NOT in the last column (continuation/example + lines like "2. Ver\u00e4nderung, Wechsel" often sit alone in column_4). + 4. The text does not start with ``[`` (IPA continuation). + 5. The zone has >=3 columns and >=5 rows (avoids false positives in + tiny zones). + 6. The majority of rows in the zone have >=2 content cells (ensures + we are in a multi-column vocab layout). + """ + heading_count = 0 + + for z in zones_data: + cells = z.get("cells", []) + rows = z.get("rows", []) + columns = z.get("columns", []) + if len(columns) < 3 or len(rows) < 5: + continue + + # Determine the last col_index (example/sentence column) + col_indices = sorted(set(c.get("col_index", 0) for c in cells)) + if not col_indices: + continue + last_col = col_indices[-1] + + # Count content cells per row (column_* but not column_1/page_ref). + # Exception: column_1 cells that contain a dictionary article word + # (die/der/das etc.) ARE content -- they appear in dictionary layouts + # where the leftmost column holds grammatical articles. + _ARTICLE_WORDS = { + "die", "der", "das", "dem", "den", "des", "ein", "eine", + "the", "a", "an", + } + row_content_counts: Dict[int, int] = {} + for cell in cells: + ct = cell.get("col_type", "") + if not ct.startswith("column_"): + continue + if ct == "column_1": + ctext = (cell.get("text") or "").strip().lower() + if ctext not in _ARTICLE_WORDS: + continue + ri = cell.get("row_index", -1) + row_content_counts[ri] = row_content_counts.get(ri, 0) + 1 + + # Majority of rows must have >=2 content cells + multi_col_rows = sum(1 for cnt in row_content_counts.values() if cnt >= 2) + if multi_col_rows < len(rows) * 0.4: + continue + + # Exclude first and last non-header rows -- these are typically + # page numbers or footer text, not headings. + non_header_rows = [r for r in rows if not r.get("is_header")] + if len(non_header_rows) < 3: + continue + first_ri = non_header_rows[0]["index"] + last_ri = non_header_rows[-1]["index"] + + heading_row_indices = [] + for row in rows: + if row.get("is_header"): + continue + ri = row["index"] + if ri == first_ri or ri == last_ri: + continue + row_cells = [c for c in cells if c.get("row_index") == ri] + content_cells = [ + c for c in row_cells + if c.get("col_type", "").startswith("column_") + and (c.get("col_type") != "column_1" + or (c.get("text") or "").strip().lower() in _ARTICLE_WORDS) + ] + if len(content_cells) != 1: + continue + cell = content_cells[0] + # Not in the last column (continuation/example lines) + if cell.get("col_index") == last_col: + continue + text = (cell.get("text") or "").strip() + if not text or text.startswith("["): + continue + # Continuation lines start with "(" -- e.g. "(usw.)", "(TV-Serie)" + if text.startswith("("): + continue + # Single cell NOT in the first content column is likely a + # continuation/overflow line, not a heading. Real headings + # ("Theme 1", "Unit 3: ...") appear in the first or second + # content column. + first_content_col = col_indices[0] if col_indices else 0 + if cell.get("col_index", 0) > first_content_col + 1: + continue + # Skip garbled IPA without brackets (e.g. "ska:f -- ska:vz") + # but NOT text with real IPA symbols (e.g. "Theme [\u03b8\u02c8i\u02d0m]") + _REAL_IPA_CHARS = set("\u02c8\u02cc\u0259\u026a\u025b\u0252\u028a\u028c\u00e6\u0251\u0254\u0283\u0292\u03b8\u00f0\u014b") + if _text_has_garbled_ipa(text) and not any(c in _REAL_IPA_CHARS for c in text): + continue + # Guard: dictionary section headings are short (1-4 alpha chars + # like "A", "Ab", "Zi", "Sch"). Longer text that starts + # lowercase is a regular vocabulary word (e.g. "zentral") that + # happens to appear alone in its row. + alpha_only = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', text) + if len(alpha_only) > 4 and text[0].islower(): + continue + heading_row_indices.append(ri) + + # Guard: if >25% of eligible rows would become headings, the + # heuristic is misfiring (e.g. sparse single-column layout where + # most rows naturally have only 1 content cell). + eligible_rows = len(non_header_rows) - 2 # minus first/last excluded + if eligible_rows > 0 and len(heading_row_indices) > eligible_rows * 0.25: + logger.debug( + "Skipping single-cell heading detection for zone %s: " + "%d/%d rows would be headings (>25%%)", + z.get("zone_index"), len(heading_row_indices), eligible_rows, + ) + continue + + for hri in heading_row_indices: + header_cells = [c for c in cells if c.get("row_index") == hri] + if not header_cells: + continue + + # Collect all word_boxes and text + all_wb = [] + all_text_parts = [] + for hc in sorted(header_cells, key=lambda c: c["col_index"]): + all_wb.extend(hc.get("word_boxes", [])) + if hc.get("text", "").strip(): + all_text_parts.append(hc["text"].strip()) + + first_col_idx = min(hc["col_index"] for hc in header_cells) + + # Remove old cells for this row, add spanning heading cell + z["cells"] = [c for c in z["cells"] if c.get("row_index") != hri] + + if all_wb: + x_min = min(wb["left"] for wb in all_wb) + y_min = min(wb["top"] for wb in all_wb) + x_max = max(wb["left"] + wb["width"] for wb in all_wb) + y_max = max(wb["top"] + wb["height"] for wb in all_wb) + else: + # Fallback to first cell bbox + bp = header_cells[0].get("bbox_px", {}) + x_min = bp.get("x", 0) + y_min = bp.get("y", 0) + x_max = x_min + bp.get("w", 0) + y_max = y_min + bp.get("h", 0) + + zone_idx = z.get("zone_index", 0) + z["cells"].append({ + "cell_id": f"Z{zone_idx}_R{hri:02d}_C{first_col_idx}", + "zone_index": zone_idx, + "row_index": hri, + "col_index": first_col_idx, + "col_type": "heading", + "text": " ".join(all_text_parts), + "confidence": 0.0, + "bbox_px": {"x": x_min, "y": y_min, + "w": x_max - x_min, "h": y_max - y_min}, + "bbox_pct": { + "x": round(x_min / img_w * 100, 2) if img_w else 0, + "y": round(y_min / img_h * 100, 2) if img_h else 0, + "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, + "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, + }, + "word_boxes": all_wb, + "ocr_engine": "words_first", + "is_bold": False, + }) + + for row in rows: + if row["index"] == hri: + row["is_header"] = True + heading_count += 1 + + return heading_count + + +def _detect_header_rows( + rows: List[Dict], + zone_words: List[Dict], + zone_y: int, + columns: Optional[List[Dict]] = None, + skip_first_row_header: bool = False, +) -> List[int]: + """Detect header rows: first-row heuristic + spanning header detection. + + A "spanning header" is a row whose words stretch across multiple column + boundaries (e.g. "Unit4: Bonnie Scotland" centred across 4 columns). + """ + if len(rows) < 2: + return [] + + headers = [] + + if not skip_first_row_header: + first_row = rows[0] + second_row = rows[1] + + # Gap between first and second row > 0.5x average row height + avg_h = sum(r["y_max"] - r["y_min"] for r in rows) / len(rows) + gap = second_row["y_min"] - first_row["y_max"] + if gap > avg_h * 0.5: + headers.append(0) + + # Also check if first row words are taller than average (bold/header text) + all_heights = [w["height"] for w in zone_words] + median_h = sorted(all_heights)[len(all_heights) // 2] if all_heights else 20 + first_row_words = [ + w for w in zone_words + if first_row["y_min"] <= w["top"] + w["height"] / 2 <= first_row["y_max"] + ] + if first_row_words: + first_h = max(w["height"] for w in first_row_words) + if first_h > median_h * 1.3: + if 0 not in headers: + headers.append(0) + + # Note: Spanning-header detection (rows spanning all columns) has been + # disabled because it produces too many false positives on vocabulary + # worksheets where IPA transcriptions or short entries naturally span + # multiple columns with few words. The first-row heuristic above is + # sufficient for detecting real headers. + + return headers + + +def _detect_colspan_cells( + zone_words: List[Dict], + columns: List[Dict], + rows: List[Dict], + cells: List[Dict], + img_w: int, + img_h: int, +) -> List[Dict]: + """Detect and merge cells that span multiple columns (colspan). + + A word-block (PaddleOCR phrase) that extends significantly past a column + boundary into the next column indicates a merged cell. This replaces + the incorrectly split cells with a single cell spanning multiple columns. + + Works for both full-page scans and box zones. + """ + if len(columns) < 2 or not zone_words or not rows: + return cells + + from cv_words_first import _assign_word_to_row + + # Column boundaries (midpoints between adjacent columns) + col_boundaries = [] + for ci in range(len(columns) - 1): + col_boundaries.append((columns[ci]["x_max"] + columns[ci + 1]["x_min"]) / 2) + + def _cols_covered(w_left: float, w_right: float) -> List[int]: + """Return list of column indices that a word-block covers.""" + covered = [] + for col in columns: + col_mid = (col["x_min"] + col["x_max"]) / 2 + # Word covers a column if it extends past the column's midpoint + if w_left < col_mid < w_right: + covered.append(col["index"]) + # Also include column if word starts within it + elif col["x_min"] <= w_left < col["x_max"]: + covered.append(col["index"]) + return sorted(set(covered)) + + # Group original word-blocks by row + row_word_blocks: Dict[int, List[Dict]] = {} + for w in zone_words: + ri = _assign_word_to_row(w, rows) + row_word_blocks.setdefault(ri, []).append(w) + + # For each row, check if any word-block spans multiple columns + rows_to_merge: Dict[int, List[Dict]] = {} # row_index -> list of spanning word-blocks + + for ri, wblocks in row_word_blocks.items(): + spanning = [] + for w in wblocks: + w_left = w["left"] + w_right = w_left + w["width"] + covered = _cols_covered(w_left, w_right) + if len(covered) >= 2: + spanning.append({"word": w, "cols": covered}) + if spanning: + rows_to_merge[ri] = spanning + + if not rows_to_merge: + return cells + + # Merge cells for spanning rows + new_cells = [] + for cell in cells: + ri = cell.get("row_index", -1) + if ri not in rows_to_merge: + new_cells.append(cell) + continue + + # Check if this cell's column is part of a spanning block + ci = cell.get("col_index", -1) + is_part_of_span = False + for span in rows_to_merge[ri]: + if ci in span["cols"]: + is_part_of_span = True + # Only emit the merged cell for the FIRST column in the span + if ci == span["cols"][0]: + # Use the ORIGINAL word-block text (not the split cell texts + # which may have broken words like "euros a" + "nd cents") + orig_word = span["word"] + merged_text = orig_word.get("text", "").strip() + all_wb = [orig_word] + + # Compute merged bbox + if all_wb: + x_min = min(wb["left"] for wb in all_wb) + y_min = min(wb["top"] for wb in all_wb) + x_max = max(wb["left"] + wb["width"] for wb in all_wb) + y_max = max(wb["top"] + wb["height"] for wb in all_wb) + else: + x_min = y_min = x_max = y_max = 0 + + new_cells.append({ + "cell_id": cell["cell_id"], + "row_index": ri, + "col_index": span["cols"][0], + "col_type": "spanning_header", + "colspan": len(span["cols"]), + "text": merged_text, + "confidence": cell.get("confidence", 0), + "bbox_px": {"x": x_min, "y": y_min, + "w": x_max - x_min, "h": y_max - y_min}, + "bbox_pct": { + "x": round(x_min / img_w * 100, 2) if img_w else 0, + "y": round(y_min / img_h * 100, 2) if img_h else 0, + "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, + "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, + }, + "word_boxes": all_wb, + "ocr_engine": cell.get("ocr_engine", ""), + "is_bold": cell.get("is_bold", False), + }) + logger.info( + "colspan detected: row %d, cols %s -> merged %d cells (%r)", + ri, span["cols"], len(span["cols"]), merged_text[:50], + ) + break + if not is_part_of_span: + new_cells.append(cell) + + return new_cells diff --git a/klausur-service/backend/grid_editor_helpers.py b/klausur-service/backend/grid_editor_helpers.py index e126dba..c75e161 100644 --- a/klausur-service/backend/grid_editor_helpers.py +++ b/klausur-service/backend/grid_editor_helpers.py @@ -1,1737 +1,58 @@ """ -Grid Editor helper functions — filters, detectors, and zone grid building. +Grid Editor helper functions — barrel re-export module. -Extracted from grid_editor_api.py for maintainability. -All functions are pure computation — no HTTP, DB, or session side effects. +This file re-exports all public symbols from the split sub-modules +so that existing ``from grid_editor_helpers import ...`` statements +continue to work without changes. + +Sub-modules: + - grid_editor_columns — column detection, cross-column splitting, marker merging + - grid_editor_filters — word/zone filtering, border ghosts, decorative margins + - grid_editor_headers — header/heading detection, colspan detection + - grid_editor_zones — vertical dividers, zone splitting/merging, zone grid building Lizenz: Apache 2.0 (kommerziell nutzbar) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import logging -import re -from typing import Any, Dict, List, Optional, Tuple - -import cv2 -import numpy as np - -from cv_vocab_types import PageZone -from cv_words_first import _cluster_rows, _build_cells -from cv_ocr_engines import _text_has_garbled_ipa - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Cross-column word splitting -# --------------------------------------------------------------------------- - -_spell_cache: Optional[Any] = None -_spell_loaded = False - - -def _is_recognized_word(text: str) -> bool: - """Check if *text* is a recognized German or English word. - - Uses the spellchecker library (same as cv_syllable_detect.py). - Returns True for real words like "oder", "Kabel", "Zeitung". - Returns False for OCR merge artifacts like "sichzie", "dasZimmer". - """ - global _spell_cache, _spell_loaded - if not text or len(text) < 2: - return False - - if not _spell_loaded: - _spell_loaded = True - try: - from spellchecker import SpellChecker - _spell_cache = SpellChecker(language="de") - except Exception: - pass - - if _spell_cache is None: - return False - - return text.lower() in _spell_cache - - -def _split_cross_column_words( - words: List[Dict], - columns: List[Dict], -) -> List[Dict]: - """Split word boxes that span across column boundaries. - - When OCR merges adjacent words from different columns (e.g. "sichzie" - spanning Col 1 and Col 2, or "dasZimmer" crossing the boundary), - split the word box at the column boundary so each piece is assigned - to the correct column. - - Only splits when: - - The word has significant overlap (>15% of its width) on both sides - - AND the word is not a recognized real word (OCR merge artifact), OR - the word contains a case transition (lowercase→uppercase) near the - boundary indicating two merged words like "dasZimmer". - """ - if len(columns) < 2: - return words - - # Column boundaries = midpoints between adjacent column edges - boundaries = [] - for i in range(len(columns) - 1): - boundary = (columns[i]["x_max"] + columns[i + 1]["x_min"]) / 2 - boundaries.append(boundary) - - new_words: List[Dict] = [] - split_count = 0 - for w in words: - w_left = w["left"] - w_width = w["width"] - w_right = w_left + w_width - text = (w.get("text") or "").strip() - - if not text or len(text) < 4 or w_width < 10: - new_words.append(w) - continue - - # Find the first boundary this word straddles significantly - split_boundary = None - for b in boundaries: - if w_left < b < w_right: - left_part = b - w_left - right_part = w_right - b - # Both sides must have at least 15% of the word width - if left_part > w_width * 0.15 and right_part > w_width * 0.15: - split_boundary = b - break - - if split_boundary is None: - new_words.append(w) - continue - - # Compute approximate split position in the text. - left_width = split_boundary - w_left - split_ratio = left_width / w_width - approx_pos = len(text) * split_ratio - - # Strategy 1: look for a case transition (lowercase→uppercase) near - # the approximate split point — e.g. "dasZimmer" splits at 'Z'. - split_char = None - search_lo = max(1, int(approx_pos) - 3) - search_hi = min(len(text), int(approx_pos) + 2) - for i in range(search_lo, search_hi): - if text[i - 1].islower() and text[i].isupper(): - split_char = i - break - - # Strategy 2: if no case transition, only split if the whole word - # is NOT a real word (i.e. it's an OCR merge artifact like "sichzie"). - # Real words like "oder", "Kabel", "Zeitung" must not be split. - if split_char is None: - clean = re.sub(r"[,;:.!?]+$", "", text) # strip trailing punct - if _is_recognized_word(clean): - new_words.append(w) - continue - # Not a real word — use floor of proportional position - split_char = max(1, min(len(text) - 1, int(approx_pos))) - - left_text = text[:split_char].rstrip() - right_text = text[split_char:].lstrip() - - if len(left_text) < 2 or len(right_text) < 2: - new_words.append(w) - continue - - right_width = w_width - round(left_width) - new_words.append({ - **w, - "text": left_text, - "width": round(left_width), - }) - new_words.append({ - **w, - "text": right_text, - "left": round(split_boundary), - "width": right_width, - }) - split_count += 1 - logger.info( - "split cross-column word %r → %r + %r at boundary %.0f", - text, left_text, right_text, split_boundary, - ) - - if split_count: - logger.info("split %d cross-column word(s)", split_count) - return new_words - - -def _filter_border_strip_words(words: List[Dict]) -> Tuple[List[Dict], int]: - """Remove page-border decoration strip words BEFORE column detection. - - Scans from each page edge inward to find the first significant x-gap - (>30 px). If the edge cluster contains <15 % of total words, those - words are removed as border-strip artifacts (alphabet letters, - illustration fragments). - - Must run BEFORE ``_build_zone_grid`` so that column detection only - sees real content words and doesn't produce inflated row counts. - """ - if len(words) < 10: - return words, 0 - - sorted_words = sorted(words, key=lambda w: w.get("left", 0)) - total = len(sorted_words) - - # -- Left-edge scan (running max right-edge) -- - left_count = 0 - running_right = 0 - for gi in range(total - 1): - running_right = max( - running_right, - sorted_words[gi].get("left", 0) + sorted_words[gi].get("width", 0), - ) - if sorted_words[gi + 1].get("left", 0) - running_right > 30: - left_count = gi + 1 - break - - # -- Right-edge scan (running min left) -- - right_count = 0 - running_left = sorted_words[-1].get("left", 0) - for gi in range(total - 1, 0, -1): - running_left = min(running_left, sorted_words[gi].get("left", 0)) - prev_right = ( - sorted_words[gi - 1].get("left", 0) - + sorted_words[gi - 1].get("width", 0) - ) - if running_left - prev_right > 30: - right_count = total - gi - break - - # Validate candidate strip: real border decorations are mostly short - # words (alphabet letters like "A", "Bb", stray marks). Multi-word - # content like "der Ranzen" or "die Schals" (continuation of German - # translations) must NOT be removed. - def _is_decorative_strip(candidates: List[Dict]) -> bool: - if not candidates: - return False - short = sum(1 for w in candidates if len((w.get("text") or "").strip()) <= 2) - return short / len(candidates) >= 0.45 - - strip_ids: set = set() - if left_count > 0 and left_count / total < 0.20: - candidates = sorted_words[:left_count] - if _is_decorative_strip(candidates): - strip_ids = {id(w) for w in candidates} - elif right_count > 0 and right_count / total < 0.20: - candidates = sorted_words[total - right_count:] - if _is_decorative_strip(candidates): - strip_ids = {id(w) for w in candidates} - - if not strip_ids: - return words, 0 - - return [w for w in words if id(w) not in strip_ids], len(strip_ids) - - -def _cluster_columns_by_alignment( - words: List[Dict], - zone_w: int, - rows: List[Dict], -) -> List[Dict[str, Any]]: - """Detect columns by clustering left-edge alignment across rows. - - Hybrid approach: - 1. Group words by row, find "group start" positions within each row - (words preceded by a large gap or first word in row) - 2. Cluster group-start left-edges by X-proximity across rows - 3. Filter by row coverage (how many rows have a group start here) - 4. Merge nearby clusters - 5. Build column boundaries - - This filters out mid-phrase word positions (e.g. IPA transcriptions, - second words in multi-word entries) by only considering positions - where a new word group begins within a row. - """ - if not words or not rows: - return [] - - total_rows = len(rows) - if total_rows == 0: - return [] - - # --- Group words by row --- - row_words: Dict[int, List[Dict]] = {} - for w in words: - y_center = w["top"] + w["height"] / 2 - best = min(rows, key=lambda r: abs(r["y_center"] - y_center)) - row_words.setdefault(best["index"], []).append(w) - - # --- Compute adaptive gap threshold for group-start detection --- - all_gaps: List[float] = [] - for ri, rw_list in row_words.items(): - sorted_rw = sorted(rw_list, key=lambda w: w["left"]) - for i in range(len(sorted_rw) - 1): - right = sorted_rw[i]["left"] + sorted_rw[i]["width"] - gap = sorted_rw[i + 1]["left"] - right - if gap > 0: - all_gaps.append(gap) - - if all_gaps: - sorted_gaps = sorted(all_gaps) - median_gap = sorted_gaps[len(sorted_gaps) // 2] - heights = [w["height"] for w in words if w.get("height", 0) > 0] - median_h = sorted(heights)[len(heights) // 2] if heights else 25 - - # For small word counts (boxes, sub-zones): PaddleOCR returns - # multi-word blocks, so ALL inter-word gaps are potential column - # boundaries. Use a low threshold based on word height — any gap - # wider than ~1x median word height is a column separator. - if len(words) <= 60: - gap_threshold = max(median_h * 1.0, 25) - logger.info( - "alignment columns (small zone): gap_threshold=%.0f " - "(median_h=%.0f, %d words, %d gaps: %s)", - gap_threshold, median_h, len(words), len(sorted_gaps), - [int(g) for g in sorted_gaps[:10]], - ) - else: - # Standard approach for large zones (full pages) - gap_threshold = max(median_gap * 3, median_h * 1.5, 30) - # Cap at 25% of zone width - max_gap = zone_w * 0.25 - if gap_threshold > max_gap > 30: - logger.info("alignment columns: capping gap_threshold %.0f → %.0f (25%% of zone_w=%d)", gap_threshold, max_gap, zone_w) - gap_threshold = max_gap - else: - gap_threshold = 50 - - # --- Find group-start positions (left-edges that begin a new column) --- - start_positions: List[tuple] = [] # (left_edge, row_index) - for ri, rw_list in row_words.items(): - sorted_rw = sorted(rw_list, key=lambda w: w["left"]) - # First word in row is always a group start - start_positions.append((sorted_rw[0]["left"], ri)) - for i in range(1, len(sorted_rw)): - right_prev = sorted_rw[i - 1]["left"] + sorted_rw[i - 1]["width"] - gap = sorted_rw[i]["left"] - right_prev - if gap >= gap_threshold: - start_positions.append((sorted_rw[i]["left"], ri)) - - start_positions.sort(key=lambda x: x[0]) - - logger.info( - "alignment columns: %d group-start positions from %d words " - "(gap_threshold=%.0f, %d rows)", - len(start_positions), len(words), gap_threshold, total_rows, - ) - - if not start_positions: - x_min = min(w["left"] for w in words) - x_max = max(w["left"] + w["width"] for w in words) - return [{"index": 0, "type": "column_text", "x_min": x_min, "x_max": x_max}] - - # --- Cluster group-start positions by X-proximity --- - tolerance = max(10, int(zone_w * 0.01)) - clusters: List[Dict[str, Any]] = [] - cur_edges = [start_positions[0][0]] - cur_rows = {start_positions[0][1]} - - for left, row_idx in start_positions[1:]: - if left - cur_edges[-1] <= tolerance: - cur_edges.append(left) - cur_rows.add(row_idx) - else: - clusters.append({ - "mean_x": int(sum(cur_edges) / len(cur_edges)), - "min_edge": min(cur_edges), - "max_edge": max(cur_edges), - "count": len(cur_edges), - "distinct_rows": len(cur_rows), - "row_coverage": len(cur_rows) / total_rows, - }) - cur_edges = [left] - cur_rows = {row_idx} - clusters.append({ - "mean_x": int(sum(cur_edges) / len(cur_edges)), - "min_edge": min(cur_edges), - "max_edge": max(cur_edges), - "count": len(cur_edges), - "distinct_rows": len(cur_rows), - "row_coverage": len(cur_rows) / total_rows, - }) - - # --- Filter by row coverage --- - # These thresholds must be high enough to avoid false columns in flowing - # text (random inter-word gaps) while still detecting real columns in - # vocabulary worksheets (which typically have >80% row coverage). - MIN_COVERAGE_PRIMARY = 0.35 - MIN_COVERAGE_SECONDARY = 0.12 - MIN_WORDS_SECONDARY = 4 - MIN_DISTINCT_ROWS = 3 - - # Content boundary for left-margin detection - content_x_min = min(w["left"] for w in words) - content_x_max = max(w["left"] + w["width"] for w in words) - content_span = content_x_max - content_x_min - - primary = [ - c for c in clusters - if c["row_coverage"] >= MIN_COVERAGE_PRIMARY - and c["distinct_rows"] >= MIN_DISTINCT_ROWS - ] - primary_ids = {id(c) for c in primary} - secondary = [ - c for c in clusters - if id(c) not in primary_ids - and c["row_coverage"] >= MIN_COVERAGE_SECONDARY - and c["count"] >= MIN_WORDS_SECONDARY - and c["distinct_rows"] >= MIN_DISTINCT_ROWS - ] - - # Tertiary: narrow left-margin columns (page refs, markers) that have - # too few rows for secondary but are clearly left-aligned and separated - # from the main content. These appear at the far left or far right and - # have a large gap to the nearest significant cluster. - used_ids = {id(c) for c in primary} | {id(c) for c in secondary} - sig_xs = [c["mean_x"] for c in primary + secondary] - - # Tertiary: clusters that are clearly to the LEFT of the first - # significant column (or RIGHT of the last). If words consistently - # start at a position left of the established first column boundary, - # they MUST be a separate column — regardless of how few rows they - # cover. The only requirement is a clear spatial gap. - MIN_COVERAGE_TERTIARY = 0.02 # at least 1 row effectively - tertiary = [] - for c in clusters: - if id(c) in used_ids: - continue - if c["distinct_rows"] < 1: - continue - if c["row_coverage"] < MIN_COVERAGE_TERTIARY: - continue - # Must be near left or right content margin (within 15%) - rel_pos = (c["mean_x"] - content_x_min) / content_span if content_span else 0.5 - if not (rel_pos < 0.15 or rel_pos > 0.85): - continue - # Must have significant gap to nearest significant cluster - if sig_xs: - min_dist = min(abs(c["mean_x"] - sx) for sx in sig_xs) - if min_dist < max(30, content_span * 0.02): - continue - tertiary.append(c) - - if tertiary: - for c in tertiary: - logger.info( - " tertiary (margin) cluster: x=%d (range %d-%d), %d words, %d rows (%.0f%%)", - c["mean_x"], c["min_edge"], c["max_edge"], - c["count"], c["distinct_rows"], c["row_coverage"] * 100, - ) - - significant = sorted(primary + secondary + tertiary, key=lambda c: c["mean_x"]) - - for c in significant: - logger.info( - " significant cluster: x=%d (range %d-%d), %d words, %d rows (%.0f%%)", - c["mean_x"], c["min_edge"], c["max_edge"], - c["count"], c["distinct_rows"], c["row_coverage"] * 100, - ) - logger.info( - "alignment columns: %d clusters, %d primary, %d secondary → %d significant", - len(clusters), len(primary), len(secondary), len(significant), - ) - - if not significant: - # Fallback: single column covering all content - x_min = min(w["left"] for w in words) - x_max = max(w["left"] + w["width"] for w in words) - return [{"index": 0, "type": "column_text", "x_min": x_min, "x_max": x_max}] - - # --- Merge nearby clusters --- - merge_distance = max(25, int(zone_w * 0.03)) - merged = [significant[0].copy()] - for s in significant[1:]: - if s["mean_x"] - merged[-1]["mean_x"] < merge_distance: - prev = merged[-1] - total = prev["count"] + s["count"] - prev["mean_x"] = ( - prev["mean_x"] * prev["count"] + s["mean_x"] * s["count"] - ) // total - prev["count"] = total - prev["min_edge"] = min(prev["min_edge"], s["min_edge"]) - prev["max_edge"] = max(prev["max_edge"], s["max_edge"]) - prev["distinct_rows"] = max(prev["distinct_rows"], s["distinct_rows"]) - else: - merged.append(s.copy()) - - logger.info( - "alignment columns: %d after merge (distance=%d)", - len(merged), merge_distance, - ) - - # --- Build column boundaries --- - margin = max(5, int(zone_w * 0.005)) - content_x_min = min(w["left"] for w in words) - content_x_max = max(w["left"] + w["width"] for w in words) - - columns: List[Dict[str, Any]] = [] - for i, cluster in enumerate(merged): - x_min = max(content_x_min, cluster["min_edge"] - margin) - if i + 1 < len(merged): - x_max = merged[i + 1]["min_edge"] - margin - else: - x_max = content_x_max - - columns.append({ - "index": i, - "type": f"column_{i + 1}" if len(merged) > 1 else "column_text", - "x_min": x_min, - "x_max": x_max, - }) - - return columns - - -# Characters that are typically OCR artefacts from box border lines. -# Intentionally excludes ! (red markers) and . , ; (real punctuation). -_GRID_GHOST_CHARS = set("|1lI[](){}/\\-—–_~=+") - - -def _filter_border_ghosts( - words: List[Dict], - boxes: List, -) -> tuple: - """Remove words sitting on box borders that are OCR artefacts. - - Returns (filtered_words, removed_count). - """ - if not boxes or not words: - return words, 0 - - # Build border bands from detected boxes - x_bands: List[tuple] = [] - y_bands: List[tuple] = [] - for b in boxes: - bt = ( - b.border_thickness - if hasattr(b, "border_thickness") - else b.get("border_thickness", 3) - ) - # Skip borderless boxes (images/graphics) — no border line to produce ghosts - if bt == 0: - continue - bx = b.x if hasattr(b, "x") else b.get("x", 0) - by = b.y if hasattr(b, "y") else b.get("y", 0) - bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) - bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) - margin = max(bt * 2, 10) + 6 - x_bands.append((bx - margin, bx + margin)) - x_bands.append((bx + bw - margin, bx + bw + margin)) - y_bands.append((by - margin, by + margin)) - y_bands.append((by + bh - margin, by + bh + margin)) - - def _is_ghost(w: Dict) -> bool: - text = (w.get("text") or "").strip() - if not text: - return False - # Check if any word edge (not just center) touches a border band - w_left = w["left"] - w_right = w["left"] + w["width"] - w_top = w["top"] - w_bottom = w["top"] + w["height"] - on_border = ( - any(lo <= w_left <= hi or lo <= w_right <= hi for lo, hi in x_bands) - or any(lo <= w_top <= hi or lo <= w_bottom <= hi for lo, hi in y_bands) - ) - if not on_border: - return False - if len(text) == 1 and text in _GRID_GHOST_CHARS: - return True - return False - - filtered = [w for w in words if not _is_ghost(w)] - return filtered, len(words) - len(filtered) - - -_MARKER_CHARS = set("•*·-–—|~=+#>→►▸▪◆○●□■✓✗✔✘") - - -def _merge_inline_marker_columns( - columns: List[Dict], - words: List[Dict], -) -> List[Dict]: - """Merge narrow marker columns (bullets, numbering) into adjacent text. - - Bullet points (•, *, -) and numbering (1., 2.) create narrow columns - at the left edge of a zone. These are inline markers that indent text, - not real separate columns. Merge them with their right neighbour. - - Does NOT merge columns containing alphabetic words like "to", "in", - "der", "die", "das" — those are legitimate content columns. - """ - if len(columns) < 2: - return columns - - merged: List[Dict] = [] - skip: set = set() - - for i, col in enumerate(columns): - if i in skip: - continue - - # Find words in this column - col_words = [ - w for w in words - if col["x_min"] <= w["left"] + w["width"] / 2 < col["x_max"] - ] - col_width = col["x_max"] - col["x_min"] - - # Narrow column with mostly short words → MIGHT be inline markers - if col_words and col_width < 80: - avg_len = sum(len(w.get("text", "")) for w in col_words) / len(col_words) - if avg_len <= 2 and i + 1 < len(columns): - # Check if words are actual markers (symbols/numbers) vs - # real alphabetic words like "to", "in", "der", "die" - texts = [(w.get("text") or "").strip() for w in col_words] - alpha_count = sum( - 1 for t in texts - if t and t[0].isalpha() and t not in _MARKER_CHARS - ) - alpha_ratio = alpha_count / len(texts) if texts else 0 - - # If ≥50% of words are alphabetic, this is a real column - if alpha_ratio >= 0.5: - logger.info( - " kept narrow column %d (w=%d, avg_len=%.1f, " - "alpha=%.0f%%) — contains real words", - i, col_width, avg_len, alpha_ratio * 100, - ) - else: - # Merge into next column - next_col = columns[i + 1].copy() - next_col["x_min"] = col["x_min"] - merged.append(next_col) - skip.add(i + 1) - logger.info( - " merged inline marker column %d (w=%d, avg_len=%.1f) " - "into column %d", - i, col_width, avg_len, i + 1, - ) - continue - - merged.append(col) - - # Re-index - for i, col in enumerate(merged): - col["index"] = i - col["type"] = f"column_{i + 1}" if len(merged) > 1 else "column_text" - - return merged - - -def _flatten_word_boxes(cells: List[Dict]) -> List[Dict]: - """Extract all word_boxes from cells into a flat list of word dicts.""" - words: List[Dict] = [] - for cell in cells: - for wb in cell.get("word_boxes") or []: - if wb.get("text", "").strip(): - words.append({ - "text": wb["text"], - "left": wb["left"], - "top": wb["top"], - "width": wb["width"], - "height": wb["height"], - "conf": wb.get("conf", 0), - }) - return words - - -def _words_in_zone( - words: List[Dict], - zone_y: int, - zone_h: int, - zone_x: int, - zone_w: int, -) -> List[Dict]: - """Filter words whose Y-center falls within a zone's bounds.""" - zone_y_end = zone_y + zone_h - zone_x_end = zone_x + zone_w - result = [] - for w in words: - cy = w["top"] + w["height"] / 2 - cx = w["left"] + w["width"] / 2 - if zone_y <= cy <= zone_y_end and zone_x <= cx <= zone_x_end: - result.append(w) - return result - - -# --------------------------------------------------------------------------- -# Vertical divider detection and zone splitting -# --------------------------------------------------------------------------- - -_PIPE_RE_VSPLIT = re.compile(r"^\|+$") - - -def _detect_vertical_dividers( - words: List[Dict], - zone_x: int, - zone_w: int, - zone_y: int, - zone_h: int, -) -> List[float]: - """Detect vertical divider lines from pipe word_boxes at consistent x. - - Returns list of divider x-positions (empty if no dividers found). - """ - if not words or zone_w <= 0 or zone_h <= 0: - return [] - - # Collect pipe word_boxes - pipes = [ - w for w in words - if _PIPE_RE_VSPLIT.match((w.get("text") or "").strip()) - ] - if len(pipes) < 5: - return [] - - # Cluster pipe x-centers by proximity - tolerance = max(15, int(zone_w * 0.02)) - pipe_xs = sorted(w["left"] + w["width"] / 2 for w in pipes) - - clusters: List[List[float]] = [[pipe_xs[0]]] - for x in pipe_xs[1:]: - if x - clusters[-1][-1] <= tolerance: - clusters[-1].append(x) - else: - clusters.append([x]) - - dividers: List[float] = [] - for cluster in clusters: - if len(cluster) < 5: - continue - mean_x = sum(cluster) / len(cluster) - # Must be between 15% and 85% of zone width - rel_pos = (mean_x - zone_x) / zone_w - if rel_pos < 0.15 or rel_pos > 0.85: - continue - # Check vertical coverage: pipes must span >= 50% of zone height - cluster_pipes = [ - w for w in pipes - if abs(w["left"] + w["width"] / 2 - mean_x) <= tolerance - ] - ys = [w["top"] for w in cluster_pipes] + [w["top"] + w["height"] for w in cluster_pipes] - y_span = max(ys) - min(ys) if ys else 0 - if y_span < zone_h * 0.5: - continue - dividers.append(mean_x) - - return sorted(dividers) - - -def _split_zone_at_vertical_dividers( - zone: "PageZone", - divider_xs: List[float], - vsplit_group_id: int, -) -> List["PageZone"]: - """Split a PageZone at vertical divider positions into sub-zones.""" - from cv_vocab_types import PageZone - - boundaries = [zone.x] + divider_xs + [zone.x + zone.width] - hints = [] - for i in range(len(boundaries) - 1): - if i == 0: - hints.append("left_of_vsplit") - elif i == len(boundaries) - 2: - hints.append("right_of_vsplit") - else: - hints.append("middle_of_vsplit") - - sub_zones = [] - for i in range(len(boundaries) - 1): - x_start = int(boundaries[i]) - x_end = int(boundaries[i + 1]) - sub = PageZone( - index=0, # re-indexed later - zone_type=zone.zone_type, - y=zone.y, - height=zone.height, - x=x_start, - width=x_end - x_start, - box=zone.box, - image_overlays=zone.image_overlays, - layout_hint=hints[i], - vsplit_group=vsplit_group_id, - ) - sub_zones.append(sub) - - return sub_zones - - -def _merge_content_zones_across_boxes( - zones: List, - content_x: int, - content_w: int, -) -> List: - """Merge content zones separated by box zones into single zones. - - Box zones become image_overlays on the merged content zone. - Pattern: [content, box*, content] → [merged_content with overlay] - Box zones NOT between two content zones stay as standalone zones. - """ - if len(zones) < 3: - return zones - - # Group consecutive runs of [content, box+, content] - result: List = [] - i = 0 - while i < len(zones): - z = zones[i] - if z.zone_type != "content": - result.append(z) - i += 1 - continue - - # Start of a potential merge group: content zone - group_contents = [z] - group_boxes = [] - j = i + 1 - # Absorb [box, content] pairs — only absorb a box if it's - # confirmed to be followed by another content zone. - while j < len(zones): - if (zones[j].zone_type == "box" - and j + 1 < len(zones) - and zones[j + 1].zone_type == "content"): - group_boxes.append(zones[j]) - group_contents.append(zones[j + 1]) - j += 2 - else: - break - - if len(group_contents) >= 2 and group_boxes: - # Merge: create one large content zone spanning all - y_min = min(c.y for c in group_contents) - y_max = max(c.y + c.height for c in group_contents) - overlays = [] - for bz in group_boxes: - overlay = { - "y": bz.y, - "height": bz.height, - "x": bz.x, - "width": bz.width, - } - if bz.box: - overlay["box"] = { - "x": bz.box.x, - "y": bz.box.y, - "width": bz.box.width, - "height": bz.box.height, - "confidence": bz.box.confidence, - "border_thickness": bz.box.border_thickness, - } - overlays.append(overlay) - - merged = PageZone( - index=0, # re-indexed below - zone_type="content", - y=y_min, - height=y_max - y_min, - x=content_x, - width=content_w, - image_overlays=overlays, - ) - result.append(merged) - i = j - else: - # No merge possible — emit just the content zone - result.append(z) - i += 1 - - # Re-index zones - for idx, z in enumerate(result): - z.index = idx - - logger.info( - "zone-merge: %d zones → %d zones after merging across boxes", - len(zones), len(result), - ) - return result - - -def _detect_heading_rows_by_color(zones_data: List[Dict], img_w: int, img_h: int) -> int: - """Detect heading rows by color + height after color annotation. - - A row is a heading if: - 1. ALL word_boxes have color_name != 'black' (typically 'blue') - 2. Mean word height > 1.2x median height of all words in the zone - - Detected heading rows are merged into a single spanning cell. - Returns count of headings detected. - """ - heading_count = 0 - - for z in zones_data: - cells = z.get("cells", []) - rows = z.get("rows", []) - columns = z.get("columns", []) - if not cells or not rows or len(columns) < 2: - continue - - # Compute median word height across the zone - all_heights = [] - for cell in cells: - for wb in cell.get("word_boxes") or []: - h = wb.get("height", 0) - if h > 0: - all_heights.append(h) - if not all_heights: - continue - all_heights_sorted = sorted(all_heights) - median_h = all_heights_sorted[len(all_heights_sorted) // 2] - - heading_row_indices = [] - for row in rows: - if row.get("is_header"): - continue # already detected as header - ri = row["index"] - row_cells = [c for c in cells if c.get("row_index") == ri] - row_wbs = [ - wb for cell in row_cells - for wb in cell.get("word_boxes") or [] - ] - if not row_wbs: - continue - - # Condition 1: ALL words are non-black - all_colored = all( - wb.get("color_name", "black") != "black" - for wb in row_wbs - ) - if not all_colored: - continue - - # Condition 2: mean height > 1.2x median - mean_h = sum(wb.get("height", 0) for wb in row_wbs) / len(row_wbs) - if mean_h <= median_h * 1.2: - continue - - heading_row_indices.append(ri) - - # Merge heading cells into spanning cells - for hri in heading_row_indices: - header_cells = [c for c in cells if c.get("row_index") == hri] - if len(header_cells) <= 1: - # Single cell — just mark it as heading - if header_cells: - header_cells[0]["col_type"] = "heading" - heading_count += 1 - # Mark row as header - for row in rows: - if row["index"] == hri: - row["is_header"] = True - continue - - # Collect all word_boxes and text from all columns - all_wb = [] - all_text_parts = [] - for hc in sorted(header_cells, key=lambda c: c["col_index"]): - all_wb.extend(hc.get("word_boxes", [])) - if hc.get("text", "").strip(): - all_text_parts.append(hc["text"].strip()) - - # Remove all cells for this row, replace with one spanning cell - z["cells"] = [c for c in z["cells"] if c.get("row_index") != hri] - - if all_wb: - x_min = min(wb["left"] for wb in all_wb) - y_min = min(wb["top"] for wb in all_wb) - x_max = max(wb["left"] + wb["width"] for wb in all_wb) - y_max = max(wb["top"] + wb["height"] for wb in all_wb) - - # Use the actual starting col_index from the first cell - first_col = min(hc["col_index"] for hc in header_cells) - zone_idx = z.get("zone_index", 0) - z["cells"].append({ - "cell_id": f"Z{zone_idx}_R{hri:02d}_C{first_col}", - "zone_index": zone_idx, - "row_index": hri, - "col_index": first_col, - "col_type": "heading", - "text": " ".join(all_text_parts), - "confidence": 0.0, - "bbox_px": {"x": x_min, "y": y_min, - "w": x_max - x_min, "h": y_max - y_min}, - "bbox_pct": { - "x": round(x_min / img_w * 100, 2) if img_w else 0, - "y": round(y_min / img_h * 100, 2) if img_h else 0, - "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, - "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, - }, - "word_boxes": all_wb, - "ocr_engine": "words_first", - "is_bold": True, - }) - - # Mark row as header - for row in rows: - if row["index"] == hri: - row["is_header"] = True - heading_count += 1 - - return heading_count - - -def _detect_heading_rows_by_single_cell( - zones_data: List[Dict], img_w: int, img_h: int, -) -> int: - """Detect heading rows that have only a single content cell. - - Black headings like "Theme" have normal color and height, so they are - missed by ``_detect_heading_rows_by_color``. The distinguishing signal - is that they occupy only one column while normal vocabulary rows fill - at least 2-3 columns. - - A row qualifies as a heading if: - 1. It is not already marked as a header/heading. - 2. It has exactly ONE cell whose col_type starts with ``column_`` - (excluding column_1 / page_ref which only carries page numbers). - 3. That single cell is NOT in the last column (continuation/example - lines like "2. Veränderung, Wechsel" often sit alone in column_4). - 4. The text does not start with ``[`` (IPA continuation). - 5. The zone has ≥3 columns and ≥5 rows (avoids false positives in - tiny zones). - 6. The majority of rows in the zone have ≥2 content cells (ensures - we are in a multi-column vocab layout). - """ - heading_count = 0 - - for z in zones_data: - cells = z.get("cells", []) - rows = z.get("rows", []) - columns = z.get("columns", []) - if len(columns) < 3 or len(rows) < 5: - continue - - # Determine the last col_index (example/sentence column) - col_indices = sorted(set(c.get("col_index", 0) for c in cells)) - if not col_indices: - continue - last_col = col_indices[-1] - - # Count content cells per row (column_* but not column_1/page_ref). - # Exception: column_1 cells that contain a dictionary article word - # (die/der/das etc.) ARE content — they appear in dictionary layouts - # where the leftmost column holds grammatical articles. - _ARTICLE_WORDS = { - "die", "der", "das", "dem", "den", "des", "ein", "eine", - "the", "a", "an", - } - row_content_counts: Dict[int, int] = {} - for cell in cells: - ct = cell.get("col_type", "") - if not ct.startswith("column_"): - continue - if ct == "column_1": - ctext = (cell.get("text") or "").strip().lower() - if ctext not in _ARTICLE_WORDS: - continue - ri = cell.get("row_index", -1) - row_content_counts[ri] = row_content_counts.get(ri, 0) + 1 - - # Majority of rows must have ≥2 content cells - multi_col_rows = sum(1 for cnt in row_content_counts.values() if cnt >= 2) - if multi_col_rows < len(rows) * 0.4: - continue - - # Exclude first and last non-header rows — these are typically - # page numbers or footer text, not headings. - non_header_rows = [r for r in rows if not r.get("is_header")] - if len(non_header_rows) < 3: - continue - first_ri = non_header_rows[0]["index"] - last_ri = non_header_rows[-1]["index"] - - heading_row_indices = [] - for row in rows: - if row.get("is_header"): - continue - ri = row["index"] - if ri == first_ri or ri == last_ri: - continue - row_cells = [c for c in cells if c.get("row_index") == ri] - content_cells = [ - c for c in row_cells - if c.get("col_type", "").startswith("column_") - and (c.get("col_type") != "column_1" - or (c.get("text") or "").strip().lower() in _ARTICLE_WORDS) - ] - if len(content_cells) != 1: - continue - cell = content_cells[0] - # Not in the last column (continuation/example lines) - if cell.get("col_index") == last_col: - continue - text = (cell.get("text") or "").strip() - if not text or text.startswith("["): - continue - # Continuation lines start with "(" — e.g. "(usw.)", "(TV-Serie)" - if text.startswith("("): - continue - # Single cell NOT in the first content column is likely a - # continuation/overflow line, not a heading. Real headings - # ("Theme 1", "Unit 3: ...") appear in the first or second - # content column. - first_content_col = col_indices[0] if col_indices else 0 - if cell.get("col_index", 0) > first_content_col + 1: - continue - # Skip garbled IPA without brackets (e.g. "ska:f – ska:vz") - # but NOT text with real IPA symbols (e.g. "Theme [θˈiːm]") - _REAL_IPA_CHARS = set("ˈˌəɪɛɒʊʌæɑɔʃʒθðŋ") - if _text_has_garbled_ipa(text) and not any(c in _REAL_IPA_CHARS for c in text): - continue - # Guard: dictionary section headings are short (1-4 alpha chars - # like "A", "Ab", "Zi", "Sch"). Longer text that starts - # lowercase is a regular vocabulary word (e.g. "zentral") that - # happens to appear alone in its row. - alpha_only = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', text) - if len(alpha_only) > 4 and text[0].islower(): - continue - heading_row_indices.append(ri) - - # Guard: if >25% of eligible rows would become headings, the - # heuristic is misfiring (e.g. sparse single-column layout where - # most rows naturally have only 1 content cell). - eligible_rows = len(non_header_rows) - 2 # minus first/last excluded - if eligible_rows > 0 and len(heading_row_indices) > eligible_rows * 0.25: - logger.debug( - "Skipping single-cell heading detection for zone %s: " - "%d/%d rows would be headings (>25%%)", - z.get("zone_index"), len(heading_row_indices), eligible_rows, - ) - continue - - for hri in heading_row_indices: - header_cells = [c for c in cells if c.get("row_index") == hri] - if not header_cells: - continue - - # Collect all word_boxes and text - all_wb = [] - all_text_parts = [] - for hc in sorted(header_cells, key=lambda c: c["col_index"]): - all_wb.extend(hc.get("word_boxes", [])) - if hc.get("text", "").strip(): - all_text_parts.append(hc["text"].strip()) - - first_col_idx = min(hc["col_index"] for hc in header_cells) - - # Remove old cells for this row, add spanning heading cell - z["cells"] = [c for c in z["cells"] if c.get("row_index") != hri] - - if all_wb: - x_min = min(wb["left"] for wb in all_wb) - y_min = min(wb["top"] for wb in all_wb) - x_max = max(wb["left"] + wb["width"] for wb in all_wb) - y_max = max(wb["top"] + wb["height"] for wb in all_wb) - else: - # Fallback to first cell bbox - bp = header_cells[0].get("bbox_px", {}) - x_min = bp.get("x", 0) - y_min = bp.get("y", 0) - x_max = x_min + bp.get("w", 0) - y_max = y_min + bp.get("h", 0) - - zone_idx = z.get("zone_index", 0) - z["cells"].append({ - "cell_id": f"Z{zone_idx}_R{hri:02d}_C{first_col_idx}", - "zone_index": zone_idx, - "row_index": hri, - "col_index": first_col_idx, - "col_type": "heading", - "text": " ".join(all_text_parts), - "confidence": 0.0, - "bbox_px": {"x": x_min, "y": y_min, - "w": x_max - x_min, "h": y_max - y_min}, - "bbox_pct": { - "x": round(x_min / img_w * 100, 2) if img_w else 0, - "y": round(y_min / img_h * 100, 2) if img_h else 0, - "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, - "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, - }, - "word_boxes": all_wb, - "ocr_engine": "words_first", - "is_bold": False, - }) - - for row in rows: - if row["index"] == hri: - row["is_header"] = True - heading_count += 1 - - return heading_count - - -def _detect_header_rows( - rows: List[Dict], - zone_words: List[Dict], - zone_y: int, - columns: Optional[List[Dict]] = None, - skip_first_row_header: bool = False, -) -> List[int]: - """Detect header rows: first-row heuristic + spanning header detection. - - A "spanning header" is a row whose words stretch across multiple column - boundaries (e.g. "Unit4: Bonnie Scotland" centred across 4 columns). - """ - if len(rows) < 2: - return [] - - headers = [] - - if not skip_first_row_header: - first_row = rows[0] - second_row = rows[1] - - # Gap between first and second row > 0.5x average row height - avg_h = sum(r["y_max"] - r["y_min"] for r in rows) / len(rows) - gap = second_row["y_min"] - first_row["y_max"] - if gap > avg_h * 0.5: - headers.append(0) - - # Also check if first row words are taller than average (bold/header text) - all_heights = [w["height"] for w in zone_words] - median_h = sorted(all_heights)[len(all_heights) // 2] if all_heights else 20 - first_row_words = [ - w for w in zone_words - if first_row["y_min"] <= w["top"] + w["height"] / 2 <= first_row["y_max"] - ] - if first_row_words: - first_h = max(w["height"] for w in first_row_words) - if first_h > median_h * 1.3: - if 0 not in headers: - headers.append(0) - - # Note: Spanning-header detection (rows spanning all columns) has been - # disabled because it produces too many false positives on vocabulary - # worksheets where IPA transcriptions or short entries naturally span - # multiple columns with few words. The first-row heuristic above is - # sufficient for detecting real headers. - - return headers - - -def _detect_colspan_cells( - zone_words: List[Dict], - columns: List[Dict], - rows: List[Dict], - cells: List[Dict], - img_w: int, - img_h: int, -) -> List[Dict]: - """Detect and merge cells that span multiple columns (colspan). - - A word-block (PaddleOCR phrase) that extends significantly past a column - boundary into the next column indicates a merged cell. This replaces - the incorrectly split cells with a single cell spanning multiple columns. - - Works for both full-page scans and box zones. - """ - if len(columns) < 2 or not zone_words or not rows: - return cells - - from cv_words_first import _assign_word_to_row - - # Column boundaries (midpoints between adjacent columns) - col_boundaries = [] - for ci in range(len(columns) - 1): - col_boundaries.append((columns[ci]["x_max"] + columns[ci + 1]["x_min"]) / 2) - - def _cols_covered(w_left: float, w_right: float) -> List[int]: - """Return list of column indices that a word-block covers.""" - covered = [] - for col in columns: - col_mid = (col["x_min"] + col["x_max"]) / 2 - # Word covers a column if it extends past the column's midpoint - if w_left < col_mid < w_right: - covered.append(col["index"]) - # Also include column if word starts within it - elif col["x_min"] <= w_left < col["x_max"]: - covered.append(col["index"]) - return sorted(set(covered)) - - # Group original word-blocks by row - row_word_blocks: Dict[int, List[Dict]] = {} - for w in zone_words: - ri = _assign_word_to_row(w, rows) - row_word_blocks.setdefault(ri, []).append(w) - - # For each row, check if any word-block spans multiple columns - rows_to_merge: Dict[int, List[Dict]] = {} # row_index → list of spanning word-blocks - - for ri, wblocks in row_word_blocks.items(): - spanning = [] - for w in wblocks: - w_left = w["left"] - w_right = w_left + w["width"] - covered = _cols_covered(w_left, w_right) - if len(covered) >= 2: - spanning.append({"word": w, "cols": covered}) - if spanning: - rows_to_merge[ri] = spanning - - if not rows_to_merge: - return cells - - # Merge cells for spanning rows - new_cells = [] - for cell in cells: - ri = cell.get("row_index", -1) - if ri not in rows_to_merge: - new_cells.append(cell) - continue - - # Check if this cell's column is part of a spanning block - ci = cell.get("col_index", -1) - is_part_of_span = False - for span in rows_to_merge[ri]: - if ci in span["cols"]: - is_part_of_span = True - # Only emit the merged cell for the FIRST column in the span - if ci == span["cols"][0]: - # Use the ORIGINAL word-block text (not the split cell texts - # which may have broken words like "euros a" + "nd cents") - orig_word = span["word"] - merged_text = orig_word.get("text", "").strip() - all_wb = [orig_word] - - # Compute merged bbox - if all_wb: - x_min = min(wb["left"] for wb in all_wb) - y_min = min(wb["top"] for wb in all_wb) - x_max = max(wb["left"] + wb["width"] for wb in all_wb) - y_max = max(wb["top"] + wb["height"] for wb in all_wb) - else: - x_min = y_min = x_max = y_max = 0 - - new_cells.append({ - "cell_id": cell["cell_id"], - "row_index": ri, - "col_index": span["cols"][0], - "col_type": "spanning_header", - "colspan": len(span["cols"]), - "text": merged_text, - "confidence": cell.get("confidence", 0), - "bbox_px": {"x": x_min, "y": y_min, - "w": x_max - x_min, "h": y_max - y_min}, - "bbox_pct": { - "x": round(x_min / img_w * 100, 2) if img_w else 0, - "y": round(y_min / img_h * 100, 2) if img_h else 0, - "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, - "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, - }, - "word_boxes": all_wb, - "ocr_engine": cell.get("ocr_engine", ""), - "is_bold": cell.get("is_bold", False), - }) - logger.info( - "colspan detected: row %d, cols %s → merged %d cells (%r)", - ri, span["cols"], len(span["cols"]), merged_text[:50], - ) - break - if not is_part_of_span: - new_cells.append(cell) - - return new_cells - - -def _build_zone_grid( - zone_words: List[Dict], - zone_x: int, - zone_y: int, - zone_w: int, - zone_h: int, - zone_index: int, - img_w: int, - img_h: int, - global_columns: Optional[List[Dict]] = None, - skip_first_row_header: bool = False, -) -> Dict[str, Any]: - """Build columns, rows, cells for a single zone from its words. - - Args: - global_columns: If provided, use these pre-computed column boundaries - instead of detecting columns per zone. Used for content zones so - that all content zones (above/between/below boxes) share the same - column structure. Box zones always detect columns independently. - """ - if not zone_words: - return { - "columns": [], - "rows": [], - "cells": [], - "header_rows": [], - } - - # Cluster rows first (needed for column alignment analysis) - rows = _cluster_rows(zone_words) - - # Diagnostic logging for small/medium zones (box zones typically have 40-60 words) - if len(zone_words) <= 60: - import statistics as _st - _heights = [w['height'] for w in zone_words if w.get('height', 0) > 0] - _med_h = _st.median(_heights) if _heights else 20 - _y_tol = max(_med_h * 0.5, 5) - logger.info( - "zone %d row-clustering: %d words, median_h=%.0f, y_tol=%.1f → %d rows", - zone_index, len(zone_words), _med_h, _y_tol, len(rows), - ) - for w in sorted(zone_words, key=lambda ww: (ww['top'], ww['left'])): - logger.info( - " zone %d word: y=%d x=%d h=%d w=%d '%s'", - zone_index, w['top'], w['left'], w['height'], w['width'], - w.get('text', '')[:40], - ) - for r in rows: - logger.info( - " zone %d row %d: y_min=%d y_max=%d y_center=%.0f", - zone_index, r['index'], r['y_min'], r['y_max'], r['y_center'], - ) - - # Use global columns if provided, otherwise detect per zone - columns = global_columns if global_columns else _cluster_columns_by_alignment(zone_words, zone_w, rows) - - # Merge inline marker columns (bullets, numbering) into adjacent text - if not global_columns: - columns = _merge_inline_marker_columns(columns, zone_words) - - if not columns or not rows: - return { - "columns": [], - "rows": [], - "cells": [], - "header_rows": [], - } - - # Split word boxes that straddle column boundaries (e.g. "sichzie" - # spanning Col 1 + Col 2). Must happen after column detection and - # before cell assignment. - # Keep original words for colspan detection (split destroys span info). - original_zone_words = zone_words - if len(columns) >= 2: - zone_words = _split_cross_column_words(zone_words, columns) - - # Build cells - cells = _build_cells(zone_words, columns, rows, img_w, img_h) - - # --- Detect colspan (merged cells spanning multiple columns) --- - # Uses the ORIGINAL (pre-split) words to detect word-blocks that span - # multiple columns. _split_cross_column_words would have destroyed - # this information by cutting words at column boundaries. - if len(columns) >= 2: - cells = _detect_colspan_cells(original_zone_words, columns, rows, cells, img_w, img_h) - - # Prefix cell IDs with zone index - for cell in cells: - cell["cell_id"] = f"Z{zone_index}_{cell['cell_id']}" - cell["zone_index"] = zone_index - - # Detect header rows (pass columns for spanning header detection) - header_rows = _detect_header_rows(rows, zone_words, zone_y, columns, - skip_first_row_header=skip_first_row_header) - - # Merge cells in spanning header rows into a single col-0 cell - if header_rows and len(columns) >= 2: - for hri in header_rows: - header_cells = [c for c in cells if c["row_index"] == hri] - if len(header_cells) <= 1: - continue - # Collect all word_boxes and text from all columns - all_wb = [] - all_text_parts = [] - for hc in sorted(header_cells, key=lambda c: c["col_index"]): - all_wb.extend(hc.get("word_boxes", [])) - if hc.get("text", "").strip(): - all_text_parts.append(hc["text"].strip()) - # Remove all header cells, replace with one spanning cell - cells = [c for c in cells if c["row_index"] != hri] - if all_wb: - x_min = min(wb["left"] for wb in all_wb) - y_min = min(wb["top"] for wb in all_wb) - x_max = max(wb["left"] + wb["width"] for wb in all_wb) - y_max = max(wb["top"] + wb["height"] for wb in all_wb) - cells.append({ - "cell_id": f"R{hri:02d}_C0", - "row_index": hri, - "col_index": 0, - "col_type": "spanning_header", - "text": " ".join(all_text_parts), - "confidence": 0.0, - "bbox_px": {"x": x_min, "y": y_min, - "w": x_max - x_min, "h": y_max - y_min}, - "bbox_pct": { - "x": round(x_min / img_w * 100, 2) if img_w else 0, - "y": round(y_min / img_h * 100, 2) if img_h else 0, - "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, - "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, - }, - "word_boxes": all_wb, - "ocr_engine": "words_first", - "is_bold": True, - }) - - # Convert columns to output format with percentages - out_columns = [] - for col in columns: - x_min = col["x_min"] - x_max = col["x_max"] - out_columns.append({ - "index": col["index"], - "label": col["type"], - "x_min_px": round(x_min), - "x_max_px": round(x_max), - "x_min_pct": round(x_min / img_w * 100, 2) if img_w else 0, - "x_max_pct": round(x_max / img_w * 100, 2) if img_w else 0, - "bold": False, - }) - - # Convert rows to output format with percentages - out_rows = [] - for row in rows: - out_rows.append({ - "index": row["index"], - "y_min_px": round(row["y_min"]), - "y_max_px": round(row["y_max"]), - "y_min_pct": round(row["y_min"] / img_h * 100, 2) if img_h else 0, - "y_max_pct": round(row["y_max"] / img_h * 100, 2) if img_h else 0, - "is_header": row["index"] in header_rows, - }) - - return { - "columns": out_columns, - "rows": out_rows, - "cells": cells, - "header_rows": header_rows, - "_raw_columns": columns, # internal: for propagation to other zones - } - - -def _get_content_bounds(words: List[Dict]) -> tuple: - """Get content bounds from word positions.""" - if not words: - return 0, 0, 0, 0 - x_min = min(w["left"] for w in words) - y_min = min(w["top"] for w in words) - x_max = max(w["left"] + w["width"] for w in words) - y_max = max(w["top"] + w["height"] for w in words) - return x_min, y_min, x_max - x_min, y_max - y_min - - -def _filter_decorative_margin( - words: List[Dict], - img_w: int, - log: Any, - session_id: str, -) -> Dict[str, Any]: - """Remove words that belong to a decorative alphabet strip on a margin. - - Some vocabulary worksheets have a vertical A–Z alphabet graphic along - the left or right edge. OCR reads each letter as an isolated single- - character word. These decorative elements are not content and confuse - column/row detection. - - Detection criteria (phase 1 — find the strip using single-char words): - - Words are in the outer 30% of the page (left or right) - - Nearly all words are single characters (letters or digits) - - At least 8 such words form a vertical strip (≥8 unique Y positions) - - Average horizontal spread of the strip is small (< 80px) - - Phase 2 — once a strip is confirmed, also remove any short word (≤3 - chars) in the same narrow x-range. This catches multi-char OCR - artifacts like "Vv" that belong to the same decorative element. - - Modifies *words* in place. - - Returns: - Dict with 'found' (bool), 'side' (str), 'letters_detected' (int). - """ - no_strip: Dict[str, Any] = {"found": False, "side": "", "letters_detected": 0} - if not words or img_w <= 0: - return no_strip - - margin_cutoff = img_w * 0.30 - # Phase 1: find candidate strips using short words (1-2 chars). - # OCR often reads alphabet sidebar letters as pairs ("Aa", "Bb") - # rather than singles, so accept ≤2-char words as strip candidates. - left_strip = [ - w for w in words - if len((w.get("text") or "").strip()) <= 2 - and w["left"] + w.get("width", 0) / 2 < margin_cutoff - ] - right_strip = [ - w for w in words - if len((w.get("text") or "").strip()) <= 2 - and w["left"] + w.get("width", 0) / 2 > img_w - margin_cutoff - ] - - for strip, side in [(left_strip, "left"), (right_strip, "right")]: - if len(strip) < 6: - continue - # Check vertical distribution: should have many distinct Y positions - y_centers = sorted(set( - int(w["top"] + w.get("height", 0) / 2) // 20 * 20 # bucket - for w in strip - )) - if len(y_centers) < 6: - continue - # Check horizontal compactness - x_positions = [w["left"] for w in strip] - x_min = min(x_positions) - x_max = max(x_positions) - x_spread = x_max - x_min - if x_spread > 80: - continue - - # Phase 2: strip confirmed — also collect short words in same x-range - # Expand x-range slightly to catch neighbors (e.g. "Vv" next to "U") - strip_x_lo = x_min - 20 - strip_x_hi = x_max + 60 # word width + tolerance - all_strip_words = [ - w for w in words - if len((w.get("text") or "").strip()) <= 3 - and strip_x_lo <= w["left"] <= strip_x_hi - and (w["left"] + w.get("width", 0) / 2 < margin_cutoff - if side == "left" - else w["left"] + w.get("width", 0) / 2 > img_w - margin_cutoff) - ] - - strip_set = set(id(w) for w in all_strip_words) - before = len(words) - words[:] = [w for w in words if id(w) not in strip_set] - removed = before - len(words) - if removed: - log.info( - "build-grid session %s: removed %d decorative %s-margin words " - "(strip x=%d-%d)", - session_id, removed, side, strip_x_lo, strip_x_hi, - ) - return {"found": True, "side": side, "letters_detected": len(strip)} - - return no_strip - - -def _filter_footer_words( - words: List[Dict], - img_h: int, - log: Any, - session_id: str, -) -> Optional[Dict]: - """Remove isolated words in the bottom 5% of the page (page numbers). - - Modifies *words* in place and returns a page_number metadata dict - if a page number was extracted, or None. - """ - if not words or img_h <= 0: - return None - footer_y = img_h * 0.95 - footer_words = [ - w for w in words - if w["top"] + w.get("height", 0) / 2 > footer_y - ] - if not footer_words: - return None - # Only remove if footer has very few words (≤ 3) with short text - total_text = "".join((w.get("text") or "").strip() for w in footer_words) - if len(footer_words) <= 3 and len(total_text) <= 10: - # Extract page number metadata before removing - page_number_info = { - "text": total_text.strip(), - "y_pct": round(footer_words[0]["top"] / img_h * 100, 1), - } - # Try to parse as integer - digits = "".join(c for c in total_text if c.isdigit()) - if digits: - page_number_info["number"] = int(digits) - - footer_set = set(id(w) for w in footer_words) - words[:] = [w for w in words if id(w) not in footer_set] - log.info( - "build-grid session %s: extracted page number '%s' and removed %d footer words", - session_id, total_text, len(footer_words), - ) - return page_number_info - return None - - -def _filter_header_junk( - words: List[Dict], - img_h: int, - log: Any, - session_id: str, -) -> None: - """Remove OCR junk from header illustrations above the real content. - - Textbook pages often have decorative header graphics (illustrations, - icons) that OCR reads as low-confidence junk characters. Real content - typically starts further down the page. - - Algorithm: - 1. Find the "content start" — the first Y position where a dense - horizontal row of 3+ high-confidence words begins. - 2. Above that line, remove words with conf < 75 and text ≤ 3 chars. - These are almost certainly OCR artifacts from illustrations. - - Modifies *words* in place. - """ - if not words or img_h <= 0: - return - - # --- Find content start: first horizontal row with ≥3 high-conf words --- - # Sort words by Y - sorted_by_y = sorted(words, key=lambda w: w["top"]) - content_start_y = 0 - _ROW_TOLERANCE = img_h * 0.02 # words within 2% of page height = same row - _MIN_ROW_WORDS = 3 - _MIN_CONF = 80 - - i = 0 - while i < len(sorted_by_y): - row_y = sorted_by_y[i]["top"] - # Collect words in this row band - row_words = [] - j = i - while j < len(sorted_by_y) and sorted_by_y[j]["top"] - row_y < _ROW_TOLERANCE: - row_words.append(sorted_by_y[j]) - j += 1 - # Count high-confidence words with real text (> 1 char) - high_conf = [ - w for w in row_words - if w.get("conf", 0) >= _MIN_CONF - and len((w.get("text") or "").strip()) > 1 - ] - if len(high_conf) >= _MIN_ROW_WORDS: - content_start_y = row_y - break - i = j if j > i else i + 1 - - if content_start_y <= 0: - return # no clear content start found - - # --- Remove low-conf short junk above content start --- - junk = [ - w for w in words - if w["top"] + w.get("height", 0) < content_start_y - and w.get("conf", 0) < 75 - and len((w.get("text") or "").strip()) <= 3 - ] - if not junk: - return - - junk_set = set(id(w) for w in junk) - before = len(words) - words[:] = [w for w in words if id(w) not in junk_set] - removed = before - len(words) - if removed: - log.info( - "build-grid session %s: removed %d header junk words above y=%d " - "(content start)", - session_id, removed, content_start_y, - ) - +# --- Re-export: columns --------------------------------------------------- +from grid_editor_columns import ( # noqa: F401 + _is_recognized_word, + _split_cross_column_words, + _cluster_columns_by_alignment, + _MARKER_CHARS, + _merge_inline_marker_columns, +) + +# --- Re-export: filters ---------------------------------------------------- +from grid_editor_filters import ( # noqa: F401 + _filter_border_strip_words, + _GRID_GHOST_CHARS, + _filter_border_ghosts, + _flatten_word_boxes, + _words_in_zone, + _get_content_bounds, + _filter_decorative_margin, + _filter_footer_words, + _filter_header_junk, +) + +# --- Re-export: headers ---------------------------------------------------- +from grid_editor_headers import ( # noqa: F401 + _detect_heading_rows_by_color, + _detect_heading_rows_by_single_cell, + _detect_header_rows, + _detect_colspan_cells, +) + +# --- Re-export: zones ------------------------------------------------------- +from grid_editor_zones import ( # noqa: F401 + _PIPE_RE_VSPLIT, + _detect_vertical_dividers, + _split_zone_at_vertical_dividers, + _merge_content_zones_across_boxes, + _build_zone_grid, +) + +# --- Re-export from cv_words_first (used by cv_box_layout.py) --------------- +from cv_words_first import _cluster_rows # noqa: F401 diff --git a/klausur-service/backend/grid_editor_zones.py b/klausur-service/backend/grid_editor_zones.py new file mode 100644 index 0000000..2640c09 --- /dev/null +++ b/klausur-service/backend/grid_editor_zones.py @@ -0,0 +1,389 @@ +""" +Grid Editor — vertical divider detection, zone splitting/merging, zone grid building. + +Split from grid_editor_helpers.py for maintainability. +All functions are pure computation — no HTTP, DB, or session side effects. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Any, Dict, List, Optional + +from cv_vocab_types import PageZone +from cv_words_first import _cluster_rows, _build_cells + +from grid_editor_columns import ( + _cluster_columns_by_alignment, + _merge_inline_marker_columns, + _split_cross_column_words, +) +from grid_editor_headers import ( + _detect_header_rows, + _detect_colspan_cells, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Vertical divider detection and zone splitting +# --------------------------------------------------------------------------- + +_PIPE_RE_VSPLIT = re.compile(r"^\|+$") + + +def _detect_vertical_dividers( + words: List[Dict], + zone_x: int, + zone_w: int, + zone_y: int, + zone_h: int, +) -> List[float]: + """Detect vertical divider lines from pipe word_boxes at consistent x. + + Returns list of divider x-positions (empty if no dividers found). + """ + if not words or zone_w <= 0 or zone_h <= 0: + return [] + + # Collect pipe word_boxes + pipes = [ + w for w in words + if _PIPE_RE_VSPLIT.match((w.get("text") or "").strip()) + ] + if len(pipes) < 5: + return [] + + # Cluster pipe x-centers by proximity + tolerance = max(15, int(zone_w * 0.02)) + pipe_xs = sorted(w["left"] + w["width"] / 2 for w in pipes) + + clusters: List[List[float]] = [[pipe_xs[0]]] + for x in pipe_xs[1:]: + if x - clusters[-1][-1] <= tolerance: + clusters[-1].append(x) + else: + clusters.append([x]) + + dividers: List[float] = [] + for cluster in clusters: + if len(cluster) < 5: + continue + mean_x = sum(cluster) / len(cluster) + # Must be between 15% and 85% of zone width + rel_pos = (mean_x - zone_x) / zone_w + if rel_pos < 0.15 or rel_pos > 0.85: + continue + # Check vertical coverage: pipes must span >= 50% of zone height + cluster_pipes = [ + w for w in pipes + if abs(w["left"] + w["width"] / 2 - mean_x) <= tolerance + ] + ys = [w["top"] for w in cluster_pipes] + [w["top"] + w["height"] for w in cluster_pipes] + y_span = max(ys) - min(ys) if ys else 0 + if y_span < zone_h * 0.5: + continue + dividers.append(mean_x) + + return sorted(dividers) + + +def _split_zone_at_vertical_dividers( + zone: "PageZone", + divider_xs: List[float], + vsplit_group_id: int, +) -> List["PageZone"]: + """Split a PageZone at vertical divider positions into sub-zones.""" + boundaries = [zone.x] + divider_xs + [zone.x + zone.width] + hints = [] + for i in range(len(boundaries) - 1): + if i == 0: + hints.append("left_of_vsplit") + elif i == len(boundaries) - 2: + hints.append("right_of_vsplit") + else: + hints.append("middle_of_vsplit") + + sub_zones = [] + for i in range(len(boundaries) - 1): + x_start = int(boundaries[i]) + x_end = int(boundaries[i + 1]) + sub = PageZone( + index=0, # re-indexed later + zone_type=zone.zone_type, + y=zone.y, + height=zone.height, + x=x_start, + width=x_end - x_start, + box=zone.box, + image_overlays=zone.image_overlays, + layout_hint=hints[i], + vsplit_group=vsplit_group_id, + ) + sub_zones.append(sub) + + return sub_zones + + +def _merge_content_zones_across_boxes( + zones: List, + content_x: int, + content_w: int, +) -> List: + """Merge content zones separated by box zones into single zones. + + Box zones become image_overlays on the merged content zone. + Pattern: [content, box*, content] -> [merged_content with overlay] + Box zones NOT between two content zones stay as standalone zones. + """ + if len(zones) < 3: + return zones + + # Group consecutive runs of [content, box+, content] + result: List = [] + i = 0 + while i < len(zones): + z = zones[i] + if z.zone_type != "content": + result.append(z) + i += 1 + continue + + # Start of a potential merge group: content zone + group_contents = [z] + group_boxes = [] + j = i + 1 + # Absorb [box, content] pairs -- only absorb a box if it's + # confirmed to be followed by another content zone. + while j < len(zones): + if (zones[j].zone_type == "box" + and j + 1 < len(zones) + and zones[j + 1].zone_type == "content"): + group_boxes.append(zones[j]) + group_contents.append(zones[j + 1]) + j += 2 + else: + break + + if len(group_contents) >= 2 and group_boxes: + # Merge: create one large content zone spanning all + y_min = min(c.y for c in group_contents) + y_max = max(c.y + c.height for c in group_contents) + overlays = [] + for bz in group_boxes: + overlay = { + "y": bz.y, + "height": bz.height, + "x": bz.x, + "width": bz.width, + } + if bz.box: + overlay["box"] = { + "x": bz.box.x, + "y": bz.box.y, + "width": bz.box.width, + "height": bz.box.height, + "confidence": bz.box.confidence, + "border_thickness": bz.box.border_thickness, + } + overlays.append(overlay) + + merged = PageZone( + index=0, # re-indexed below + zone_type="content", + y=y_min, + height=y_max - y_min, + x=content_x, + width=content_w, + image_overlays=overlays, + ) + result.append(merged) + i = j + else: + # No merge possible -- emit just the content zone + result.append(z) + i += 1 + + # Re-index zones + for idx, z in enumerate(result): + z.index = idx + + logger.info( + "zone-merge: %d zones -> %d zones after merging across boxes", + len(zones), len(result), + ) + return result + + +def _build_zone_grid( + zone_words: List[Dict], + zone_x: int, + zone_y: int, + zone_w: int, + zone_h: int, + zone_index: int, + img_w: int, + img_h: int, + global_columns: Optional[List[Dict]] = None, + skip_first_row_header: bool = False, +) -> Dict[str, Any]: + """Build columns, rows, cells for a single zone from its words. + + Args: + global_columns: If provided, use these pre-computed column boundaries + instead of detecting columns per zone. Used for content zones so + that all content zones (above/between/below boxes) share the same + column structure. Box zones always detect columns independently. + """ + if not zone_words: + return { + "columns": [], + "rows": [], + "cells": [], + "header_rows": [], + } + + # Cluster rows first (needed for column alignment analysis) + rows = _cluster_rows(zone_words) + + # Diagnostic logging for small/medium zones (box zones typically have 40-60 words) + if len(zone_words) <= 60: + import statistics as _st + _heights = [w['height'] for w in zone_words if w.get('height', 0) > 0] + _med_h = _st.median(_heights) if _heights else 20 + _y_tol = max(_med_h * 0.5, 5) + logger.info( + "zone %d row-clustering: %d words, median_h=%.0f, y_tol=%.1f -> %d rows", + zone_index, len(zone_words), _med_h, _y_tol, len(rows), + ) + for w in sorted(zone_words, key=lambda ww: (ww['top'], ww['left'])): + logger.info( + " zone %d word: y=%d x=%d h=%d w=%d '%s'", + zone_index, w['top'], w['left'], w['height'], w['width'], + w.get('text', '')[:40], + ) + for r in rows: + logger.info( + " zone %d row %d: y_min=%d y_max=%d y_center=%.0f", + zone_index, r['index'], r['y_min'], r['y_max'], r['y_center'], + ) + + # Use global columns if provided, otherwise detect per zone + columns = global_columns if global_columns else _cluster_columns_by_alignment(zone_words, zone_w, rows) + + # Merge inline marker columns (bullets, numbering) into adjacent text + if not global_columns: + columns = _merge_inline_marker_columns(columns, zone_words) + + if not columns or not rows: + return { + "columns": [], + "rows": [], + "cells": [], + "header_rows": [], + } + + # Split word boxes that straddle column boundaries (e.g. "sichzie" + # spanning Col 1 + Col 2). Must happen after column detection and + # before cell assignment. + # Keep original words for colspan detection (split destroys span info). + original_zone_words = zone_words + if len(columns) >= 2: + zone_words = _split_cross_column_words(zone_words, columns) + + # Build cells + cells = _build_cells(zone_words, columns, rows, img_w, img_h) + + # --- Detect colspan (merged cells spanning multiple columns) --- + # Uses the ORIGINAL (pre-split) words to detect word-blocks that span + # multiple columns. _split_cross_column_words would have destroyed + # this information by cutting words at column boundaries. + if len(columns) >= 2: + cells = _detect_colspan_cells(original_zone_words, columns, rows, cells, img_w, img_h) + + # Prefix cell IDs with zone index + for cell in cells: + cell["cell_id"] = f"Z{zone_index}_{cell['cell_id']}" + cell["zone_index"] = zone_index + + # Detect header rows (pass columns for spanning header detection) + header_rows = _detect_header_rows(rows, zone_words, zone_y, columns, + skip_first_row_header=skip_first_row_header) + + # Merge cells in spanning header rows into a single col-0 cell + if header_rows and len(columns) >= 2: + for hri in header_rows: + header_cells = [c for c in cells if c["row_index"] == hri] + if len(header_cells) <= 1: + continue + # Collect all word_boxes and text from all columns + all_wb = [] + all_text_parts = [] + for hc in sorted(header_cells, key=lambda c: c["col_index"]): + all_wb.extend(hc.get("word_boxes", [])) + if hc.get("text", "").strip(): + all_text_parts.append(hc["text"].strip()) + # Remove all header cells, replace with one spanning cell + cells = [c for c in cells if c["row_index"] != hri] + if all_wb: + x_min = min(wb["left"] for wb in all_wb) + y_min = min(wb["top"] for wb in all_wb) + x_max = max(wb["left"] + wb["width"] for wb in all_wb) + y_max = max(wb["top"] + wb["height"] for wb in all_wb) + cells.append({ + "cell_id": f"R{hri:02d}_C0", + "row_index": hri, + "col_index": 0, + "col_type": "spanning_header", + "text": " ".join(all_text_parts), + "confidence": 0.0, + "bbox_px": {"x": x_min, "y": y_min, + "w": x_max - x_min, "h": y_max - y_min}, + "bbox_pct": { + "x": round(x_min / img_w * 100, 2) if img_w else 0, + "y": round(y_min / img_h * 100, 2) if img_h else 0, + "w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0, + "h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0, + }, + "word_boxes": all_wb, + "ocr_engine": "words_first", + "is_bold": True, + }) + + # Convert columns to output format with percentages + out_columns = [] + for col in columns: + x_min = col["x_min"] + x_max = col["x_max"] + out_columns.append({ + "index": col["index"], + "label": col["type"], + "x_min_px": round(x_min), + "x_max_px": round(x_max), + "x_min_pct": round(x_min / img_w * 100, 2) if img_w else 0, + "x_max_pct": round(x_max / img_w * 100, 2) if img_w else 0, + "bold": False, + }) + + # Convert rows to output format with percentages + out_rows = [] + for row in rows: + out_rows.append({ + "index": row["index"], + "y_min_px": round(row["y_min"]), + "y_max_px": round(row["y_max"]), + "y_min_pct": round(row["y_min"] / img_h * 100, 2) if img_h else 0, + "y_max_pct": round(row["y_max"] / img_h * 100, 2) if img_h else 0, + "is_header": row["index"] in header_rows, + }) + + return { + "columns": out_columns, + "rows": out_rows, + "cells": cells, + "header_rows": header_rows, + "_raw_columns": columns, # internal: for propagation to other zones + } diff --git a/klausur-service/backend/legal_corpus_chunking.py b/klausur-service/backend/legal_corpus_chunking.py new file mode 100644 index 0000000..b3c47fc --- /dev/null +++ b/klausur-service/backend/legal_corpus_chunking.py @@ -0,0 +1,197 @@ +""" +Legal Corpus Chunking — Text splitting, semantic chunking, and HTML-to-text conversion. + +Provides German-aware sentence splitting, paragraph splitting, semantic chunking +with overlap, and HTML-to-text conversion for legal document ingestion. +""" + +import re +from typing import Dict, List, Optional, Tuple + + +# German abbreviations that don't end sentences +GERMAN_ABBREVIATIONS = { + 'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max', + 'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw', + 'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'aufl', + 'bd', 'betr', 'bzgl', 'dgl', 'ebd', 'hrsg', 'jg', 'kap', 'lt', + 'rdnr', 'rn', 'std', 'str', 'tel', 'ua', 'uvm', 'va', 'zb', + 'bsi', 'tr', 'owasp', 'iso', 'iec', 'din', 'en' +} + + +def split_into_sentences(text: str) -> List[str]: + """Split text into sentences with German language support.""" + if not text: + return [] + + text = re.sub(r'\s+', ' ', text).strip() + + # Protect abbreviations + protected_text = text + for abbrev in GERMAN_ABBREVIATIONS: + pattern = re.compile(r'\b' + re.escape(abbrev) + r'\.', re.IGNORECASE) + protected_text = pattern.sub(abbrev.replace('.', '') + '', protected_text) + + # Protect decimal/ordinal numbers and requirement IDs (e.g., "O.Data_1") + protected_text = re.sub(r'(\d)\.(\d)', r'\1\2', protected_text) + protected_text = re.sub(r'(\d+)\.(\s)', r'\1\2', protected_text) + protected_text = re.sub(r'([A-Z])\.([A-Z])', r'\1\2', protected_text) # O.Data_1 + + # Split on sentence endings + sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜ0-9])|(?<=[.!?])$' + raw_sentences = re.split(sentence_pattern, protected_text) + + # Restore protected characters + sentences = [] + for s in raw_sentences: + s = s.replace('', '.').replace('', '.').replace('', '.').replace('', '.').replace('', '.') + s = s.strip() + if s: + sentences.append(s) + + return sentences + + +def split_into_paragraphs(text: str) -> List[str]: + """Split text into paragraphs.""" + if not text: + return [] + + raw_paragraphs = re.split(r'\n\s*\n', text) + return [para.strip() for para in raw_paragraphs if para.strip()] + + +def chunk_text_semantic( + text: str, + chunk_size: int = 1000, + overlap: int = 200, +) -> List[Tuple[str, int]]: + """ + Semantic chunking that respects paragraph and sentence boundaries. + Matches NIBIS chunking strategy for consistency. + + Returns list of (chunk_text, start_position) tuples. + """ + if not text: + return [] + + if len(text) <= chunk_size: + return [(text.strip(), 0)] + + paragraphs = split_into_paragraphs(text) + overlap_sentences = max(1, overlap // 100) # Convert char overlap to sentence overlap + + chunks = [] + current_chunk_parts: List[str] = [] + current_chunk_length = 0 + chunk_start = 0 + position = 0 + + for para in paragraphs: + if len(para) > chunk_size: + # Large paragraph: split into sentences + sentences = split_into_sentences(para) + + for sentence in sentences: + sentence_len = len(sentence) + + if sentence_len > chunk_size: + # Very long sentence: save current chunk first + if current_chunk_parts: + chunk_text = ' '.join(current_chunk_parts) + chunks.append((chunk_text, chunk_start)) + overlap_buffer = current_chunk_parts[-overlap_sentences:] if overlap_sentences > 0 else [] + current_chunk_parts = list(overlap_buffer) + current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts) + + # Add long sentence as its own chunk + chunks.append((sentence, position)) + current_chunk_parts = [sentence] + current_chunk_length = len(sentence) + 1 + position += sentence_len + 1 + continue + + if current_chunk_length + sentence_len + 1 > chunk_size and current_chunk_parts: + # Current chunk is full, save it + chunk_text = ' '.join(current_chunk_parts) + chunks.append((chunk_text, chunk_start)) + overlap_buffer = current_chunk_parts[-overlap_sentences:] if overlap_sentences > 0 else [] + current_chunk_parts = list(overlap_buffer) + current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts) + chunk_start = position - current_chunk_length + + current_chunk_parts.append(sentence) + current_chunk_length += sentence_len + 1 + position += sentence_len + 1 + else: + # Small paragraph: try to keep together + para_len = len(para) + if current_chunk_length + para_len + 2 > chunk_size and current_chunk_parts: + chunk_text = ' '.join(current_chunk_parts) + chunks.append((chunk_text, chunk_start)) + last_para_sentences = split_into_sentences(current_chunk_parts[-1] if current_chunk_parts else "") + overlap_buffer = last_para_sentences[-overlap_sentences:] if overlap_sentences > 0 and last_para_sentences else [] + current_chunk_parts = list(overlap_buffer) + current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts) + chunk_start = position - current_chunk_length + + if current_chunk_parts: + current_chunk_parts.append(para) + current_chunk_length += para_len + 2 + else: + current_chunk_parts = [para] + current_chunk_length = para_len + chunk_start = position + + position += para_len + 2 + + # Don't forget the last chunk + if current_chunk_parts: + chunk_text = ' '.join(current_chunk_parts) + chunks.append((chunk_text, chunk_start)) + + # Clean up whitespace + return [(re.sub(r'\s+', ' ', c).strip(), pos) for c, pos in chunks if c.strip()] + + +def extract_article_info(text: str) -> Optional[Dict]: + """Extract article number and paragraph from text.""" + # Pattern for "Artikel X" or "Art. X" + article_match = re.search(r'(?:Artikel|Art\.?)\s+(\d+)', text) + paragraph_match = re.search(r'(?:Absatz|Abs\.?)\s+(\d+)', text) + + if article_match: + return { + "article": article_match.group(1), + "paragraph": paragraph_match.group(1) if paragraph_match else None, + } + return None + + +def html_to_text(html_content: str) -> str: + """Convert HTML to clean text.""" + # Remove script and style tags + html_content = re.sub(r']*>.*?', '', html_content, flags=re.DOTALL) + html_content = re.sub(r']*>.*?', '', html_content, flags=re.DOTALL) + # Remove comments + html_content = re.sub(r'', '', html_content, flags=re.DOTALL) + # Replace common HTML entities + html_content = html_content.replace(' ', ' ') + html_content = html_content.replace('&', '&') + html_content = html_content.replace('<', '<') + html_content = html_content.replace('>', '>') + html_content = html_content.replace('"', '"') + # Convert breaks and paragraphs to newlines for better chunking + html_content = re.sub(r'', '\n', html_content, flags=re.IGNORECASE) + html_content = re.sub(r'

', '\n\n', html_content, flags=re.IGNORECASE) + html_content = re.sub(r'', '\n', html_content, flags=re.IGNORECASE) + html_content = re.sub(r'', '\n\n', html_content, flags=re.IGNORECASE) + # Remove remaining HTML tags + text = re.sub(r'<[^>]+>', ' ', html_content) + # Clean up whitespace (but preserve paragraph breaks) + text = re.sub(r'[ \t]+', ' ', text) + text = re.sub(r'\n[ \t]+', '\n', text) + text = re.sub(r'[ \t]+\n', '\n', text) + text = re.sub(r'\n{3,}', '\n\n', text) + return text.strip() diff --git a/klausur-service/backend/legal_corpus_ingestion.py b/klausur-service/backend/legal_corpus_ingestion.py index 62b5e55..235b49c 100644 --- a/klausur-service/backend/legal_corpus_ingestion.py +++ b/klausur-service/backend/legal_corpus_ingestion.py @@ -8,6 +8,10 @@ Includes EU regulations, DACH national laws, and EDPB guidelines. Collections: - bp_legal_corpus: All regulation texts (GDPR, AI Act, CRA, BSI, etc.) +Split modules: +- legal_corpus_registry: Regulation dataclass + REGULATIONS list (pure data) +- legal_corpus_chunking: Sentence/paragraph splitting, semantic chunking, HTML-to-text + Usage: python legal_corpus_ingestion.py --ingest-all python legal_corpus_ingestion.py --ingest GDPR AIACT @@ -19,11 +23,9 @@ import hashlib import json import logging import os -import re -from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from urllib.parse import urlparse import httpx @@ -37,6 +39,17 @@ from qdrant_client.models import ( VectorParams, ) +# Re-export for backward compatibility +from legal_corpus_registry import Regulation, REGULATIONS # noqa: F401 +from legal_corpus_chunking import ( # noqa: F401 + chunk_text_semantic, + extract_article_info, + html_to_text, + split_into_sentences, + split_into_paragraphs, + GERMAN_ABBREVIATIONS, +) + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -44,8 +57,6 @@ logger = logging.getLogger(__name__) # Configuration - Support both QDRANT_URL and QDRANT_HOST/PORT _qdrant_url = os.getenv("QDRANT_URL", "") if _qdrant_url: - # Parse URL: http://qdrant:6333 -> host=qdrant, port=6333 - from urllib.parse import urlparse _parsed = urlparse(_qdrant_url) QDRANT_HOST = _parsed.hostname or "localhost" QDRANT_PORT = _parsed.port or 6333 @@ -61,614 +72,12 @@ CHUNK_SIZE = int(os.getenv("LEGAL_CHUNK_SIZE", "1000")) CHUNK_OVERLAP = int(os.getenv("LEGAL_CHUNK_OVERLAP", "200")) # Base path for local PDF/HTML files -# In Docker: /app/docs/legal_corpus (mounted volume) -# Local dev: relative to script location _default_docs_path = Path(__file__).parent.parent / "docs" / "legal_corpus" LEGAL_DOCS_PATH = Path(os.getenv("LEGAL_DOCS_PATH", str(_default_docs_path))) -# Docker-specific override: if /app/docs exists, use it if Path("/app/docs/legal_corpus").exists(): LEGAL_DOCS_PATH = Path("/app/docs/legal_corpus") -@dataclass -class Regulation: - """Regulation metadata.""" - code: str - name: str - full_name: str - regulation_type: str - source_url: str - description: str - celex: Optional[str] = None # CELEX number for EUR-Lex direct access - local_path: Optional[str] = None - language: str = "de" - requirement_count: int = 0 - - -# All regulations from Compliance Hub (EU + DACH national laws + guidelines) -REGULATIONS: List[Regulation] = [ - Regulation( - code="GDPR", - name="DSGVO", - full_name="Verordnung (EU) 2016/679 - Datenschutz-Grundverordnung", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2016/679/oj/deu", - description="Grundverordnung zum Schutz natuerlicher Personen bei der Verarbeitung personenbezogener Daten.", - celex="32016R0679", - requirement_count=99, - ), - Regulation( - code="EPRIVACY", - name="ePrivacy-Richtlinie", - full_name="Richtlinie 2002/58/EG", - regulation_type="eu_directive", - source_url="https://eur-lex.europa.eu/eli/dir/2002/58/oj/deu", - description="Datenschutz in der elektronischen Kommunikation, Cookies und Tracking.", - celex="32002L0058", - requirement_count=25, - ), - Regulation( - code="TDDDG", - name="TDDDG", - full_name="Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/ttdsg/TDDDG.pdf", - description="Deutsche Umsetzung der ePrivacy-Richtlinie (30 Paragraphen).", - requirement_count=30, - ), - Regulation( - code="SCC", - name="Standardvertragsklauseln", - full_name="Durchfuehrungsbeschluss (EU) 2021/914", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/dec_impl/2021/914/oj/deu", - description="Standardvertragsklauseln fuer Drittlandtransfers.", - celex="32021D0914", - requirement_count=18, - ), - Regulation( - code="DPF", - name="EU-US Data Privacy Framework", - full_name="Durchfuehrungsbeschluss (EU) 2023/1795", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/dec_impl/2023/1795/oj", - description="Angemessenheitsbeschluss fuer USA-Transfers.", - celex="32023D1795", - requirement_count=12, - ), - Regulation( - code="AIACT", - name="EU AI Act", - full_name="Verordnung (EU) 2024/1689 - KI-Verordnung", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2024/1689/oj/deu", - description="EU-Verordnung zur Regulierung von KI-Systemen nach Risikostufen.", - celex="32024R1689", - requirement_count=85, - ), - Regulation( - code="CRA", - name="Cyber Resilience Act", - full_name="Verordnung (EU) 2024/2847", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2024/2847/oj/deu", - description="Cybersicherheitsanforderungen, SBOM-Pflicht.", - celex="32024R2847", - requirement_count=45, - ), - Regulation( - code="NIS2", - name="NIS2-Richtlinie", - full_name="Richtlinie (EU) 2022/2555", - regulation_type="eu_directive", - source_url="https://eur-lex.europa.eu/eli/dir/2022/2555/oj/deu", - description="Cybersicherheit fuer wesentliche Einrichtungen.", - celex="32022L2555", - requirement_count=46, - ), - Regulation( - code="EUCSA", - name="EU Cybersecurity Act", - full_name="Verordnung (EU) 2019/881", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2019/881/oj/deu", - description="ENISA und Cybersicherheitszertifizierung.", - celex="32019R0881", - requirement_count=35, - ), - Regulation( - code="DATAACT", - name="Data Act", - full_name="Verordnung (EU) 2023/2854", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2023/2854/oj/deu", - description="Fairer Datenzugang, IoT-Daten, Cloud-Wechsel.", - celex="32023R2854", - requirement_count=42, - ), - Regulation( - code="DGA", - name="Data Governance Act", - full_name="Verordnung (EU) 2022/868", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2022/868/oj/deu", - description="Weiterverwendung oeffentlicher Daten.", - celex="32022R0868", - requirement_count=35, - ), - Regulation( - code="DSA", - name="Digital Services Act", - full_name="Verordnung (EU) 2022/2065", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2022/2065/oj/deu", - description="Digitale Dienste, Transparenzpflichten.", - celex="32022R2065", - requirement_count=93, - ), - Regulation( - code="EAA", - name="European Accessibility Act", - full_name="Richtlinie (EU) 2019/882", - regulation_type="eu_directive", - source_url="https://eur-lex.europa.eu/eli/dir/2019/882/oj/deu", - description="Barrierefreiheit digitaler Produkte.", - celex="32019L0882", - requirement_count=25, - ), - Regulation( - code="DSM", - name="DSM-Urheberrechtsrichtlinie", - full_name="Richtlinie (EU) 2019/790", - regulation_type="eu_directive", - source_url="https://eur-lex.europa.eu/eli/dir/2019/790/oj/deu", - description="Urheberrecht, Text- und Data-Mining.", - celex="32019L0790", - requirement_count=22, - ), - Regulation( - code="PLD", - name="Produkthaftungsrichtlinie", - full_name="Richtlinie (EU) 2024/2853", - regulation_type="eu_directive", - source_url="https://eur-lex.europa.eu/eli/dir/2024/2853/oj/deu", - description="Produkthaftung inkl. Software und KI.", - celex="32024L2853", - requirement_count=18, - ), - Regulation( - code="GPSR", - name="General Product Safety", - full_name="Verordnung (EU) 2023/988", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2023/988/oj/deu", - description="Allgemeine Produktsicherheit.", - celex="32023R0988", - requirement_count=30, - ), - Regulation( - code="BSI-TR-03161-1", - name="BSI-TR-03161 Teil 1", - full_name="BSI Technische Richtlinie - Allgemeine Anforderungen", - regulation_type="bsi_standard", - source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-1.pdf?__blob=publicationFile&v=6", - description="Allgemeine Sicherheitsanforderungen (45 Pruefaspekte).", - requirement_count=45, - ), - Regulation( - code="BSI-TR-03161-2", - name="BSI-TR-03161 Teil 2", - full_name="BSI Technische Richtlinie - Web-Anwendungen", - regulation_type="bsi_standard", - source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-2.pdf?__blob=publicationFile&v=5", - description="Web-Sicherheit (40 Pruefaspekte).", - requirement_count=40, - ), - Regulation( - code="BSI-TR-03161-3", - name="BSI-TR-03161 Teil 3", - full_name="BSI Technische Richtlinie - Hintergrundsysteme", - regulation_type="bsi_standard", - source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-3.pdf?__blob=publicationFile&v=5", - description="Backend-Sicherheit (35 Pruefaspekte).", - requirement_count=35, - ), - # Additional regulations for financial sector and health - Regulation( - code="DORA", - name="DORA", - full_name="Verordnung (EU) 2022/2554 - Digital Operational Resilience Act", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2022/2554/oj/deu", - description="Digitale operationale Resilienz fuer den Finanzsektor. IKT-Risikomanagement, Vorfallmeldung, Resilienz-Tests.", - celex="32022R2554", - requirement_count=64, - ), - Regulation( - code="PSD2", - name="PSD2", - full_name="Richtlinie (EU) 2015/2366 - Zahlungsdiensterichtlinie", - regulation_type="eu_directive", - source_url="https://eur-lex.europa.eu/eli/dir/2015/2366/oj/deu", - description="Zahlungsdienste im Binnenmarkt. Starke Kundenauthentifizierung, Open Banking APIs.", - celex="32015L2366", - requirement_count=117, - ), - Regulation( - code="AMLR", - name="AML-Verordnung", - full_name="Verordnung (EU) 2024/1624 - Geldwaeschebekaempfung", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2024/1624/oj/deu", - description="Verhinderung der Nutzung des Finanzsystems zur Geldwaesche und Terrorismusfinanzierung.", - celex="32024R1624", - requirement_count=89, - ), - Regulation( - code="EHDS", - name="EHDS", - full_name="Verordnung (EU) 2025/327 - Europaeischer Gesundheitsdatenraum", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2025/327/oj/deu", - description="Europaeischer Raum fuer Gesundheitsdaten. Primaer- und Sekundaernutzung von Gesundheitsdaten.", - celex="32025R0327", - requirement_count=95, - ), - Regulation( - code="MiCA", - name="MiCA", - full_name="Verordnung (EU) 2023/1114 - Markets in Crypto-Assets", - regulation_type="eu_regulation", - source_url="https://eur-lex.europa.eu/eli/reg/2023/1114/oj/deu", - description="Regulierung von Kryptowerten, Stablecoins und Crypto-Asset-Dienstleistern.", - celex="32023R1114", - requirement_count=149, - ), - # ===================================================================== - # DACH National Laws — Deutschland (P1) - # ===================================================================== - Regulation( - code="DE_DDG", - name="Digitale-Dienste-Gesetz", - full_name="Digitale-Dienste-Gesetz (DDG)", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/ddg/", - description="Deutsches Umsetzungsgesetz zum DSA. Regelt Impressumspflicht (§5), Informationspflichten fuer digitale Dienste und Cookies.", - requirement_count=30, - ), - Regulation( - code="DE_BGB_AGB", - name="BGB AGB-Recht", - full_name="BGB §§305-310, 312-312k — AGB und Fernabsatz", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/bgb/", - description="Deutsches AGB-Recht (§§305-310 BGB) und Fernabsatzrecht (§§312-312k BGB). Klauselverbote, Inhaltskontrolle, Widerrufsrecht, Button-Loesung.", - local_path="DE_BGB_AGB.txt", - requirement_count=40, - ), - Regulation( - code="DE_EGBGB", - name="EGBGB Art. 246-248", - full_name="Einfuehrungsgesetz zum BGB — Informationspflichten", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/bgbeg/", - description="Informationspflichten bei Verbrauchervertraegen (Art. 246), Fernabsatz (Art. 246a), E-Commerce (Art. 246c).", - local_path="DE_EGBGB.txt", - requirement_count=20, - ), - Regulation( - code="DE_UWG", - name="UWG Deutschland", - full_name="Gesetz gegen den unlauteren Wettbewerb (UWG)", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/uwg_2004/", - description="Unlauterer Wettbewerb: irrefuehrende Werbung, Spam-Verbot, Preisangaben, Online-Marketing-Regeln.", - requirement_count=25, - ), - Regulation( - code="DE_HGB_RET", - name="HGB Aufbewahrung", - full_name="HGB §§238-261, 257 — Handelsbuecher und Aufbewahrungsfristen", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/hgb/", - description="Buchfuehrungspflicht, Aufbewahrungsfristen 6/10 Jahre, Anforderungen an elektronische Aufbewahrung.", - local_path="DE_HGB_RET.txt", - requirement_count=15, - ), - Regulation( - code="DE_AO_RET", - name="AO Aufbewahrung", - full_name="Abgabenordnung §§140-148 — Steuerliche Aufbewahrungspflichten", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/ao_1977/", - description="Steuerliche Buchfuehrungs- und Aufbewahrungspflichten. 6/10 Jahre Fristen, Datenzugriff durch Finanzbehoerden.", - local_path="DE_AO_RET.txt", - requirement_count=12, - ), - Regulation( - code="DE_TKG", - name="TKG 2021", - full_name="Telekommunikationsgesetz 2021", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/tkg_2021/", - description="Telekommunikationsregulierung: Kundenschutz, Datenschutz, Vertragslaufzeiten, Netzinfrastruktur.", - requirement_count=45, - ), - # ===================================================================== - # DACH National Laws — Oesterreich (P1) - # ===================================================================== - Regulation( - code="AT_ECG", - name="E-Commerce-Gesetz AT", - full_name="E-Commerce-Gesetz (ECG) Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20001703", - description="Oesterreichisches E-Commerce-Gesetz: Impressum/Offenlegungspflicht (§5), Informationspflichten, Haftung von Diensteanbietern.", - language="de", - requirement_count=30, - ), - Regulation( - code="AT_TKG", - name="TKG 2021 AT", - full_name="Telekommunikationsgesetz 2021 Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20011678", - description="Oesterreichisches TKG: Cookie-Bestimmungen (§165), Kommunikationsgeheimnis, Endgeraetezugriff.", - language="de", - requirement_count=40, - ), - Regulation( - code="AT_KSCHG", - name="KSchG Oesterreich", - full_name="Konsumentenschutzgesetz (KSchG) Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10002462", - description="Konsumentenschutz: AGB-Kontrolle (§6 Klauselverbote, §9 Verbandsklage), Ruecktrittsrecht, Informationspflichten.", - language="de", - requirement_count=35, - ), - Regulation( - code="AT_FAGG", - name="FAGG Oesterreich", - full_name="Fern- und Auswaertsgeschaefte-Gesetz (FAGG) Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20008847", - description="Fernabsatzrecht: Informationspflichten, Widerrufsrecht 14 Tage, Button-Loesung, Ausnahmen.", - language="de", - requirement_count=20, - ), - Regulation( - code="AT_UGB_RET", - name="UGB Aufbewahrung AT", - full_name="UGB §§189-216, 212 — Rechnungslegung und Aufbewahrung Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10001702", - description="Oesterreichische Rechnungslegungspflicht und Aufbewahrungsfristen (7 Jahre). Buchfuehrung, Jahresabschluss.", - local_path="AT_UGB_RET.txt", - language="de", - requirement_count=15, - ), - Regulation( - code="AT_BAO_RET", - name="BAO §132 AT", - full_name="Bundesabgabenordnung §132 — Aufbewahrung Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10003940", - description="Steuerliche Aufbewahrungspflicht 7 Jahre fuer Buecher, Aufzeichnungen und Belege. Grundstuecke 22 Jahre.", - language="de", - requirement_count=5, - ), - Regulation( - code="AT_MEDIENG", - name="MedienG §§24-25 AT", - full_name="Mediengesetz §§24-25 Oesterreich — Impressum und Offenlegung", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10000719", - description="Impressum/Offenlegungspflicht fuer periodische Medien und Websites in Oesterreich.", - language="de", - requirement_count=10, - ), - # ===================================================================== - # DACH National Laws — Schweiz (P1) - # ===================================================================== - Regulation( - code="CH_DSV", - name="DSV Schweiz", - full_name="Datenschutzverordnung (DSV) Schweiz — SR 235.11", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/2022/568/de", - description="Ausfuehrungsverordnung zum revDSG: Meldepflichten, DSFA-Verfahren, Auslandtransfers, technische Massnahmen.", - language="de", - requirement_count=30, - ), - Regulation( - code="CH_OR_AGB", - name="OR AGB/Aufbewahrung CH", - full_name="Obligationenrecht — AGB-Kontrolle und Aufbewahrung Schweiz (SR 220)", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/27/317_321_377/de", - description="Art. 8 OR (AGB-Inhaltskontrolle), Art. 19/20 (Vertragsfreiheit), Art. 957-958f (Buchfuehrung, 10 Jahre Aufbewahrung).", - local_path="CH_OR_AGB.txt", - language="de", - requirement_count=20, - ), - Regulation( - code="CH_UWG", - name="UWG Schweiz", - full_name="Bundesgesetz gegen den unlauteren Wettbewerb Schweiz (SR 241)", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/1988/223_223_223/de", - description="Lauterkeitsrecht: Impressumspflicht, irrefuehrende Werbung, aggressive Verkaufsmethoden, AGB-Transparenz.", - language="de", - requirement_count=20, - ), - Regulation( - code="CH_FMG", - name="FMG Schweiz", - full_name="Fernmeldegesetz Schweiz (SR 784.10)", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/1997/2187_2187_2187/de", - description="Telekommunikationsregulierung: Fernmeldegeheimnis, Cookies/Tracking (Art. 45c), Spam-Verbot, Datenschutz.", - language="de", - requirement_count=25, - ), - # ===================================================================== - # Deutschland P2 - # ===================================================================== - Regulation( - code="DE_PANGV", - name="PAngV", - full_name="Preisangabenverordnung (PAngV 2022)", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/pangv_2022/", - description="Preisangaben: Gesamtpreis, Grundpreis, Streichpreise (§11), Online-Preisauszeichnung.", - requirement_count=15, - ), - Regulation( - code="DE_DLINFOV", - name="DL-InfoV", - full_name="Dienstleistungs-Informationspflichten-Verordnung", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/dlinfov/", - description="Informationspflichten fuer Dienstleister: Identitaet, Kontakt, Berufshaftpflicht, AGB-Zugang.", - requirement_count=10, - ), - Regulation( - code="DE_BETRVG", - name="BetrVG §87", - full_name="Betriebsverfassungsgesetz §87 Abs.1 Nr.6", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/betrvg/", - description="Mitbestimmung bei technischer Ueberwachung: Betriebsrat-Beteiligung bei IT-Systemen, die Arbeitnehmerverhalten ueberwachen koennen.", - requirement_count=5, - ), - # ===================================================================== - # Oesterreich P2 - # ===================================================================== - Regulation( - code="AT_ABGB_AGB", - name="ABGB AGB-Recht AT", - full_name="ABGB §§861-879, 864a — AGB-Kontrolle Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10001622", - description="Geltungskontrolle (§864a), Sittenwidrigkeitskontrolle (§879 Abs.3), allgemeine Vertragsregeln.", - local_path="AT_ABGB_AGB.txt", - language="de", - requirement_count=10, - ), - Regulation( - code="AT_UWG", - name="UWG Oesterreich", - full_name="Bundesgesetz gegen den unlauteren Wettbewerb Oesterreich", - regulation_type="at_law", - source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10002665", - description="Lauterkeitsrecht AT: irrefuehrende Geschaeftspraktiken, aggressive Praktiken, Preisauszeichnung.", - language="de", - requirement_count=15, - ), - # ===================================================================== - # Schweiz P2 - # ===================================================================== - Regulation( - code="CH_GEBUV", - name="GeBuV Schweiz", - full_name="Geschaeftsbuecher-Verordnung Schweiz (SR 221.431)", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/2002/468_468_468/de", - description="Ausfuehrungsvorschriften zur Buchfuehrung: elektronische Aufbewahrung, Integritaet, Datentraeger.", - language="de", - requirement_count=10, - ), - Regulation( - code="CH_ZERTES", - name="ZertES Schweiz", - full_name="Bundesgesetz ueber die elektronische Signatur (SR 943.03)", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/2016/752/de", - description="Elektronische Signatur und Zertifizierung: Qualifizierte Signaturen, Zertifizierungsdiensteanbieter.", - language="de", - requirement_count=10, - ), - # ===================================================================== - # Deutschland P3 - # ===================================================================== - Regulation( - code="DE_GESCHGEHG", - name="GeschGehG", - full_name="Gesetz zum Schutz von Geschaeftsgeheimnissen", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/geschgehg/", - description="Schutz von Geschaeftsgeheimnissen: Definition, angemessene Geheimhaltungsmassnahmen, Reverse Engineering.", - requirement_count=10, - ), - Regulation( - code="DE_BSIG", - name="BSI-Gesetz", - full_name="Gesetz ueber das Bundesamt fuer Sicherheit in der Informationstechnik (BSIG)", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/bsig_2009/", - description="BSI-Aufgaben, KRITIS-Meldepflichten, IT-Sicherheitsstandards, Zertifizierung.", - requirement_count=20, - ), - Regulation( - code="DE_USTG_RET", - name="UStG §14b", - full_name="Umsatzsteuergesetz §14b — Aufbewahrung von Rechnungen", - regulation_type="de_law", - source_url="https://www.gesetze-im-internet.de/ustg_1980/", - description="Aufbewahrungspflicht fuer Rechnungen: 10 Jahre, Grundstuecke 20 Jahre, elektronische Aufbewahrung.", - local_path="DE_USTG_RET.txt", - requirement_count=5, - ), - # ===================================================================== - # Schweiz P3 - # ===================================================================== - Regulation( - code="CH_ZGB_PERS", - name="ZGB Persoenlichkeitsschutz CH", - full_name="Zivilgesetzbuch Art. 28-28l — Persoenlichkeitsschutz Schweiz (SR 210)", - regulation_type="ch_law", - source_url="https://www.fedlex.admin.ch/eli/cc/24/233_245_233/de", - description="Persoenlichkeitsschutz: Recht am eigenen Bild, Schutz der Privatsphaere, Gegendarstellungsrecht.", - language="de", - requirement_count=8, - ), - # ===================================================================== - # 3 fehlgeschlagene Quellen mit alternativen URLs nachholen - # ===================================================================== - Regulation( - code="LU_DPA_LAW", - name="Datenschutzgesetz Luxemburg", - full_name="Loi du 1er aout 2018 — Datenschutzgesetz Luxemburg", - regulation_type="national_law", - source_url="https://legilux.public.lu/eli/etat/leg/loi/2018/08/01/a686/jo", - description="Luxemburgisches Datenschutzgesetz: Organisation der CNPD, nationale DSGVO-Ergaenzung.", - language="fr", - requirement_count=40, - ), - Regulation( - code="DK_DATABESKYTTELSESLOVEN", - name="Databeskyttelsesloven DK", - full_name="Databeskyttelsesloven — Datenschutzgesetz Daenemark", - regulation_type="national_law", - source_url="https://www.retsinformation.dk/eli/lta/2018/502", - description="Daenisches Datenschutzgesetz als ergaenzende Bestimmungen zur DSGVO. Reguliert durch Datatilsynet.", - language="da", - requirement_count=30, - ), - Regulation( - code="EDPB_GUIDELINES_1_2022", - name="EDPB GL Bussgelder", - full_name="EDPB Leitlinien 04/2022 zur Berechnung von Bussgeldern nach der DSGVO", - regulation_type="eu_guideline", - source_url="https://www.edpb.europa.eu/system/files/2023-05/edpb_guidelines_042022_calculationofadministrativefines_en.pdf", - description="EDPB-Leitlinien zur Berechnung von Verwaltungsbussgeldern unter der DSGVO.", - language="en", - requirement_count=15, - ), -] - - class LegalCorpusIngestion: """Handles ingestion of legal documents into Qdrant.""" @@ -710,156 +119,24 @@ class LegalCorpusIngestion: logger.error(f"Embedding generation failed: {e}") raise - # German abbreviations that don't end sentences - GERMAN_ABBREVIATIONS = { - 'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max', - 'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw', - 'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'aufl', - 'bd', 'betr', 'bzgl', 'dgl', 'ebd', 'hrsg', 'jg', 'kap', 'lt', - 'rdnr', 'rn', 'std', 'str', 'tel', 'ua', 'uvm', 'va', 'zb', - 'bsi', 'tr', 'owasp', 'iso', 'iec', 'din', 'en' - } + # Delegate chunking/text methods to legal_corpus_chunking module + # Keep as instance methods for backward compatibility + GERMAN_ABBREVIATIONS = GERMAN_ABBREVIATIONS def _split_into_sentences(self, text: str) -> List[str]: - """Split text into sentences with German language support.""" - if not text: - return [] - - text = re.sub(r'\s+', ' ', text).strip() - - # Protect abbreviations - protected_text = text - for abbrev in self.GERMAN_ABBREVIATIONS: - pattern = re.compile(r'\b' + re.escape(abbrev) + r'\.', re.IGNORECASE) - protected_text = pattern.sub(abbrev.replace('.', '') + '', protected_text) - - # Protect decimal/ordinal numbers and requirement IDs (e.g., "O.Data_1") - protected_text = re.sub(r'(\d)\.(\d)', r'\1\2', protected_text) - protected_text = re.sub(r'(\d+)\.(\s)', r'\1\2', protected_text) - protected_text = re.sub(r'([A-Z])\.([A-Z])', r'\1\2', protected_text) # O.Data_1 - - # Split on sentence endings - sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜ0-9])|(?<=[.!?])$' - raw_sentences = re.split(sentence_pattern, protected_text) - - # Restore protected characters - sentences = [] - for s in raw_sentences: - s = s.replace('', '.').replace('', '.').replace('', '.').replace('', '.').replace('', '.') - s = s.strip() - if s: - sentences.append(s) - - return sentences + return split_into_sentences(text) def _split_into_paragraphs(self, text: str) -> List[str]: - """Split text into paragraphs.""" - if not text: - return [] + return split_into_paragraphs(text) - raw_paragraphs = re.split(r'\n\s*\n', text) - return [para.strip() for para in raw_paragraphs if para.strip()] + def _chunk_text_semantic(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP): + return chunk_text_semantic(text, chunk_size, overlap) - def _chunk_text_semantic(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[Tuple[str, int]]: - """ - Semantic chunking that respects paragraph and sentence boundaries. - Matches NIBIS chunking strategy for consistency. + def _extract_article_info(self, text: str): + return extract_article_info(text) - Returns list of (chunk_text, start_position) tuples. - """ - if not text: - return [] - - if len(text) <= chunk_size: - return [(text.strip(), 0)] - - paragraphs = self._split_into_paragraphs(text) - overlap_sentences = max(1, overlap // 100) # Convert char overlap to sentence overlap - - chunks = [] - current_chunk_parts = [] - current_chunk_length = 0 - chunk_start = 0 - position = 0 - - for para in paragraphs: - if len(para) > chunk_size: - # Large paragraph: split into sentences - sentences = self._split_into_sentences(para) - - for sentence in sentences: - sentence_len = len(sentence) - - if sentence_len > chunk_size: - # Very long sentence: save current chunk first - if current_chunk_parts: - chunk_text = ' '.join(current_chunk_parts) - chunks.append((chunk_text, chunk_start)) - overlap_buffer = current_chunk_parts[-overlap_sentences:] if overlap_sentences > 0 else [] - current_chunk_parts = list(overlap_buffer) - current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts) - - # Add long sentence as its own chunk - chunks.append((sentence, position)) - current_chunk_parts = [sentence] - current_chunk_length = len(sentence) + 1 - position += sentence_len + 1 - continue - - if current_chunk_length + sentence_len + 1 > chunk_size and current_chunk_parts: - # Current chunk is full, save it - chunk_text = ' '.join(current_chunk_parts) - chunks.append((chunk_text, chunk_start)) - overlap_buffer = current_chunk_parts[-overlap_sentences:] if overlap_sentences > 0 else [] - current_chunk_parts = list(overlap_buffer) - current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts) - chunk_start = position - current_chunk_length - - current_chunk_parts.append(sentence) - current_chunk_length += sentence_len + 1 - position += sentence_len + 1 - else: - # Small paragraph: try to keep together - para_len = len(para) - if current_chunk_length + para_len + 2 > chunk_size and current_chunk_parts: - chunk_text = ' '.join(current_chunk_parts) - chunks.append((chunk_text, chunk_start)) - last_para_sentences = self._split_into_sentences(current_chunk_parts[-1] if current_chunk_parts else "") - overlap_buffer = last_para_sentences[-overlap_sentences:] if overlap_sentences > 0 and last_para_sentences else [] - current_chunk_parts = list(overlap_buffer) - current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts) - chunk_start = position - current_chunk_length - - if current_chunk_parts: - current_chunk_parts.append(para) - current_chunk_length += para_len + 2 - else: - current_chunk_parts = [para] - current_chunk_length = para_len - chunk_start = position - - position += para_len + 2 - - # Don't forget the last chunk - if current_chunk_parts: - chunk_text = ' '.join(current_chunk_parts) - chunks.append((chunk_text, chunk_start)) - - # Clean up whitespace - return [(re.sub(r'\s+', ' ', c).strip(), pos) for c, pos in chunks if c.strip()] - - def _extract_article_info(self, text: str) -> Optional[Dict]: - """Extract article number and paragraph from text.""" - # Pattern for "Artikel X" or "Art. X" - article_match = re.search(r'(?:Artikel|Art\.?)\s+(\d+)', text) - paragraph_match = re.search(r'(?:Absatz|Abs\.?)\s+(\d+)', text) - - if article_match: - return { - "article": article_match.group(1), - "paragraph": paragraph_match.group(1) if paragraph_match else None, - } - return None + def _html_to_text(self, html_content: str) -> str: + return html_to_text(html_content) async def _fetch_document_text(self, regulation: Regulation) -> Optional[str]: """ @@ -880,7 +157,6 @@ class LegalCorpusIngestion: if local_pdf.exists(): logger.info(f"Extracting text from PDF: {local_pdf}") try: - # Use embedding service for PDF extraction response = await self.http_client.post( f"{EMBEDDING_SERVICE_URL}/extract-pdf", files={"file": open(local_pdf, "rb")}, @@ -892,7 +168,7 @@ class LegalCorpusIngestion: except Exception as e: logger.error(f"PDF extraction failed for {regulation.code}: {e}") - # Try EUR-Lex CELEX URL if available (bypasses JavaScript CAPTCHA) + # Try EUR-Lex CELEX URL if available if regulation.celex: celex_url = f"https://eur-lex.europa.eu/legal-content/DE/TXT/HTML/?uri=CELEX:{regulation.celex}" logger.info(f"Fetching {regulation.code} from EUR-Lex CELEX: {celex_url}") @@ -911,7 +187,6 @@ class LegalCorpusIngestion: html_content = response.text - # Check if we got actual content, not a CAPTCHA page if "verify that you're not a robot" not in html_content and len(html_content) > 10000: text = self._html_to_text(html_content) if text and len(text) > 1000: @@ -927,7 +202,6 @@ class LegalCorpusIngestion: # Fallback to original source URL logger.info(f"Fetching {regulation.code} from: {regulation.source_url}") try: - # Check if source URL is a PDF (handle URLs with query parameters) parsed_url = urlparse(regulation.source_url) is_pdf_url = parsed_url.path.lower().endswith('.pdf') if is_pdf_url: @@ -943,7 +217,6 @@ class LegalCorpusIngestion: ) response.raise_for_status() - # Extract text from PDF via embedding service pdf_content = response.content extract_response = await self.http_client.post( f"{EMBEDDING_SERVICE_URL}/extract-pdf", @@ -960,7 +233,6 @@ class LegalCorpusIngestion: logger.warning(f"PDF extraction returned empty text for {regulation.code}") return None else: - # Regular HTML fetch response = await self.http_client.get( regulation.source_url, follow_redirects=True, @@ -979,55 +251,21 @@ class LegalCorpusIngestion: logger.error(f"Failed to fetch {regulation.code}: {e}") return None - def _html_to_text(self, html_content: str) -> str: - """Convert HTML to clean text.""" - # Remove script and style tags - html_content = re.sub(r']*>.*?', '', html_content, flags=re.DOTALL) - html_content = re.sub(r']*>.*?', '', html_content, flags=re.DOTALL) - # Remove comments - html_content = re.sub(r'', '', html_content, flags=re.DOTALL) - # Replace common HTML entities - html_content = html_content.replace(' ', ' ') - html_content = html_content.replace('&', '&') - html_content = html_content.replace('<', '<') - html_content = html_content.replace('>', '>') - html_content = html_content.replace('"', '"') - # Convert breaks and paragraphs to newlines for better chunking - html_content = re.sub(r'', '\n', html_content, flags=re.IGNORECASE) - html_content = re.sub(r'

', '\n\n', html_content, flags=re.IGNORECASE) - html_content = re.sub(r'', '\n', html_content, flags=re.IGNORECASE) - html_content = re.sub(r'', '\n\n', html_content, flags=re.IGNORECASE) - # Remove remaining HTML tags - text = re.sub(r'<[^>]+>', ' ', html_content) - # Clean up whitespace (but preserve paragraph breaks) - text = re.sub(r'[ \t]+', ' ', text) - text = re.sub(r'\n[ \t]+', '\n', text) - text = re.sub(r'[ \t]+\n', '\n', text) - text = re.sub(r'\n{3,}', '\n\n', text) - return text.strip() - async def ingest_regulation(self, regulation: Regulation) -> int: - """ - Ingest a single regulation into Qdrant. - - Returns number of chunks indexed. - """ + """Ingest a single regulation into Qdrant. Returns number of chunks indexed.""" logger.info(f"Ingesting {regulation.code}: {regulation.name}") - # Fetch document text text = await self._fetch_document_text(regulation) if not text or len(text) < 100: logger.warning(f"No text found for {regulation.code}, skipping") return 0 - # Chunk the text chunks = self._chunk_text_semantic(text) logger.info(f"Created {len(chunks)} chunks for {regulation.code}") if not chunks: return 0 - # Generate embeddings in batches (very small for CPU stability) batch_size = 4 all_points = [] max_retries = 3 @@ -1036,7 +274,6 @@ class LegalCorpusIngestion: batch_chunks = chunks[i:i + batch_size] chunk_texts = [c[0] for c in batch_chunks] - # Retry logic for embedding service stability embeddings = None for retry in range(max_retries): try: @@ -1045,21 +282,19 @@ class LegalCorpusIngestion: except Exception as e: logger.warning(f"Embedding attempt {retry+1}/{max_retries} failed for batch {i//batch_size}: {e}") if retry < max_retries - 1: - await asyncio.sleep(3 * (retry + 1)) # Longer backoff: 3s, 6s, 9s + await asyncio.sleep(3 * (retry + 1)) else: logger.error(f"Embedding failed permanently for batch {i//batch_size}") if embeddings is None: continue - # Longer delay between batches for CPU stability await asyncio.sleep(1.5) for j, ((chunk_text, position), embedding) in enumerate(zip(batch_chunks, embeddings)): chunk_idx = i + j point_id = hashlib.md5(f"{regulation.code}-{chunk_idx}".encode()).hexdigest() - # Extract article info if present article_info = self._extract_article_info(chunk_text) point = PointStruct( @@ -1078,12 +313,11 @@ class LegalCorpusIngestion: "paragraph": article_info.get("paragraph") if article_info else None, "language": regulation.language, "indexed_at": datetime.utcnow().isoformat(), - "training_allowed": False, # Legal texts - no training + "training_allowed": False, }, ) all_points.append(point) - # Upsert to Qdrant if all_points: self.qdrant.upsert( collection_name=LEGAL_CORPUS_COLLECTION, @@ -1135,7 +369,6 @@ class LegalCorpusIngestion: try: collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION) - # Count points per regulation regulation_counts = {} for reg in REGULATIONS: result = self.qdrant.count( @@ -1171,22 +404,10 @@ class LegalCorpusIngestion: regulation_codes: Optional[List[str]] = None, top_k: int = 5, ) -> List[Dict]: - """ - Search the legal corpus for relevant passages. - - Args: - query: Search query text - regulation_codes: Optional list of regulation codes to filter - top_k: Number of results to return - - Returns: - List of search results with text and metadata - """ - # Generate query embedding + """Search the legal corpus for relevant passages.""" embeddings = await self._generate_embeddings([query]) query_vector = embeddings[0] - # Build filter search_filter = None if regulation_codes: search_filter = Filter( @@ -1199,7 +420,6 @@ class LegalCorpusIngestion: ] ) - # Search results = self.qdrant.search( collection_name=LEGAL_CORPUS_COLLECTION, query_vector=query_vector, diff --git a/klausur-service/backend/legal_corpus_registry.py b/klausur-service/backend/legal_corpus_registry.py new file mode 100644 index 0000000..b48e54c --- /dev/null +++ b/klausur-service/backend/legal_corpus_registry.py @@ -0,0 +1,608 @@ +""" +Legal Corpus Registry — Regulation metadata and definitions. + +Pure data module: contains the Regulation dataclass and the REGULATIONS list +with all EU regulations, DACH national laws, and EDPB guidelines. +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class Regulation: + """Regulation metadata.""" + code: str + name: str + full_name: str + regulation_type: str + source_url: str + description: str + celex: Optional[str] = None # CELEX number for EUR-Lex direct access + local_path: Optional[str] = None + language: str = "de" + requirement_count: int = 0 + + +# All regulations from Compliance Hub (EU + DACH national laws + guidelines) +REGULATIONS: List[Regulation] = [ + Regulation( + code="GDPR", + name="DSGVO", + full_name="Verordnung (EU) 2016/679 - Datenschutz-Grundverordnung", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2016/679/oj/deu", + description="Grundverordnung zum Schutz natuerlicher Personen bei der Verarbeitung personenbezogener Daten.", + celex="32016R0679", + requirement_count=99, + ), + Regulation( + code="EPRIVACY", + name="ePrivacy-Richtlinie", + full_name="Richtlinie 2002/58/EG", + regulation_type="eu_directive", + source_url="https://eur-lex.europa.eu/eli/dir/2002/58/oj/deu", + description="Datenschutz in der elektronischen Kommunikation, Cookies und Tracking.", + celex="32002L0058", + requirement_count=25, + ), + Regulation( + code="TDDDG", + name="TDDDG", + full_name="Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/ttdsg/TDDDG.pdf", + description="Deutsche Umsetzung der ePrivacy-Richtlinie (30 Paragraphen).", + requirement_count=30, + ), + Regulation( + code="SCC", + name="Standardvertragsklauseln", + full_name="Durchfuehrungsbeschluss (EU) 2021/914", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/dec_impl/2021/914/oj/deu", + description="Standardvertragsklauseln fuer Drittlandtransfers.", + celex="32021D0914", + requirement_count=18, + ), + Regulation( + code="DPF", + name="EU-US Data Privacy Framework", + full_name="Durchfuehrungsbeschluss (EU) 2023/1795", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/dec_impl/2023/1795/oj", + description="Angemessenheitsbeschluss fuer USA-Transfers.", + celex="32023D1795", + requirement_count=12, + ), + Regulation( + code="AIACT", + name="EU AI Act", + full_name="Verordnung (EU) 2024/1689 - KI-Verordnung", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2024/1689/oj/deu", + description="EU-Verordnung zur Regulierung von KI-Systemen nach Risikostufen.", + celex="32024R1689", + requirement_count=85, + ), + Regulation( + code="CRA", + name="Cyber Resilience Act", + full_name="Verordnung (EU) 2024/2847", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2024/2847/oj/deu", + description="Cybersicherheitsanforderungen, SBOM-Pflicht.", + celex="32024R2847", + requirement_count=45, + ), + Regulation( + code="NIS2", + name="NIS2-Richtlinie", + full_name="Richtlinie (EU) 2022/2555", + regulation_type="eu_directive", + source_url="https://eur-lex.europa.eu/eli/dir/2022/2555/oj/deu", + description="Cybersicherheit fuer wesentliche Einrichtungen.", + celex="32022L2555", + requirement_count=46, + ), + Regulation( + code="EUCSA", + name="EU Cybersecurity Act", + full_name="Verordnung (EU) 2019/881", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2019/881/oj/deu", + description="ENISA und Cybersicherheitszertifizierung.", + celex="32019R0881", + requirement_count=35, + ), + Regulation( + code="DATAACT", + name="Data Act", + full_name="Verordnung (EU) 2023/2854", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2023/2854/oj/deu", + description="Fairer Datenzugang, IoT-Daten, Cloud-Wechsel.", + celex="32023R2854", + requirement_count=42, + ), + Regulation( + code="DGA", + name="Data Governance Act", + full_name="Verordnung (EU) 2022/868", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2022/868/oj/deu", + description="Weiterverwendung oeffentlicher Daten.", + celex="32022R0868", + requirement_count=35, + ), + Regulation( + code="DSA", + name="Digital Services Act", + full_name="Verordnung (EU) 2022/2065", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2022/2065/oj/deu", + description="Digitale Dienste, Transparenzpflichten.", + celex="32022R2065", + requirement_count=93, + ), + Regulation( + code="EAA", + name="European Accessibility Act", + full_name="Richtlinie (EU) 2019/882", + regulation_type="eu_directive", + source_url="https://eur-lex.europa.eu/eli/dir/2019/882/oj/deu", + description="Barrierefreiheit digitaler Produkte.", + celex="32019L0882", + requirement_count=25, + ), + Regulation( + code="DSM", + name="DSM-Urheberrechtsrichtlinie", + full_name="Richtlinie (EU) 2019/790", + regulation_type="eu_directive", + source_url="https://eur-lex.europa.eu/eli/dir/2019/790/oj/deu", + description="Urheberrecht, Text- und Data-Mining.", + celex="32019L0790", + requirement_count=22, + ), + Regulation( + code="PLD", + name="Produkthaftungsrichtlinie", + full_name="Richtlinie (EU) 2024/2853", + regulation_type="eu_directive", + source_url="https://eur-lex.europa.eu/eli/dir/2024/2853/oj/deu", + description="Produkthaftung inkl. Software und KI.", + celex="32024L2853", + requirement_count=18, + ), + Regulation( + code="GPSR", + name="General Product Safety", + full_name="Verordnung (EU) 2023/988", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2023/988/oj/deu", + description="Allgemeine Produktsicherheit.", + celex="32023R0988", + requirement_count=30, + ), + Regulation( + code="BSI-TR-03161-1", + name="BSI-TR-03161 Teil 1", + full_name="BSI Technische Richtlinie - Allgemeine Anforderungen", + regulation_type="bsi_standard", + source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-1.pdf?__blob=publicationFile&v=6", + description="Allgemeine Sicherheitsanforderungen (45 Pruefaspekte).", + requirement_count=45, + ), + Regulation( + code="BSI-TR-03161-2", + name="BSI-TR-03161 Teil 2", + full_name="BSI Technische Richtlinie - Web-Anwendungen", + regulation_type="bsi_standard", + source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-2.pdf?__blob=publicationFile&v=5", + description="Web-Sicherheit (40 Pruefaspekte).", + requirement_count=40, + ), + Regulation( + code="BSI-TR-03161-3", + name="BSI-TR-03161 Teil 3", + full_name="BSI Technische Richtlinie - Hintergrundsysteme", + regulation_type="bsi_standard", + source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-3.pdf?__blob=publicationFile&v=5", + description="Backend-Sicherheit (35 Pruefaspekte).", + requirement_count=35, + ), + # Additional regulations for financial sector and health + Regulation( + code="DORA", + name="DORA", + full_name="Verordnung (EU) 2022/2554 - Digital Operational Resilience Act", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2022/2554/oj/deu", + description="Digitale operationale Resilienz fuer den Finanzsektor. IKT-Risikomanagement, Vorfallmeldung, Resilienz-Tests.", + celex="32022R2554", + requirement_count=64, + ), + Regulation( + code="PSD2", + name="PSD2", + full_name="Richtlinie (EU) 2015/2366 - Zahlungsdiensterichtlinie", + regulation_type="eu_directive", + source_url="https://eur-lex.europa.eu/eli/dir/2015/2366/oj/deu", + description="Zahlungsdienste im Binnenmarkt. Starke Kundenauthentifizierung, Open Banking APIs.", + celex="32015L2366", + requirement_count=117, + ), + Regulation( + code="AMLR", + name="AML-Verordnung", + full_name="Verordnung (EU) 2024/1624 - Geldwaeschebekaempfung", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2024/1624/oj/deu", + description="Verhinderung der Nutzung des Finanzsystems zur Geldwaesche und Terrorismusfinanzierung.", + celex="32024R1624", + requirement_count=89, + ), + Regulation( + code="EHDS", + name="EHDS", + full_name="Verordnung (EU) 2025/327 - Europaeischer Gesundheitsdatenraum", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2025/327/oj/deu", + description="Europaeischer Raum fuer Gesundheitsdaten. Primaer- und Sekundaernutzung von Gesundheitsdaten.", + celex="32025R0327", + requirement_count=95, + ), + Regulation( + code="MiCA", + name="MiCA", + full_name="Verordnung (EU) 2023/1114 - Markets in Crypto-Assets", + regulation_type="eu_regulation", + source_url="https://eur-lex.europa.eu/eli/reg/2023/1114/oj/deu", + description="Regulierung von Kryptowerten, Stablecoins und Crypto-Asset-Dienstleistern.", + celex="32023R1114", + requirement_count=149, + ), + # ===================================================================== + # DACH National Laws — Deutschland (P1) + # ===================================================================== + Regulation( + code="DE_DDG", + name="Digitale-Dienste-Gesetz", + full_name="Digitale-Dienste-Gesetz (DDG)", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/ddg/", + description="Deutsches Umsetzungsgesetz zum DSA. Regelt Impressumspflicht (§5), Informationspflichten fuer digitale Dienste und Cookies.", + requirement_count=30, + ), + Regulation( + code="DE_BGB_AGB", + name="BGB AGB-Recht", + full_name="BGB §§305-310, 312-312k — AGB und Fernabsatz", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/bgb/", + description="Deutsches AGB-Recht (§§305-310 BGB) und Fernabsatzrecht (§§312-312k BGB). Klauselverbote, Inhaltskontrolle, Widerrufsrecht, Button-Loesung.", + local_path="DE_BGB_AGB.txt", + requirement_count=40, + ), + Regulation( + code="DE_EGBGB", + name="EGBGB Art. 246-248", + full_name="Einfuehrungsgesetz zum BGB — Informationspflichten", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/bgbeg/", + description="Informationspflichten bei Verbrauchervertraegen (Art. 246), Fernabsatz (Art. 246a), E-Commerce (Art. 246c).", + local_path="DE_EGBGB.txt", + requirement_count=20, + ), + Regulation( + code="DE_UWG", + name="UWG Deutschland", + full_name="Gesetz gegen den unlauteren Wettbewerb (UWG)", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/uwg_2004/", + description="Unlauterer Wettbewerb: irrefuehrende Werbung, Spam-Verbot, Preisangaben, Online-Marketing-Regeln.", + requirement_count=25, + ), + Regulation( + code="DE_HGB_RET", + name="HGB Aufbewahrung", + full_name="HGB §§238-261, 257 — Handelsbuecher und Aufbewahrungsfristen", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/hgb/", + description="Buchfuehrungspflicht, Aufbewahrungsfristen 6/10 Jahre, Anforderungen an elektronische Aufbewahrung.", + local_path="DE_HGB_RET.txt", + requirement_count=15, + ), + Regulation( + code="DE_AO_RET", + name="AO Aufbewahrung", + full_name="Abgabenordnung §§140-148 — Steuerliche Aufbewahrungspflichten", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/ao_1977/", + description="Steuerliche Buchfuehrungs- und Aufbewahrungspflichten. 6/10 Jahre Fristen, Datenzugriff durch Finanzbehoerden.", + local_path="DE_AO_RET.txt", + requirement_count=12, + ), + Regulation( + code="DE_TKG", + name="TKG 2021", + full_name="Telekommunikationsgesetz 2021", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/tkg_2021/", + description="Telekommunikationsregulierung: Kundenschutz, Datenschutz, Vertragslaufzeiten, Netzinfrastruktur.", + requirement_count=45, + ), + # ===================================================================== + # DACH National Laws — Oesterreich (P1) + # ===================================================================== + Regulation( + code="AT_ECG", + name="E-Commerce-Gesetz AT", + full_name="E-Commerce-Gesetz (ECG) Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20001703", + description="Oesterreichisches E-Commerce-Gesetz: Impressum/Offenlegungspflicht (§5), Informationspflichten, Haftung von Diensteanbietern.", + language="de", + requirement_count=30, + ), + Regulation( + code="AT_TKG", + name="TKG 2021 AT", + full_name="Telekommunikationsgesetz 2021 Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20011678", + description="Oesterreichisches TKG: Cookie-Bestimmungen (§165), Kommunikationsgeheimnis, Endgeraetezugriff.", + language="de", + requirement_count=40, + ), + Regulation( + code="AT_KSCHG", + name="KSchG Oesterreich", + full_name="Konsumentenschutzgesetz (KSchG) Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10002462", + description="Konsumentenschutz: AGB-Kontrolle (§6 Klauselverbote, §9 Verbandsklage), Ruecktrittsrecht, Informationspflichten.", + language="de", + requirement_count=35, + ), + Regulation( + code="AT_FAGG", + name="FAGG Oesterreich", + full_name="Fern- und Auswaertsgeschaefte-Gesetz (FAGG) Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20008847", + description="Fernabsatzrecht: Informationspflichten, Widerrufsrecht 14 Tage, Button-Loesung, Ausnahmen.", + language="de", + requirement_count=20, + ), + Regulation( + code="AT_UGB_RET", + name="UGB Aufbewahrung AT", + full_name="UGB §§189-216, 212 — Rechnungslegung und Aufbewahrung Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10001702", + description="Oesterreichische Rechnungslegungspflicht und Aufbewahrungsfristen (7 Jahre). Buchfuehrung, Jahresabschluss.", + local_path="AT_UGB_RET.txt", + language="de", + requirement_count=15, + ), + Regulation( + code="AT_BAO_RET", + name="BAO §132 AT", + full_name="Bundesabgabenordnung §132 — Aufbewahrung Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10003940", + description="Steuerliche Aufbewahrungspflicht 7 Jahre fuer Buecher, Aufzeichnungen und Belege. Grundstuecke 22 Jahre.", + language="de", + requirement_count=5, + ), + Regulation( + code="AT_MEDIENG", + name="MedienG §§24-25 AT", + full_name="Mediengesetz §§24-25 Oesterreich — Impressum und Offenlegung", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10000719", + description="Impressum/Offenlegungspflicht fuer periodische Medien und Websites in Oesterreich.", + language="de", + requirement_count=10, + ), + # ===================================================================== + # DACH National Laws — Schweiz (P1) + # ===================================================================== + Regulation( + code="CH_DSV", + name="DSV Schweiz", + full_name="Datenschutzverordnung (DSV) Schweiz — SR 235.11", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/2022/568/de", + description="Ausfuehrungsverordnung zum revDSG: Meldepflichten, DSFA-Verfahren, Auslandtransfers, technische Massnahmen.", + language="de", + requirement_count=30, + ), + Regulation( + code="CH_OR_AGB", + name="OR AGB/Aufbewahrung CH", + full_name="Obligationenrecht — AGB-Kontrolle und Aufbewahrung Schweiz (SR 220)", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/27/317_321_377/de", + description="Art. 8 OR (AGB-Inhaltskontrolle), Art. 19/20 (Vertragsfreiheit), Art. 957-958f (Buchfuehrung, 10 Jahre Aufbewahrung).", + local_path="CH_OR_AGB.txt", + language="de", + requirement_count=20, + ), + Regulation( + code="CH_UWG", + name="UWG Schweiz", + full_name="Bundesgesetz gegen den unlauteren Wettbewerb Schweiz (SR 241)", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/1988/223_223_223/de", + description="Lauterkeitsrecht: Impressumspflicht, irrefuehrende Werbung, aggressive Verkaufsmethoden, AGB-Transparenz.", + language="de", + requirement_count=20, + ), + Regulation( + code="CH_FMG", + name="FMG Schweiz", + full_name="Fernmeldegesetz Schweiz (SR 784.10)", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/1997/2187_2187_2187/de", + description="Telekommunikationsregulierung: Fernmeldegeheimnis, Cookies/Tracking (Art. 45c), Spam-Verbot, Datenschutz.", + language="de", + requirement_count=25, + ), + # ===================================================================== + # Deutschland P2 + # ===================================================================== + Regulation( + code="DE_PANGV", + name="PAngV", + full_name="Preisangabenverordnung (PAngV 2022)", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/pangv_2022/", + description="Preisangaben: Gesamtpreis, Grundpreis, Streichpreise (§11), Online-Preisauszeichnung.", + requirement_count=15, + ), + Regulation( + code="DE_DLINFOV", + name="DL-InfoV", + full_name="Dienstleistungs-Informationspflichten-Verordnung", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/dlinfov/", + description="Informationspflichten fuer Dienstleister: Identitaet, Kontakt, Berufshaftpflicht, AGB-Zugang.", + requirement_count=10, + ), + Regulation( + code="DE_BETRVG", + name="BetrVG §87", + full_name="Betriebsverfassungsgesetz §87 Abs.1 Nr.6", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/betrvg/", + description="Mitbestimmung bei technischer Ueberwachung: Betriebsrat-Beteiligung bei IT-Systemen, die Arbeitnehmerverhalten ueberwachen koennen.", + requirement_count=5, + ), + # ===================================================================== + # Oesterreich P2 + # ===================================================================== + Regulation( + code="AT_ABGB_AGB", + name="ABGB AGB-Recht AT", + full_name="ABGB §§861-879, 864a — AGB-Kontrolle Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10001622", + description="Geltungskontrolle (§864a), Sittenwidrigkeitskontrolle (§879 Abs.3), allgemeine Vertragsregeln.", + local_path="AT_ABGB_AGB.txt", + language="de", + requirement_count=10, + ), + Regulation( + code="AT_UWG", + name="UWG Oesterreich", + full_name="Bundesgesetz gegen den unlauteren Wettbewerb Oesterreich", + regulation_type="at_law", + source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10002665", + description="Lauterkeitsrecht AT: irrefuehrende Geschaeftspraktiken, aggressive Praktiken, Preisauszeichnung.", + language="de", + requirement_count=15, + ), + # ===================================================================== + # Schweiz P2 + # ===================================================================== + Regulation( + code="CH_GEBUV", + name="GeBuV Schweiz", + full_name="Geschaeftsbuecher-Verordnung Schweiz (SR 221.431)", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/2002/468_468_468/de", + description="Ausfuehrungsvorschriften zur Buchfuehrung: elektronische Aufbewahrung, Integritaet, Datentraeger.", + language="de", + requirement_count=10, + ), + Regulation( + code="CH_ZERTES", + name="ZertES Schweiz", + full_name="Bundesgesetz ueber die elektronische Signatur (SR 943.03)", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/2016/752/de", + description="Elektronische Signatur und Zertifizierung: Qualifizierte Signaturen, Zertifizierungsdiensteanbieter.", + language="de", + requirement_count=10, + ), + # ===================================================================== + # Deutschland P3 + # ===================================================================== + Regulation( + code="DE_GESCHGEHG", + name="GeschGehG", + full_name="Gesetz zum Schutz von Geschaeftsgeheimnissen", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/geschgehg/", + description="Schutz von Geschaeftsgeheimnissen: Definition, angemessene Geheimhaltungsmassnahmen, Reverse Engineering.", + requirement_count=10, + ), + Regulation( + code="DE_BSIG", + name="BSI-Gesetz", + full_name="Gesetz ueber das Bundesamt fuer Sicherheit in der Informationstechnik (BSIG)", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/bsig_2009/", + description="BSI-Aufgaben, KRITIS-Meldepflichten, IT-Sicherheitsstandards, Zertifizierung.", + requirement_count=20, + ), + Regulation( + code="DE_USTG_RET", + name="UStG §14b", + full_name="Umsatzsteuergesetz §14b — Aufbewahrung von Rechnungen", + regulation_type="de_law", + source_url="https://www.gesetze-im-internet.de/ustg_1980/", + description="Aufbewahrungspflicht fuer Rechnungen: 10 Jahre, Grundstuecke 20 Jahre, elektronische Aufbewahrung.", + local_path="DE_USTG_RET.txt", + requirement_count=5, + ), + # ===================================================================== + # Schweiz P3 + # ===================================================================== + Regulation( + code="CH_ZGB_PERS", + name="ZGB Persoenlichkeitsschutz CH", + full_name="Zivilgesetzbuch Art. 28-28l — Persoenlichkeitsschutz Schweiz (SR 210)", + regulation_type="ch_law", + source_url="https://www.fedlex.admin.ch/eli/cc/24/233_245_233/de", + description="Persoenlichkeitsschutz: Recht am eigenen Bild, Schutz der Privatsphaere, Gegendarstellungsrecht.", + language="de", + requirement_count=8, + ), + # ===================================================================== + # 3 fehlgeschlagene Quellen mit alternativen URLs nachholen + # ===================================================================== + Regulation( + code="LU_DPA_LAW", + name="Datenschutzgesetz Luxemburg", + full_name="Loi du 1er aout 2018 — Datenschutzgesetz Luxemburg", + regulation_type="national_law", + source_url="https://legilux.public.lu/eli/etat/leg/loi/2018/08/01/a686/jo", + description="Luxemburgisches Datenschutzgesetz: Organisation der CNPD, nationale DSGVO-Ergaenzung.", + language="fr", + requirement_count=40, + ), + Regulation( + code="DK_DATABESKYTTELSESLOVEN", + name="Databeskyttelsesloven DK", + full_name="Databeskyttelsesloven — Datenschutzgesetz Daenemark", + regulation_type="national_law", + source_url="https://www.retsinformation.dk/eli/lta/2018/502", + description="Daenisches Datenschutzgesetz als ergaenzende Bestimmungen zur DSGVO. Reguliert durch Datatilsynet.", + language="da", + requirement_count=30, + ), + Regulation( + code="EDPB_GUIDELINES_1_2022", + name="EDPB GL Bussgelder", + full_name="EDPB Leitlinien 04/2022 zur Berechnung von Bussgeldern nach der DSGVO", + regulation_type="eu_guideline", + source_url="https://www.edpb.europa.eu/system/files/2023-05/edpb_guidelines_042022_calculationofadministrativefines_en.pdf", + description="EDPB-Leitlinien zur Berechnung von Verwaltungsbussgeldern unter der DSGVO.", + language="en", + requirement_count=15, + ), +] diff --git a/klausur-service/backend/worksheet_editor_ai.py b/klausur-service/backend/worksheet_editor_ai.py new file mode 100644 index 0000000..a6cb56e --- /dev/null +++ b/klausur-service/backend/worksheet_editor_ai.py @@ -0,0 +1,485 @@ +""" +Worksheet Editor AI — AI image generation and AI worksheet modification. +""" + +import io +import json +import base64 +import logging +import re +import time +import random +from typing import List, Dict + +import httpx + +from worksheet_editor_models import ( + AIImageRequest, + AIImageResponse, + AIImageStyle, + AIModifyRequest, + AIModifyResponse, + OLLAMA_URL, + STYLE_PROMPTS, +) + +logger = logging.getLogger(__name__) + + +# ============================================= +# AI IMAGE GENERATION +# ============================================= + +async def generate_ai_image_logic(request: AIImageRequest) -> AIImageResponse: + """ + Generate an AI image using Ollama with a text-to-image model. + + Falls back to a placeholder if Ollama is not available. + """ + from fastapi import HTTPException + + try: + # Build enhanced prompt with style + style_modifier = STYLE_PROMPTS.get(request.style, "") + enhanced_prompt = f"{request.prompt}, {style_modifier}" + + logger.info(f"Generating AI image: {enhanced_prompt[:100]}...") + + # Check if Ollama is available + async with httpx.AsyncClient(timeout=10.0) as check_client: + try: + health_response = await check_client.get(f"{OLLAMA_URL}/api/tags") + if health_response.status_code != 200: + raise HTTPException(status_code=503, detail="Ollama service not available") + except httpx.ConnectError: + logger.warning("Ollama not reachable, returning placeholder") + return _generate_placeholder_image(request, enhanced_prompt) + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + tags_response = await client.get(f"{OLLAMA_URL}/api/tags") + available_models = [m.get("name", "") for m in tags_response.json().get("models", [])] + + sd_model = None + for model in available_models: + if "stable" in model.lower() or "sd" in model.lower() or "diffusion" in model.lower(): + sd_model = model + break + + if not sd_model: + logger.warning("No Stable Diffusion model found in Ollama") + return _generate_placeholder_image(request, enhanced_prompt) + + logger.info(f"SD model found: {sd_model}, but image generation API not implemented") + return _generate_placeholder_image(request, enhanced_prompt) + + except Exception as e: + logger.error(f"Image generation failed: {e}") + return _generate_placeholder_image(request, enhanced_prompt) + + except HTTPException: + raise + except Exception as e: + logger.error(f"AI image generation error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +def _generate_placeholder_image(request: AIImageRequest, prompt: str) -> AIImageResponse: + """ + Generate a placeholder image when AI generation is not available. + Creates a simple SVG-based placeholder with the prompt text. + """ + from PIL import Image, ImageDraw, ImageFont + + width, height = request.width, request.height + + style_colors = { + AIImageStyle.REALISTIC: ("#2563eb", "#dbeafe"), + AIImageStyle.CARTOON: ("#f97316", "#ffedd5"), + AIImageStyle.SKETCH: ("#6b7280", "#f3f4f6"), + AIImageStyle.CLIPART: ("#8b5cf6", "#ede9fe"), + AIImageStyle.EDUCATIONAL: ("#059669", "#d1fae5"), + } + + fg_color, bg_color = style_colors.get(request.style, ("#6366f1", "#e0e7ff")) + + img = Image.new('RGB', (width, height), bg_color) + draw = ImageDraw.Draw(img) + + draw.rectangle([5, 5, width-6, height-6], outline=fg_color, width=3) + + cx, cy = width // 2, height // 2 - 30 + draw.ellipse([cx-40, cy-40, cx+40, cy+40], outline=fg_color, width=3) + draw.line([cx-20, cy-10, cx+20, cy-10], fill=fg_color, width=3) + draw.line([cx, cy-10, cx, cy+20], fill=fg_color, width=3) + + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) + except Exception: + font = ImageFont.load_default() + + max_chars = 40 + lines = [] + words = prompt[:200].split() + current_line = "" + for word in words: + if len(current_line) + len(word) + 1 <= max_chars: + current_line += (" " + word if current_line else word) + else: + if current_line: + lines.append(current_line) + current_line = word + if current_line: + lines.append(current_line) + + text_y = cy + 60 + for line in lines[:4]: + bbox = draw.textbbox((0, 0), line, font=font) + text_width = bbox[2] - bbox[0] + draw.text((cx - text_width // 2, text_y), line, fill=fg_color, font=font) + text_y += 20 + + badge_text = "KI-Bild (Platzhalter)" + try: + badge_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10) + except Exception: + badge_font = font + draw.rectangle([10, height-30, 150, height-10], fill=fg_color) + draw.text((15, height-27), badge_text, fill="white", font=badge_font) + + buffer = io.BytesIO() + img.save(buffer, format='PNG') + buffer.seek(0) + + image_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" + + return AIImageResponse( + image_base64=image_base64, + prompt_used=prompt, + error="AI image generation not available. Using placeholder." + ) + + +# ============================================= +# AI WORKSHEET MODIFICATION +# ============================================= + +async def modify_worksheet_with_ai_logic(request: AIModifyRequest) -> AIModifyResponse: + """ + Modify a worksheet using AI based on natural language prompt. + """ + try: + logger.info(f"AI modify request: {request.prompt[:100]}...") + + try: + canvas_data = json.loads(request.canvas_json) + except json.JSONDecodeError: + return AIModifyResponse( + message="Fehler beim Parsen des Canvas", + error="Invalid canvas JSON" + ) + + system_prompt = """Du bist ein Assistent fuer die Bearbeitung von Arbeitsblaettern. +Du erhaeltst den aktuellen Zustand eines Canvas im JSON-Format und eine Anweisung des Nutzers. +Deine Aufgabe ist es, die gewuenschten Aenderungen am Canvas vorzunehmen. + +Der Canvas verwendet Fabric.js. Hier sind die wichtigsten Objekttypen: +- i-text: Interaktiver Text mit fontFamily, fontSize, fill, left, top +- rect: Rechteck mit left, top, width, height, fill, stroke, strokeWidth +- circle: Kreis mit left, top, radius, fill, stroke, strokeWidth +- line: Linie mit x1, y1, x2, y2, stroke, strokeWidth + +Das Canvas ist 794x1123 Pixel (A4 bei 96 DPI). + +Antworte NUR mit einem JSON-Objekt in diesem Format: +{ + "action": "modify" oder "add" oder "delete" oder "info", + "objects": [...], // Neue/modifizierte Objekte (bei modify/add) + "message": "Kurze Beschreibung der Aenderung" +} + +Wenn du Objekte hinzufuegst, generiere eindeutige IDs im Format "obj__". +""" + + user_prompt = f"""Aktueller Canvas-Zustand: +```json +{json.dumps(canvas_data, indent=2)[:5000]} +``` + +Nutzer-Anweisung: {request.prompt} + +Fuehre die Aenderung durch und antworte mit dem JSON-Objekt.""" + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + f"{OLLAMA_URL}/api/generate", + json={ + "model": request.model, + "prompt": user_prompt, + "system": system_prompt, + "stream": False, + "options": { + "temperature": 0.3, + "num_predict": 4096 + } + } + ) + + if response.status_code != 200: + logger.warning(f"Ollama error: {response.status_code}, trying local fallback") + return _handle_simple_modification(request.prompt, canvas_data) + + ai_response = response.json().get("response", "") + + except httpx.ConnectError: + logger.warning("Ollama not reachable") + return _handle_simple_modification(request.prompt, canvas_data) + except httpx.TimeoutException: + logger.warning("Ollama timeout, trying local fallback") + return _handle_simple_modification(request.prompt, canvas_data) + + try: + json_start = ai_response.find('{') + json_end = ai_response.rfind('}') + 1 + + if json_start == -1 or json_end <= json_start: + logger.warning(f"No JSON found in AI response: {ai_response[:200]}") + return AIModifyResponse( + message="KI konnte die Anfrage nicht verarbeiten", + error="No JSON in response" + ) + + ai_json = json.loads(ai_response[json_start:json_end]) + action = ai_json.get("action", "info") + message = ai_json.get("message", "Aenderungen angewendet") + new_objects = ai_json.get("objects", []) + + if action == "info": + return AIModifyResponse(message=message) + + if action == "add" and new_objects: + existing_objects = canvas_data.get("objects", []) + existing_objects.extend(new_objects) + canvas_data["objects"] = existing_objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=message + ) + + if action == "modify" and new_objects: + existing_objects = canvas_data.get("objects", []) + new_ids = {obj.get("id") for obj in new_objects if obj.get("id")} + kept_objects = [obj for obj in existing_objects if obj.get("id") not in new_ids] + kept_objects.extend(new_objects) + canvas_data["objects"] = kept_objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=message + ) + + if action == "delete": + delete_ids = ai_json.get("delete_ids", []) + if delete_ids: + existing_objects = canvas_data.get("objects", []) + canvas_data["objects"] = [obj for obj in existing_objects if obj.get("id") not in delete_ids] + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=message + ) + + return AIModifyResponse(message=message) + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse AI JSON: {e}") + return AIModifyResponse( + message="Fehler beim Verarbeiten der KI-Antwort", + error=str(e) + ) + + except Exception as e: + logger.error(f"AI modify error: {e}") + return AIModifyResponse( + message="Ein unerwarteter Fehler ist aufgetreten", + error=str(e) + ) + + +def _handle_simple_modification(prompt: str, canvas_data: dict) -> AIModifyResponse: + """ + Handle simple modifications locally when Ollama is not available. + Supports basic commands like adding headings, lines, etc. + """ + prompt_lower = prompt.lower() + objects = canvas_data.get("objects", []) + + def generate_id(): + return f"obj_{int(time.time()*1000)}_{random.randint(1000, 9999)}" + + # Add heading + if "ueberschrift" in prompt_lower or "titel" in prompt_lower or "heading" in prompt_lower: + text_match = re.search(r'"([^"]+)"', prompt) + text = text_match.group(1) if text_match else "Ueberschrift" + + new_text = { + "type": "i-text", "id": generate_id(), "text": text, + "left": 397, "top": 50, "originX": "center", + "fontFamily": "Arial", "fontSize": 28, "fontWeight": "bold", "fill": "#000000" + } + objects.append(new_text) + canvas_data["objects"] = objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=f"Ueberschrift '{text}' hinzugefuegt" + ) + + # Add lines for writing + if "linie" in prompt_lower or "line" in prompt_lower or "schreib" in prompt_lower: + num_match = re.search(r'(\d+)', prompt) + num_lines = int(num_match.group(1)) if num_match else 5 + num_lines = min(num_lines, 20) + + start_y = 150 + line_spacing = 40 + + for i in range(num_lines): + new_line = { + "type": "line", "id": generate_id(), + "x1": 60, "y1": start_y + i * line_spacing, + "x2": 734, "y2": start_y + i * line_spacing, + "stroke": "#cccccc", "strokeWidth": 1 + } + objects.append(new_line) + + canvas_data["objects"] = objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=f"{num_lines} Schreiblinien hinzugefuegt" + ) + + # Make text bigger + if "groesser" in prompt_lower or "bigger" in prompt_lower or "larger" in prompt_lower: + modified = 0 + for obj in objects: + if obj.get("type") in ["i-text", "text", "textbox"]: + current_size = obj.get("fontSize", 16) + obj["fontSize"] = int(current_size * 1.25) + modified += 1 + + canvas_data["objects"] = objects + if modified > 0: + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=f"{modified} Texte vergroessert" + ) + + # Center elements + if "zentrier" in prompt_lower or "center" in prompt_lower or "mitte" in prompt_lower: + center_x = 397 + for obj in objects: + if not obj.get("isGrid"): + obj["left"] = center_x + obj["originX"] = "center" + + canvas_data["objects"] = objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message="Elemente zentriert" + ) + + # Add numbering + if "nummer" in prompt_lower or "nummerier" in prompt_lower or "1-10" in prompt_lower: + range_match = re.search(r'(\d+)\s*[-bis]+\s*(\d+)', prompt) + if range_match: + start, end = int(range_match.group(1)), int(range_match.group(2)) + else: + start, end = 1, 10 + + y = 100 + for i in range(start, min(end + 1, start + 20)): + new_text = { + "type": "i-text", "id": generate_id(), "text": f"{i}.", + "left": 40, "top": y, "fontFamily": "Arial", "fontSize": 14, "fill": "#000000" + } + objects.append(new_text) + y += 35 + + canvas_data["objects"] = objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=f"Nummerierung {start}-{end} hinzugefuegt" + ) + + # Add rectangle/box + if "rechteck" in prompt_lower or "box" in prompt_lower or "kasten" in prompt_lower: + new_rect = { + "type": "rect", "id": generate_id(), + "left": 100, "top": 200, "width": 200, "height": 100, + "fill": "transparent", "stroke": "#000000", "strokeWidth": 2 + } + objects.append(new_rect) + canvas_data["objects"] = objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message="Rechteck hinzugefuegt" + ) + + # Add grid/raster + if "raster" in prompt_lower or "grid" in prompt_lower or "tabelle" in prompt_lower: + dim_match = re.search(r'(\d+)\s*[x/\u00d7\*mal by]\s*(\d+)', prompt_lower) + if dim_match: + cols = int(dim_match.group(1)) + rows = int(dim_match.group(2)) + else: + nums = re.findall(r'(\d+)', prompt) + if len(nums) >= 2: + cols, rows = int(nums[0]), int(nums[1]) + else: + cols, rows = 3, 4 + + cols = min(max(1, cols), 10) + rows = min(max(1, rows), 15) + + canvas_width = 794 + canvas_height = 1123 + margin = 60 + available_width = canvas_width - 2 * margin + available_height = canvas_height - 2 * margin - 80 + + cell_width = available_width / cols + cell_height = min(available_height / rows, 80) + + start_x = margin + start_y = 120 + + grid_objects = [] + for r in range(rows + 1): + y = start_y + r * cell_height + grid_objects.append({ + "type": "line", "id": generate_id(), + "x1": start_x, "y1": y, + "x2": start_x + cols * cell_width, "y2": y, + "stroke": "#666666", "strokeWidth": 1, "isGrid": True + }) + + for c in range(cols + 1): + x = start_x + c * cell_width + grid_objects.append({ + "type": "line", "id": generate_id(), + "x1": x, "y1": start_y, + "x2": x, "y2": start_y + rows * cell_height, + "stroke": "#666666", "strokeWidth": 1, "isGrid": True + }) + + objects.extend(grid_objects) + canvas_data["objects"] = objects + return AIModifyResponse( + modified_canvas_json=json.dumps(canvas_data), + message=f"{cols}x{rows} Raster hinzugefuegt ({cols} Spalten, {rows} Zeilen)" + ) + + # Default: Ollama needed + return AIModifyResponse( + message="Diese Aenderung erfordert den KI-Service. Bitte stellen Sie sicher, dass Ollama laeuft.", + error="Complex modification requires Ollama" + ) diff --git a/klausur-service/backend/worksheet_editor_api.py b/klausur-service/backend/worksheet_editor_api.py index bc96131..875d78b 100644 --- a/klausur-service/backend/worksheet_editor_api.py +++ b/klausur-service/backend/worksheet_editor_api.py @@ -5,122 +5,60 @@ Provides endpoints for: - AI Image generation via Ollama/Stable Diffusion - Worksheet Save/Load - PDF Export + +Split modules: +- worksheet_editor_models: Enums, Pydantic models, configuration +- worksheet_editor_ai: AI image generation and AI worksheet modification +- worksheet_editor_reconstruct: Document reconstruction from vocab sessions """ import os import io -import uuid import json -import base64 import logging from datetime import datetime, timezone -from typing import Optional, List, Dict, Any -from enum import Enum -from dataclasses import dataclass, field, asdict +import uuid -from fastapi import APIRouter, HTTPException, Request, BackgroundTasks -from fastapi.responses import FileResponse, StreamingResponse -from pydantic import BaseModel, Field +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse import httpx -# PDF Generation -try: - from reportlab.lib import colors - from reportlab.lib.pagesizes import A4 - from reportlab.lib.units import mm - from reportlab.pdfgen import canvas - from reportlab.lib.styles import getSampleStyleSheet - REPORTLAB_AVAILABLE = True -except ImportError: - REPORTLAB_AVAILABLE = False +# Re-export everything from sub-modules for backward compatibility +from worksheet_editor_models import ( # noqa: F401 + AIImageStyle, + WorksheetStatus, + AIImageRequest, + AIImageResponse, + PageData, + PageFormat, + WorksheetSaveRequest, + WorksheetResponse, + AIModifyRequest, + AIModifyResponse, + ReconstructRequest, + ReconstructResponse, + worksheets_db, + OLLAMA_URL, + SD_MODEL, + WORKSHEET_STORAGE_DIR, + STYLE_PROMPTS, + REPORTLAB_AVAILABLE, +) + +from worksheet_editor_ai import ( # noqa: F401 + generate_ai_image_logic, + _generate_placeholder_image, + modify_worksheet_with_ai_logic, + _handle_simple_modification, +) + +from worksheet_editor_reconstruct import ( # noqa: F401 + reconstruct_document_logic, + _detect_image_regions, +) logger = logging.getLogger(__name__) -# ============================================= -# CONFIGURATION -# ============================================= - -OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") -SD_MODEL = os.getenv("SD_MODEL", "stable-diffusion") # or specific SD model -WORKSHEET_STORAGE_DIR = os.getenv("WORKSHEET_STORAGE_DIR", - os.path.join(os.path.dirname(os.path.abspath(__file__)), "worksheet-storage")) - -# Ensure storage directory exists -os.makedirs(WORKSHEET_STORAGE_DIR, exist_ok=True) - -# ============================================= -# ENUMS & MODELS -# ============================================= - -class AIImageStyle(str, Enum): - REALISTIC = "realistic" - CARTOON = "cartoon" - SKETCH = "sketch" - CLIPART = "clipart" - EDUCATIONAL = "educational" - -class WorksheetStatus(str, Enum): - DRAFT = "draft" - PUBLISHED = "published" - ARCHIVED = "archived" - -# Style prompt modifiers -STYLE_PROMPTS = { - AIImageStyle.REALISTIC: "photorealistic, high detail, professional photography", - AIImageStyle.CARTOON: "cartoon style, colorful, child-friendly, simple shapes", - AIImageStyle.SKETCH: "pencil sketch, hand-drawn, black and white, artistic", - AIImageStyle.CLIPART: "clipart style, flat design, simple, vector-like", - AIImageStyle.EDUCATIONAL: "educational illustration, clear, informative, textbook style" -} - -# ============================================= -# REQUEST/RESPONSE MODELS -# ============================================= - -class AIImageRequest(BaseModel): - prompt: str = Field(..., min_length=3, max_length=500) - style: AIImageStyle = AIImageStyle.EDUCATIONAL - width: int = Field(512, ge=256, le=1024) - height: int = Field(512, ge=256, le=1024) - -class AIImageResponse(BaseModel): - image_base64: str - prompt_used: str - error: Optional[str] = None - -class PageData(BaseModel): - id: str - index: int - canvasJSON: str - -class PageFormat(BaseModel): - width: float = 210 - height: float = 297 - orientation: str = "portrait" - margins: Dict[str, float] = {"top": 15, "right": 15, "bottom": 15, "left": 15} - -class WorksheetSaveRequest(BaseModel): - id: Optional[str] = None - title: str - description: Optional[str] = None - pages: List[PageData] - pageFormat: Optional[PageFormat] = None - -class WorksheetResponse(BaseModel): - id: str - title: str - description: Optional[str] - pages: List[PageData] - pageFormat: PageFormat - createdAt: str - updatedAt: str - -# ============================================= -# IN-MEMORY STORAGE (Development) -# ============================================= - -worksheets_db: Dict[str, Dict] = {} - # ============================================= # ROUTER # ============================================= @@ -143,144 +81,7 @@ async def generate_ai_image(request: AIImageRequest): Falls back to a placeholder if Ollama is not available. """ - try: - # Build enhanced prompt with style - style_modifier = STYLE_PROMPTS.get(request.style, "") - enhanced_prompt = f"{request.prompt}, {style_modifier}" - - logger.info(f"Generating AI image: {enhanced_prompt[:100]}...") - - # Check if Ollama is available - async with httpx.AsyncClient(timeout=10.0) as check_client: - try: - health_response = await check_client.get(f"{OLLAMA_URL}/api/tags") - if health_response.status_code != 200: - raise HTTPException(status_code=503, detail="Ollama service not available") - except httpx.ConnectError: - logger.warning("Ollama not reachable, returning placeholder") - # Return a placeholder image (simple colored rectangle) - return _generate_placeholder_image(request, enhanced_prompt) - - # Try to generate with Stable Diffusion via Ollama - # Note: Ollama doesn't natively support SD, this is a placeholder for when it does - # or when using a compatible endpoint - - try: - async with httpx.AsyncClient(timeout=300.0) as client: - # Check if SD model is available - tags_response = await client.get(f"{OLLAMA_URL}/api/tags") - available_models = [m.get("name", "") for m in tags_response.json().get("models", [])] - - # Look for SD-compatible model - sd_model = None - for model in available_models: - if "stable" in model.lower() or "sd" in model.lower() or "diffusion" in model.lower(): - sd_model = model - break - - if not sd_model: - logger.warning("No Stable Diffusion model found in Ollama") - return _generate_placeholder_image(request, enhanced_prompt) - - # Generate image (this would need Ollama's image generation API) - # For now, return placeholder - logger.info(f"SD model found: {sd_model}, but image generation API not implemented") - return _generate_placeholder_image(request, enhanced_prompt) - - except Exception as e: - logger.error(f"Image generation failed: {e}") - return _generate_placeholder_image(request, enhanced_prompt) - - except HTTPException: - raise - except Exception as e: - logger.error(f"AI image generation error: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -def _generate_placeholder_image(request: AIImageRequest, prompt: str) -> AIImageResponse: - """ - Generate a placeholder image when AI generation is not available. - Creates a simple SVG-based placeholder with the prompt text. - """ - from PIL import Image, ImageDraw, ImageFont - - # Create image - width, height = request.width, request.height - - # Style-based colors - style_colors = { - AIImageStyle.REALISTIC: ("#2563eb", "#dbeafe"), - AIImageStyle.CARTOON: ("#f97316", "#ffedd5"), - AIImageStyle.SKETCH: ("#6b7280", "#f3f4f6"), - AIImageStyle.CLIPART: ("#8b5cf6", "#ede9fe"), - AIImageStyle.EDUCATIONAL: ("#059669", "#d1fae5"), - } - - fg_color, bg_color = style_colors.get(request.style, ("#6366f1", "#e0e7ff")) - - # Create image with Pillow - img = Image.new('RGB', (width, height), bg_color) - draw = ImageDraw.Draw(img) - - # Draw border - draw.rectangle([5, 5, width-6, height-6], outline=fg_color, width=3) - - # Draw icon (simple shapes) - cx, cy = width // 2, height // 2 - 30 - draw.ellipse([cx-40, cy-40, cx+40, cy+40], outline=fg_color, width=3) - draw.line([cx-20, cy-10, cx+20, cy-10], fill=fg_color, width=3) - draw.line([cx, cy-10, cx, cy+20], fill=fg_color, width=3) - - # Draw text - try: - font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) - except: - font = ImageFont.load_default() - - # Wrap text - max_chars = 40 - lines = [] - words = prompt[:200].split() - current_line = "" - for word in words: - if len(current_line) + len(word) + 1 <= max_chars: - current_line += (" " + word if current_line else word) - else: - if current_line: - lines.append(current_line) - current_line = word - if current_line: - lines.append(current_line) - - text_y = cy + 60 - for line in lines[:4]: # Max 4 lines - bbox = draw.textbbox((0, 0), line, font=font) - text_width = bbox[2] - bbox[0] - draw.text((cx - text_width // 2, text_y), line, fill=fg_color, font=font) - text_y += 20 - - # Draw "AI Placeholder" badge - badge_text = "KI-Bild (Platzhalter)" - try: - badge_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10) - except: - badge_font = font - draw.rectangle([10, height-30, 150, height-10], fill=fg_color) - draw.text((15, height-27), badge_text, fill="white", font=badge_font) - - # Convert to base64 - buffer = io.BytesIO() - img.save(buffer, format='PNG') - buffer.seek(0) - - image_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" - - return AIImageResponse( - image_base64=image_base64, - prompt_used=prompt, - error="AI image generation not available. Using placeholder." - ) + return await generate_ai_image_logic(request) # ============================================= @@ -298,10 +99,8 @@ async def save_worksheet(request: WorksheetSaveRequest): try: now = datetime.now(timezone.utc).isoformat() - # Generate or use existing ID worksheet_id = request.id or f"ws_{uuid.uuid4().hex[:12]}" - # Build worksheet data worksheet = { "id": worksheet_id, "title": request.title, @@ -312,10 +111,8 @@ async def save_worksheet(request: WorksheetSaveRequest): "updatedAt": now } - # Save to in-memory storage worksheets_db[worksheet_id] = worksheet - # Also persist to file filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json") with open(filepath, 'w', encoding='utf-8') as f: json.dump(worksheet, f, ensure_ascii=False, indent=2) @@ -331,20 +128,16 @@ async def save_worksheet(request: WorksheetSaveRequest): @router.get("/{worksheet_id}", response_model=WorksheetResponse) async def get_worksheet(worksheet_id: str): - """ - Load a worksheet document by ID. - """ + """Load a worksheet document by ID.""" try: - # Try in-memory first if worksheet_id in worksheets_db: return WorksheetResponse(**worksheets_db[worksheet_id]) - # Try file storage filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json") if os.path.exists(filepath): with open(filepath, 'r', encoding='utf-8') as f: worksheet = json.load(f) - worksheets_db[worksheet_id] = worksheet # Cache it + worksheets_db[worksheet_id] = worksheet return WorksheetResponse(**worksheet) raise HTTPException(status_code=404, detail="Worksheet not found") @@ -358,13 +151,10 @@ async def get_worksheet(worksheet_id: str): @router.get("/list/all") async def list_worksheets(): - """ - List all available worksheets. - """ + """List all available worksheets.""" try: worksheets = [] - # Load from file storage for filename in os.listdir(WORKSHEET_STORAGE_DIR): if filename.endswith('.json'): filepath = os.path.join(WORKSHEET_STORAGE_DIR, filename) @@ -382,7 +172,6 @@ async def list_worksheets(): except Exception as e: logger.warning(f"Failed to load {filename}: {e}") - # Sort by updatedAt descending worksheets.sort(key=lambda x: x.get("updatedAt", ""), reverse=True) return {"worksheets": worksheets, "total": len(worksheets)} @@ -394,15 +183,11 @@ async def list_worksheets(): @router.delete("/{worksheet_id}") async def delete_worksheet(worksheet_id: str): - """ - Delete a worksheet document. - """ + """Delete a worksheet document.""" try: - # Remove from memory if worksheet_id in worksheets_db: del worksheets_db[worksheet_id] - # Remove file filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json") if os.path.exists(filepath): os.remove(filepath) @@ -434,7 +219,9 @@ async def export_worksheet_pdf(worksheet_id: str): raise HTTPException(status_code=501, detail="PDF export not available (reportlab not installed)") try: - # Load worksheet + from reportlab.lib.pagesizes import A4 + from reportlab.pdfgen import canvas + worksheet = worksheets_db.get(worksheet_id) if not worksheet: filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json") @@ -444,21 +231,18 @@ async def export_worksheet_pdf(worksheet_id: str): else: raise HTTPException(status_code=404, detail="Worksheet not found") - # Create PDF buffer = io.BytesIO() c = canvas.Canvas(buffer, pagesize=A4) page_width, page_height = A4 for page_data in worksheet.get("pages", []): - # Add title on first page if page_data.get("index", 0) == 0: c.setFont("Helvetica-Bold", 18) c.drawString(50, page_height - 50, worksheet.get("title", "Arbeitsblatt")) c.setFont("Helvetica", 10) c.drawString(50, page_height - 70, f"Erstellt: {worksheet.get('createdAt', '')[:10]}") - # Parse canvas JSON and render basic elements canvas_json_str = page_data.get("canvasJSON", "{}") if canvas_json_str: try: @@ -469,34 +253,28 @@ async def export_worksheet_pdf(worksheet_id: str): obj_type = obj.get("type", "") if obj_type in ["text", "i-text", "textbox"]: - # Render text text = obj.get("text", "") left = obj.get("left", 50) top = obj.get("top", 100) font_size = obj.get("fontSize", 12) - # Convert from canvas coords to PDF coords - pdf_x = left * 0.75 # Approximate scale + pdf_x = left * 0.75 pdf_y = page_height - (top * 0.75) c.setFont("Helvetica", min(font_size, 24)) c.drawString(pdf_x, pdf_y, text[:100]) elif obj_type == "rect": - # Render rectangle left = obj.get("left", 0) * 0.75 top = obj.get("top", 0) * 0.75 width = obj.get("width", 50) * 0.75 height = obj.get("height", 30) * 0.75 - c.rect(left, page_height - top - height, width, height) elif obj_type == "circle": - # Render circle left = obj.get("left", 0) * 0.75 top = obj.get("top", 0) * 0.75 radius = obj.get("radius", 25) * 0.75 - c.circle(left + radius, page_height - top - radius, radius) except json.JSONDecodeError: @@ -526,16 +304,6 @@ async def export_worksheet_pdf(worksheet_id: str): # AI WORKSHEET MODIFICATION # ============================================= -class AIModifyRequest(BaseModel): - prompt: str = Field(..., min_length=3, max_length=1000) - canvas_json: str - model: str = "qwen2.5vl:32b" - -class AIModifyResponse(BaseModel): - modified_canvas_json: Optional[str] = None - message: str - error: Optional[str] = None - @router.post("/ai-modify", response_model=AIModifyResponse) async def modify_worksheet_with_ai(request: AIModifyRequest): """ @@ -544,397 +312,7 @@ async def modify_worksheet_with_ai(request: AIModifyRequest): Uses Ollama with qwen2.5vl:32b to understand the canvas state and generate modifications based on the user's request. """ - try: - logger.info(f"AI modify request: {request.prompt[:100]}...") - - # Parse current canvas state - try: - canvas_data = json.loads(request.canvas_json) - except json.JSONDecodeError: - return AIModifyResponse( - message="Fehler beim Parsen des Canvas", - error="Invalid canvas JSON" - ) - - # Build system prompt for the AI - system_prompt = """Du bist ein Assistent fuer die Bearbeitung von Arbeitsblaettern. -Du erhaeltst den aktuellen Zustand eines Canvas im JSON-Format und eine Anweisung des Nutzers. -Deine Aufgabe ist es, die gewuenschten Aenderungen am Canvas vorzunehmen. - -Der Canvas verwendet Fabric.js. Hier sind die wichtigsten Objekttypen: -- i-text: Interaktiver Text mit fontFamily, fontSize, fill, left, top -- rect: Rechteck mit left, top, width, height, fill, stroke, strokeWidth -- circle: Kreis mit left, top, radius, fill, stroke, strokeWidth -- line: Linie mit x1, y1, x2, y2, stroke, strokeWidth - -Das Canvas ist 794x1123 Pixel (A4 bei 96 DPI). - -Antworte NUR mit einem JSON-Objekt in diesem Format: -{ - "action": "modify" oder "add" oder "delete" oder "info", - "objects": [...], // Neue/modifizierte Objekte (bei modify/add) - "message": "Kurze Beschreibung der Aenderung" -} - -Wenn du Objekte hinzufuegst, generiere eindeutige IDs im Format "obj__". -""" - - user_prompt = f"""Aktueller Canvas-Zustand: -```json -{json.dumps(canvas_data, indent=2)[:5000]} -``` - -Nutzer-Anweisung: {request.prompt} - -Fuehre die Aenderung durch und antworte mit dem JSON-Objekt.""" - - # Call Ollama - try: - async with httpx.AsyncClient(timeout=120.0) as client: - response = await client.post( - f"{OLLAMA_URL}/api/generate", - json={ - "model": request.model, - "prompt": user_prompt, - "system": system_prompt, - "stream": False, - "options": { - "temperature": 0.3, - "num_predict": 4096 - } - } - ) - - if response.status_code != 200: - logger.warning(f"Ollama error: {response.status_code}, trying local fallback") - # Fallback: Try to handle simple requests locally - return _handle_simple_modification(request.prompt, canvas_data) - - result = response.json() - ai_response = result.get("response", "") - - except httpx.ConnectError: - logger.warning("Ollama not reachable") - # Fallback: Try to handle simple requests locally - return _handle_simple_modification(request.prompt, canvas_data) - except httpx.TimeoutException: - logger.warning("Ollama timeout, trying local fallback") - # Fallback: Try to handle simple requests locally - return _handle_simple_modification(request.prompt, canvas_data) - - # Parse AI response - try: - # Find JSON in response - json_start = ai_response.find('{') - json_end = ai_response.rfind('}') + 1 - - if json_start == -1 or json_end <= json_start: - logger.warning(f"No JSON found in AI response: {ai_response[:200]}") - return AIModifyResponse( - message="KI konnte die Anfrage nicht verarbeiten", - error="No JSON in response" - ) - - ai_json = json.loads(ai_response[json_start:json_end]) - action = ai_json.get("action", "info") - message = ai_json.get("message", "Aenderungen angewendet") - new_objects = ai_json.get("objects", []) - - if action == "info": - return AIModifyResponse(message=message) - - if action == "add" and new_objects: - # Add new objects to canvas - existing_objects = canvas_data.get("objects", []) - existing_objects.extend(new_objects) - canvas_data["objects"] = existing_objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=message - ) - - if action == "modify" and new_objects: - # Replace matching objects or add new ones - existing_objects = canvas_data.get("objects", []) - new_ids = {obj.get("id") for obj in new_objects if obj.get("id")} - - # Keep objects that aren't being modified - kept_objects = [obj for obj in existing_objects if obj.get("id") not in new_ids] - kept_objects.extend(new_objects) - canvas_data["objects"] = kept_objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=message - ) - - if action == "delete": - # Delete objects by ID - delete_ids = ai_json.get("delete_ids", []) - if delete_ids: - existing_objects = canvas_data.get("objects", []) - canvas_data["objects"] = [obj for obj in existing_objects if obj.get("id") not in delete_ids] - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=message - ) - - return AIModifyResponse(message=message) - - except json.JSONDecodeError as e: - logger.error(f"Failed to parse AI JSON: {e}") - return AIModifyResponse( - message="Fehler beim Verarbeiten der KI-Antwort", - error=str(e) - ) - - except Exception as e: - logger.error(f"AI modify error: {e}") - return AIModifyResponse( - message="Ein unerwarteter Fehler ist aufgetreten", - error=str(e) - ) - - -def _handle_simple_modification(prompt: str, canvas_data: dict) -> AIModifyResponse: - """ - Handle simple modifications locally when Ollama is not available. - Supports basic commands like adding headings, lines, etc. - """ - import time - import random - - prompt_lower = prompt.lower() - objects = canvas_data.get("objects", []) - - def generate_id(): - return f"obj_{int(time.time()*1000)}_{random.randint(1000, 9999)}" - - # Add heading - if "ueberschrift" in prompt_lower or "titel" in prompt_lower or "heading" in prompt_lower: - # Extract text if provided in quotes - import re - text_match = re.search(r'"([^"]+)"', prompt) - text = text_match.group(1) if text_match else "Ueberschrift" - - new_text = { - "type": "i-text", - "id": generate_id(), - "text": text, - "left": 397, # Center of A4 - "top": 50, - "originX": "center", - "fontFamily": "Arial", - "fontSize": 28, - "fontWeight": "bold", - "fill": "#000000" - } - objects.append(new_text) - canvas_data["objects"] = objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=f"Ueberschrift '{text}' hinzugefuegt" - ) - - # Add lines for writing - if "linie" in prompt_lower or "line" in prompt_lower or "schreib" in prompt_lower: - # Count how many lines - import re - num_match = re.search(r'(\d+)', prompt) - num_lines = int(num_match.group(1)) if num_match else 5 - num_lines = min(num_lines, 20) # Max 20 lines - - start_y = 150 - line_spacing = 40 - - for i in range(num_lines): - new_line = { - "type": "line", - "id": generate_id(), - "x1": 60, - "y1": start_y + i * line_spacing, - "x2": 734, - "y2": start_y + i * line_spacing, - "stroke": "#cccccc", - "strokeWidth": 1 - } - objects.append(new_line) - - canvas_data["objects"] = objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=f"{num_lines} Schreiblinien hinzugefuegt" - ) - - # Make text bigger - if "groesser" in prompt_lower or "bigger" in prompt_lower or "larger" in prompt_lower: - modified = 0 - for obj in objects: - if obj.get("type") in ["i-text", "text", "textbox"]: - current_size = obj.get("fontSize", 16) - obj["fontSize"] = int(current_size * 1.25) - modified += 1 - - canvas_data["objects"] = objects - - if modified > 0: - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=f"{modified} Texte vergroessert" - ) - - # Center elements - if "zentrier" in prompt_lower or "center" in prompt_lower or "mitte" in prompt_lower: - center_x = 397 - for obj in objects: - if not obj.get("isGrid"): - obj["left"] = center_x - obj["originX"] = "center" - - canvas_data["objects"] = objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message="Elemente zentriert" - ) - - # Add numbering - if "nummer" in prompt_lower or "nummerier" in prompt_lower or "1-10" in prompt_lower: - import re - range_match = re.search(r'(\d+)\s*[-bis]+\s*(\d+)', prompt) - if range_match: - start, end = int(range_match.group(1)), int(range_match.group(2)) - else: - start, end = 1, 10 - - y = 100 - for i in range(start, min(end + 1, start + 20)): - new_text = { - "type": "i-text", - "id": generate_id(), - "text": f"{i}.", - "left": 40, - "top": y, - "fontFamily": "Arial", - "fontSize": 14, - "fill": "#000000" - } - objects.append(new_text) - y += 35 - - canvas_data["objects"] = objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=f"Nummerierung {start}-{end} hinzugefuegt" - ) - - # Add rectangle/box - if "rechteck" in prompt_lower or "box" in prompt_lower or "kasten" in prompt_lower: - new_rect = { - "type": "rect", - "id": generate_id(), - "left": 100, - "top": 200, - "width": 200, - "height": 100, - "fill": "transparent", - "stroke": "#000000", - "strokeWidth": 2 - } - objects.append(new_rect) - canvas_data["objects"] = objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message="Rechteck hinzugefuegt" - ) - - # Add grid/raster - if "raster" in prompt_lower or "grid" in prompt_lower or "tabelle" in prompt_lower: - import re - # Parse dimensions like "3x4", "3/4", "3 mal 4", "3 by 4" - dim_match = re.search(r'(\d+)\s*[x/×\*mal by]\s*(\d+)', prompt_lower) - if dim_match: - cols = int(dim_match.group(1)) - rows = int(dim_match.group(2)) - else: - # Try single numbers - nums = re.findall(r'(\d+)', prompt) - if len(nums) >= 2: - cols, rows = int(nums[0]), int(nums[1]) - else: - cols, rows = 3, 4 # Default grid - - # Limit grid size - cols = min(max(1, cols), 10) - rows = min(max(1, rows), 15) - - # Canvas dimensions (A4 at 96 DPI) - canvas_width = 794 - canvas_height = 1123 - - # Grid positioning - margin = 60 - available_width = canvas_width - 2 * margin - available_height = canvas_height - 2 * margin - 80 # Leave space for header - - cell_width = available_width / cols - cell_height = min(available_height / rows, 80) # Max cell height - - start_x = margin - start_y = 120 # Below potential header - - # Create grid lines - grid_objects = [] - - # Horizontal lines - for r in range(rows + 1): - y = start_y + r * cell_height - grid_objects.append({ - "type": "line", - "id": generate_id(), - "x1": start_x, - "y1": y, - "x2": start_x + cols * cell_width, - "y2": y, - "stroke": "#666666", - "strokeWidth": 1, - "isGrid": True - }) - - # Vertical lines - for c in range(cols + 1): - x = start_x + c * cell_width - grid_objects.append({ - "type": "line", - "id": generate_id(), - "x1": x, - "y1": start_y, - "x2": x, - "y2": start_y + rows * cell_height, - "stroke": "#666666", - "strokeWidth": 1, - "isGrid": True - }) - - objects.extend(grid_objects) - canvas_data["objects"] = objects - - return AIModifyResponse( - modified_canvas_json=json.dumps(canvas_data), - message=f"{cols}x{rows} Raster hinzugefuegt ({cols} Spalten, {rows} Zeilen)" - ) - - # Default: Ollama needed - return AIModifyResponse( - message="Diese Aenderung erfordert den KI-Service. Bitte stellen Sie sicher, dass Ollama laeuft.", - error="Complex modification requires Ollama" - ) + return await modify_worksheet_with_ai_logic(request) # ============================================= @@ -943,9 +321,7 @@ def _handle_simple_modification(prompt: str, canvas_data: dict) -> AIModifyRespo @router.get("/health/check") async def health_check(): - """ - Check worksheet editor API health and dependencies. - """ + """Check worksheet editor API health and dependencies.""" status = { "status": "healthy", "ollama": False, @@ -954,12 +330,11 @@ async def health_check(): "worksheets_count": len(worksheets_db) } - # Check Ollama try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{OLLAMA_URL}/api/tags") status["ollama"] = response.status_code == 200 - except: + except Exception: pass return status @@ -969,221 +344,15 @@ async def health_check(): # DOCUMENT RECONSTRUCTION FROM VOCAB SESSION # ============================================= -class ReconstructRequest(BaseModel): - session_id: str - page_number: int = 1 - include_images: bool = True - regenerate_graphics: bool = False - -class ReconstructResponse(BaseModel): - canvas_json: str - page_width: int - page_height: int - elements_count: int - vocabulary_matched: int - message: str - error: Optional[str] = None - - @router.post("/reconstruct-from-session", response_model=ReconstructResponse) async def reconstruct_document_from_session(request: ReconstructRequest): """ Reconstruct a document from a vocab session into Fabric.js canvas format. - This endpoint: - 1. Loads the original PDF from the vocab session - 2. Runs OCR with position tracking - 3. Uses vision LLM to understand layout (headers, images, columns) - 4. Creates Fabric.js canvas JSON with positioned elements - 5. Maps extracted vocabulary to their positions - Returns canvas JSON ready to load into the worksheet editor. """ try: - # Import vocab session storage - from vocab_worksheet_api import _sessions, convert_pdf_page_to_image - - # Check if session exists - if request.session_id not in _sessions: - raise HTTPException(status_code=404, detail=f"Session {request.session_id} not found") - - session = _sessions[request.session_id] - - # Check if PDF data exists - if not session.get("pdf_data"): - raise HTTPException(status_code=400, detail="Session has no PDF data") - - pdf_data = session["pdf_data"] - page_count = session.get("pdf_page_count", 1) - - if request.page_number < 1 or request.page_number > page_count: - raise HTTPException( - status_code=400, - detail=f"Page {request.page_number} not found. PDF has {page_count} pages." - ) - - # Get extracted vocabulary for this page - vocabulary = session.get("vocabulary", []) - page_vocab = [v for v in vocabulary if v.get("source_page") == request.page_number] - - logger.info(f"Reconstructing page {request.page_number} from session {request.session_id}") - logger.info(f"Found {len(page_vocab)} vocabulary items for this page") - - # Convert PDF page to image (async function) - image_bytes = await convert_pdf_page_to_image(pdf_data, request.page_number) - if not image_bytes: - raise HTTPException(status_code=500, detail="Failed to convert PDF page to image") - - # Get image dimensions - from PIL import Image - img = Image.open(io.BytesIO(image_bytes)) - img_width, img_height = img.size - - # Run OCR with positions - from hybrid_vocab_extractor import run_paddle_ocr, OCRRegion - ocr_regions, raw_text = run_paddle_ocr(image_bytes) - - logger.info(f"OCR found {len(ocr_regions)} text regions") - - # Scale factor: Convert image pixels to A4 canvas pixels (794x1123) - A4_WIDTH = 794 - A4_HEIGHT = 1123 - scale_x = A4_WIDTH / img_width - scale_y = A4_HEIGHT / img_height - - # Build Fabric.js objects - fabric_objects = [] - - # 1. Add white background - fabric_objects.append({ - "type": "rect", - "left": 0, - "top": 0, - "width": A4_WIDTH, - "height": A4_HEIGHT, - "fill": "#ffffff", - "selectable": False, - "evented": False, - "isBackground": True - }) - - # 2. Group OCR regions by Y-coordinate to detect rows - sorted_regions = sorted(ocr_regions, key=lambda r: (r.y1, r.x1)) - - # 3. Detect headers (larger text at top) - headers = [] - body_regions = [] - - for region in sorted_regions: - height = region.y2 - region.y1 - # Headers are typically taller and near the top - if region.y1 < img_height * 0.15 and height > 30: - headers.append(region) - else: - body_regions.append(region) - - # 4. Create text objects for each region - vocab_matched = 0 - - for region in sorted_regions: - # Scale positions to A4 - left = int(region.x1 * scale_x) - top = int(region.y1 * scale_y) - - # Determine if this is a header - is_header = region in headers - - # Determine font size based on region height - region_height = region.y2 - region.y1 - base_font_size = max(10, min(32, int(region_height * scale_y * 0.8))) - - if is_header: - base_font_size = max(base_font_size, 24) - - # Check if this text matches vocabulary - is_vocab = False - vocab_match = None - for v in page_vocab: - if v.get("english", "").lower() in region.text.lower() or \ - v.get("german", "").lower() in region.text.lower(): - is_vocab = True - vocab_match = v - vocab_matched += 1 - break - - # Create Fabric.js text object - text_obj = { - "type": "i-text", - "id": f"text_{uuid.uuid4().hex[:8]}", - "left": left, - "top": top, - "text": region.text, - "fontFamily": "Arial", - "fontSize": base_font_size, - "fontWeight": "bold" if is_header else "normal", - "fill": "#000000", - "originX": "left", - "originY": "top", - } - - # Add metadata for vocabulary items - if is_vocab and vocab_match: - text_obj["isVocabulary"] = True - text_obj["vocabularyId"] = vocab_match.get("id") - text_obj["english"] = vocab_match.get("english") - text_obj["german"] = vocab_match.get("german") - - fabric_objects.append(text_obj) - - # 5. If include_images, try to detect and extract image regions - if request.include_images: - image_regions = await _detect_image_regions(image_bytes, ocr_regions, img_width, img_height) - - for i, img_region in enumerate(image_regions): - # Extract image region from original - img_x1 = int(img_region["x1"]) - img_y1 = int(img_region["y1"]) - img_x2 = int(img_region["x2"]) - img_y2 = int(img_region["y2"]) - - # Crop the region - cropped = img.crop((img_x1, img_y1, img_x2, img_y2)) - - # Convert to base64 - buffer = io.BytesIO() - cropped.save(buffer, format='PNG') - buffer.seek(0) - img_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" - - # Create Fabric.js image object - fabric_objects.append({ - "type": "image", - "id": f"img_{uuid.uuid4().hex[:8]}", - "left": int(img_x1 * scale_x), - "top": int(img_y1 * scale_y), - "width": int((img_x2 - img_x1) * scale_x), - "height": int((img_y2 - img_y1) * scale_y), - "src": img_base64, - "scaleX": 1, - "scaleY": 1, - }) - - # Build canvas JSON - canvas_data = { - "version": "6.0.0", - "objects": fabric_objects, - "background": "#ffffff" - } - - return ReconstructResponse( - canvas_json=json.dumps(canvas_data), - page_width=A4_WIDTH, - page_height=A4_HEIGHT, - elements_count=len(fabric_objects), - vocabulary_matched=vocab_matched, - message=f"Reconstructed page {request.page_number} with {len(fabric_objects)} elements, {vocab_matched} vocabulary items matched" - ) - + return await reconstruct_document_logic(request) except HTTPException: raise except Exception as e: @@ -1193,101 +362,15 @@ async def reconstruct_document_from_session(request: ReconstructRequest): raise HTTPException(status_code=500, detail=str(e)) -async def _detect_image_regions( - image_bytes: bytes, - ocr_regions: list, - img_width: int, - img_height: int -) -> List[Dict]: - """ - Detect image/graphic regions in the document. - - Uses a simple approach: - 1. Find large gaps between text regions (potential image areas) - 2. Use edge detection to find bounded regions - 3. Filter out text areas - """ - from PIL import Image - import numpy as np - - try: - img = Image.open(io.BytesIO(image_bytes)) - img_array = np.array(img.convert('L')) # Grayscale - - # Create a mask of text regions - text_mask = np.ones_like(img_array, dtype=bool) - for region in ocr_regions: - x1 = max(0, region.x1 - 5) - y1 = max(0, region.y1 - 5) - x2 = min(img_width, region.x2 + 5) - y2 = min(img_height, region.y2 + 5) - text_mask[y1:y2, x1:x2] = False - - # Find contours in non-text areas - # Simple approach: look for rectangular regions with significant content - image_regions = [] - - # Use edge detection - import cv2 - edges = cv2.Canny(img_array, 50, 150) - - # Apply text mask - edges[~text_mask] = 0 - - # Find contours - contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - - for contour in contours: - x, y, w, h = cv2.boundingRect(contour) - - # Filter: minimum size for images (at least 50x50 pixels) - if w > 50 and h > 50: - # Filter: not too large (not the whole page) - if w < img_width * 0.9 and h < img_height * 0.9: - # Check if this region has actual content (not just edges) - region_content = img_array[y:y+h, x:x+w] - variance = np.var(region_content) - - if variance > 500: # Has enough visual content - image_regions.append({ - "x1": x, - "y1": y, - "x2": x + w, - "y2": y + h - }) - - # Remove overlapping regions (keep larger ones) - filtered_regions = [] - for region in sorted(image_regions, key=lambda r: (r["x2"]-r["x1"])*(r["y2"]-r["y1"]), reverse=True): - overlaps = False - for existing in filtered_regions: - # Check overlap - if not (region["x2"] < existing["x1"] or region["x1"] > existing["x2"] or - region["y2"] < existing["y1"] or region["y1"] > existing["y2"]): - overlaps = True - break - if not overlaps: - filtered_regions.append(region) - - logger.info(f"Detected {len(filtered_regions)} image regions") - return filtered_regions[:10] # Limit to 10 images max - - except Exception as e: - logger.warning(f"Image region detection failed: {e}") - return [] - - @router.get("/sessions/available") async def get_available_sessions(): - """ - Get list of available vocab sessions that can be reconstructed. - """ + """Get list of available vocab sessions that can be reconstructed.""" try: from vocab_worksheet_api import _sessions available = [] for session_id, session in _sessions.items(): - if session.get("pdf_data"): # Only sessions with PDF + if session.get("pdf_data"): available.append({ "id": session_id, "name": session.get("name", "Unnamed"), diff --git a/klausur-service/backend/worksheet_editor_models.py b/klausur-service/backend/worksheet_editor_models.py new file mode 100644 index 0000000..468d36e --- /dev/null +++ b/klausur-service/backend/worksheet_editor_models.py @@ -0,0 +1,133 @@ +""" +Worksheet Editor Models — Enums, Pydantic models, and configuration. +""" + +import os +import logging +from typing import Optional, List, Dict +from enum import Enum + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# ============================================= +# CONFIGURATION +# ============================================= + +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") +SD_MODEL = os.getenv("SD_MODEL", "stable-diffusion") # or specific SD model +WORKSHEET_STORAGE_DIR = os.getenv("WORKSHEET_STORAGE_DIR", + os.path.join(os.path.dirname(os.path.abspath(__file__)), "worksheet-storage")) + +# Ensure storage directory exists +os.makedirs(WORKSHEET_STORAGE_DIR, exist_ok=True) + +# ============================================= +# ENUMS & MODELS +# ============================================= + +class AIImageStyle(str, Enum): + REALISTIC = "realistic" + CARTOON = "cartoon" + SKETCH = "sketch" + CLIPART = "clipart" + EDUCATIONAL = "educational" + +class WorksheetStatus(str, Enum): + DRAFT = "draft" + PUBLISHED = "published" + ARCHIVED = "archived" + +# Style prompt modifiers +STYLE_PROMPTS = { + AIImageStyle.REALISTIC: "photorealistic, high detail, professional photography", + AIImageStyle.CARTOON: "cartoon style, colorful, child-friendly, simple shapes", + AIImageStyle.SKETCH: "pencil sketch, hand-drawn, black and white, artistic", + AIImageStyle.CLIPART: "clipart style, flat design, simple, vector-like", + AIImageStyle.EDUCATIONAL: "educational illustration, clear, informative, textbook style" +} + +# ============================================= +# REQUEST/RESPONSE MODELS +# ============================================= + +class AIImageRequest(BaseModel): + prompt: str = Field(..., min_length=3, max_length=500) + style: AIImageStyle = AIImageStyle.EDUCATIONAL + width: int = Field(512, ge=256, le=1024) + height: int = Field(512, ge=256, le=1024) + +class AIImageResponse(BaseModel): + image_base64: str + prompt_used: str + error: Optional[str] = None + +class PageData(BaseModel): + id: str + index: int + canvasJSON: str + +class PageFormat(BaseModel): + width: float = 210 + height: float = 297 + orientation: str = "portrait" + margins: Dict[str, float] = {"top": 15, "right": 15, "bottom": 15, "left": 15} + +class WorksheetSaveRequest(BaseModel): + id: Optional[str] = None + title: str + description: Optional[str] = None + pages: List[PageData] + pageFormat: Optional[PageFormat] = None + +class WorksheetResponse(BaseModel): + id: str + title: str + description: Optional[str] + pages: List[PageData] + pageFormat: PageFormat + createdAt: str + updatedAt: str + +class AIModifyRequest(BaseModel): + prompt: str = Field(..., min_length=3, max_length=1000) + canvas_json: str + model: str = "qwen2.5vl:32b" + +class AIModifyResponse(BaseModel): + modified_canvas_json: Optional[str] = None + message: str + error: Optional[str] = None + +class ReconstructRequest(BaseModel): + session_id: str + page_number: int = 1 + include_images: bool = True + regenerate_graphics: bool = False + +class ReconstructResponse(BaseModel): + canvas_json: str + page_width: int + page_height: int + elements_count: int + vocabulary_matched: int + message: str + error: Optional[str] = None + +# ============================================= +# IN-MEMORY STORAGE (Development) +# ============================================= + +worksheets_db: Dict[str, Dict] = {} + +# PDF Generation availability +try: + from reportlab.lib import colors # noqa: F401 + from reportlab.lib.pagesizes import A4 # noqa: F401 + from reportlab.lib.units import mm # noqa: F401 + from reportlab.pdfgen import canvas # noqa: F401 + from reportlab.lib.styles import getSampleStyleSheet # noqa: F401 + REPORTLAB_AVAILABLE = True +except ImportError: + REPORTLAB_AVAILABLE = False diff --git a/klausur-service/backend/worksheet_editor_reconstruct.py b/klausur-service/backend/worksheet_editor_reconstruct.py new file mode 100644 index 0000000..b17f2c2 --- /dev/null +++ b/klausur-service/backend/worksheet_editor_reconstruct.py @@ -0,0 +1,255 @@ +""" +Worksheet Editor Reconstruct — Document reconstruction from vocab sessions. +""" + +import io +import uuid +import base64 +import logging +from typing import List, Dict + +import numpy as np + +from worksheet_editor_models import ( + ReconstructRequest, + ReconstructResponse, +) + +logger = logging.getLogger(__name__) + + +async def reconstruct_document_logic(request: ReconstructRequest) -> ReconstructResponse: + """ + Reconstruct a document from a vocab session into Fabric.js canvas format. + + This function: + 1. Loads the original PDF from the vocab session + 2. Runs OCR with position tracking + 3. Creates Fabric.js canvas JSON with positioned elements + 4. Maps extracted vocabulary to their positions + + Returns ReconstructResponse ready to send to the client. + """ + from fastapi import HTTPException + from vocab_worksheet_api import _sessions, convert_pdf_page_to_image + + # Check if session exists + if request.session_id not in _sessions: + raise HTTPException(status_code=404, detail=f"Session {request.session_id} not found") + + session = _sessions[request.session_id] + + if not session.get("pdf_data"): + raise HTTPException(status_code=400, detail="Session has no PDF data") + + pdf_data = session["pdf_data"] + page_count = session.get("pdf_page_count", 1) + + if request.page_number < 1 or request.page_number > page_count: + raise HTTPException( + status_code=400, + detail=f"Page {request.page_number} not found. PDF has {page_count} pages." + ) + + vocabulary = session.get("vocabulary", []) + page_vocab = [v for v in vocabulary if v.get("source_page") == request.page_number] + + logger.info(f"Reconstructing page {request.page_number} from session {request.session_id}") + logger.info(f"Found {len(page_vocab)} vocabulary items for this page") + + image_bytes = await convert_pdf_page_to_image(pdf_data, request.page_number) + if not image_bytes: + raise HTTPException(status_code=500, detail="Failed to convert PDF page to image") + + from PIL import Image + img = Image.open(io.BytesIO(image_bytes)) + img_width, img_height = img.size + + from hybrid_vocab_extractor import run_paddle_ocr + ocr_regions, raw_text = run_paddle_ocr(image_bytes) + + logger.info(f"OCR found {len(ocr_regions)} text regions") + + A4_WIDTH = 794 + A4_HEIGHT = 1123 + scale_x = A4_WIDTH / img_width + scale_y = A4_HEIGHT / img_height + + fabric_objects = [] + + # 1. Add white background + fabric_objects.append({ + "type": "rect", "left": 0, "top": 0, + "width": A4_WIDTH, "height": A4_HEIGHT, + "fill": "#ffffff", "selectable": False, + "evented": False, "isBackground": True + }) + + # 2. Group OCR regions by Y-coordinate to detect rows + sorted_regions = sorted(ocr_regions, key=lambda r: (r.y1, r.x1)) + + # 3. Detect headers (larger text at top) + headers = [] + for region in sorted_regions: + height = region.y2 - region.y1 + if region.y1 < img_height * 0.15 and height > 30: + headers.append(region) + + # 4. Create text objects for each region + vocab_matched = 0 + + for region in sorted_regions: + left = int(region.x1 * scale_x) + top = int(region.y1 * scale_y) + + is_header = region in headers + + region_height = region.y2 - region.y1 + base_font_size = max(10, min(32, int(region_height * scale_y * 0.8))) + + if is_header: + base_font_size = max(base_font_size, 24) + + is_vocab = False + vocab_match = None + for v in page_vocab: + if v.get("english", "").lower() in region.text.lower() or \ + v.get("german", "").lower() in region.text.lower(): + is_vocab = True + vocab_match = v + vocab_matched += 1 + break + + text_obj = { + "type": "i-text", + "id": f"text_{uuid.uuid4().hex[:8]}", + "left": left, "top": top, + "text": region.text, + "fontFamily": "Arial", + "fontSize": base_font_size, + "fontWeight": "bold" if is_header else "normal", + "fill": "#000000", + "originX": "left", "originY": "top", + } + + if is_vocab and vocab_match: + text_obj["isVocabulary"] = True + text_obj["vocabularyId"] = vocab_match.get("id") + text_obj["english"] = vocab_match.get("english") + text_obj["german"] = vocab_match.get("german") + + fabric_objects.append(text_obj) + + # 5. If include_images, detect and extract image regions + if request.include_images: + image_regions = await _detect_image_regions(image_bytes, ocr_regions, img_width, img_height) + + for i, img_region in enumerate(image_regions): + img_x1 = int(img_region["x1"]) + img_y1 = int(img_region["y1"]) + img_x2 = int(img_region["x2"]) + img_y2 = int(img_region["y2"]) + + cropped = img.crop((img_x1, img_y1, img_x2, img_y2)) + + buffer = io.BytesIO() + cropped.save(buffer, format='PNG') + buffer.seek(0) + img_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" + + fabric_objects.append({ + "type": "image", + "id": f"img_{uuid.uuid4().hex[:8]}", + "left": int(img_x1 * scale_x), + "top": int(img_y1 * scale_y), + "width": int((img_x2 - img_x1) * scale_x), + "height": int((img_y2 - img_y1) * scale_y), + "src": img_base64, + "scaleX": 1, "scaleY": 1, + }) + + import json + canvas_data = { + "version": "6.0.0", + "objects": fabric_objects, + "background": "#ffffff" + } + + return ReconstructResponse( + canvas_json=json.dumps(canvas_data), + page_width=A4_WIDTH, + page_height=A4_HEIGHT, + elements_count=len(fabric_objects), + vocabulary_matched=vocab_matched, + message=f"Reconstructed page {request.page_number} with {len(fabric_objects)} elements, " + f"{vocab_matched} vocabulary items matched" + ) + + +async def _detect_image_regions( + image_bytes: bytes, + ocr_regions: list, + img_width: int, + img_height: int +) -> List[Dict]: + """ + Detect image/graphic regions in the document. + + Uses a simple approach: + 1. Find large gaps between text regions (potential image areas) + 2. Use edge detection to find bounded regions + 3. Filter out text areas + """ + from PIL import Image + import cv2 + + try: + img = Image.open(io.BytesIO(image_bytes)) + img_array = np.array(img.convert('L')) + + text_mask = np.ones_like(img_array, dtype=bool) + for region in ocr_regions: + x1 = max(0, region.x1 - 5) + y1 = max(0, region.y1 - 5) + x2 = min(img_width, region.x2 + 5) + y2 = min(img_height, region.y2 + 5) + text_mask[y1:y2, x1:x2] = False + + image_regions = [] + + edges = cv2.Canny(img_array, 50, 150) + edges[~text_mask] = 0 + + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + + if w > 50 and h > 50: + if w < img_width * 0.9 and h < img_height * 0.9: + region_content = img_array[y:y+h, x:x+w] + variance = np.var(region_content) + + if variance > 500: + image_regions.append({ + "x1": x, "y1": y, + "x2": x + w, "y2": y + h + }) + + filtered_regions = [] + for region in sorted(image_regions, key=lambda r: (r["x2"]-r["x1"])*(r["y2"]-r["y1"]), reverse=True): + overlaps = False + for existing in filtered_regions: + if not (region["x2"] < existing["x1"] or region["x1"] > existing["x2"] or + region["y2"] < existing["y1"] or region["y1"] > existing["y2"]): + overlaps = True + break + if not overlaps: + filtered_regions.append(region) + + logger.info(f"Detected {len(filtered_regions)} image regions") + return filtered_regions[:10] + + except Exception as e: + logger.warning(f"Image region detection failed: {e}") + return []