[split-required] Split remaining Python monoliths (Phase 1 continued)

klausur-service (7 monoliths):
- grid_editor_helpers.py (1,737 → 5 files: columns, filters, headers, zones)
- cv_cell_grid.py (1,675 → 7 files: build, legacy, streaming, merge, vocab)
- worksheet_editor_api.py (1,305 → 4 files: models, AI, reconstruct, routes)
- legal_corpus_ingestion.py (1,280 → 3 files: registry, chunking, ingestion)
- cv_review.py (1,248 → 4 files: pipeline, spell, LLM, barrel)
- cv_preprocessing.py (1,166 → 3 files: deskew, dewarp, barrel)
- rbac.py, admin_api.py, routes/eh.py remain (next batch)

backend-lehrer (1 monolith):
- classroom_engine/repository.py (1,705 → 7 files by domain)

All re-export barrels preserve backward compatibility.
Zero import errors verified.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-24 22:47:59 +02:00
parent 0b37c5e692
commit b2a0126f14
34 changed files with 9264 additions and 9164 deletions

View File

@@ -17,6 +17,7 @@
# Pure Data Registries (keine Logik, nur Daten-Definitionen)
**/dsfa_sources_registry.py | owner=klausur | reason=Pure data registry (license + source definitions, no logic) | review=2027-01-01
**/legal_corpus_registry.py | owner=klausur | reason=Pure data registry (Regulation dataclass + 47 regulation definitions, no logic) | review=2027-01-01
**/backlog/backlog-items.ts | owner=admin-lehrer | reason=Pure data array (506 LOC, no logic, only BacklogItem[] literals) | review=2027-01-01
**/lib/module-registry-data.ts | owner=admin-lehrer | reason=Pure data array (510 LOC, no logic, only BackendModule[] literals) | review=2027-01-01

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,453 @@
"""
Teacher Context, Schoolyear Event & Recurring Routine Repositories.
CRUD-Operationen fuer Schuljahres-Kontext (Phase 8).
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session as DBSession
from .context_models import (
TeacherContextDB, SchoolyearEventDB, RecurringRoutineDB,
MacroPhaseEnum, EventTypeEnum, EventStatusEnum,
RoutineTypeEnum, RecurrencePatternEnum,
FEDERAL_STATES, SCHOOL_TYPES,
)
class TeacherContextRepository:
"""Repository fuer Lehrer-Kontext CRUD-Operationen (Phase 8)."""
def __init__(self, db: DBSession):
self.db = db
# ==================== CREATE / GET-OR-CREATE ====================
def get_or_create(self, teacher_id: str) -> TeacherContextDB:
"""
Holt den Kontext eines Lehrers oder erstellt einen neuen.
Args:
teacher_id: ID des Lehrers
Returns:
TeacherContextDB Model
"""
context = self.get_by_teacher_id(teacher_id)
if context:
return context
# Neuen Kontext erstellen
from uuid import uuid4
context = TeacherContextDB(
id=str(uuid4()),
teacher_id=teacher_id,
macro_phase=MacroPhaseEnum.ONBOARDING,
)
self.db.add(context)
self.db.commit()
self.db.refresh(context)
return context
# ==================== READ ====================
def get_by_teacher_id(self, teacher_id: str) -> Optional[TeacherContextDB]:
"""Holt den Kontext eines Lehrers."""
return self.db.query(TeacherContextDB).filter(
TeacherContextDB.teacher_id == teacher_id
).first()
# ==================== UPDATE ====================
def update_context(
self,
teacher_id: str,
federal_state: str = None,
school_type: str = None,
schoolyear: str = None,
schoolyear_start: datetime = None,
macro_phase: str = None,
current_week: int = None,
) -> Optional[TeacherContextDB]:
"""Aktualisiert den Kontext eines Lehrers."""
context = self.get_or_create(teacher_id)
if federal_state is not None:
context.federal_state = federal_state
if school_type is not None:
context.school_type = school_type
if schoolyear is not None:
context.schoolyear = schoolyear
if schoolyear_start is not None:
context.schoolyear_start = schoolyear_start
if macro_phase is not None:
context.macro_phase = MacroPhaseEnum(macro_phase)
if current_week is not None:
context.current_week = current_week
self.db.commit()
self.db.refresh(context)
return context
def complete_onboarding(self, teacher_id: str) -> TeacherContextDB:
"""Markiert Onboarding als abgeschlossen."""
context = self.get_or_create(teacher_id)
context.onboarding_completed = True
context.macro_phase = MacroPhaseEnum.SCHULJAHRESSTART
self.db.commit()
self.db.refresh(context)
return context
def update_flags(
self,
teacher_id: str,
has_classes: bool = None,
has_schedule: bool = None,
is_exam_period: bool = None,
is_before_holidays: bool = None,
) -> TeacherContextDB:
"""Aktualisiert die Status-Flags eines Kontexts."""
context = self.get_or_create(teacher_id)
if has_classes is not None:
context.has_classes = has_classes
if has_schedule is not None:
context.has_schedule = has_schedule
if is_exam_period is not None:
context.is_exam_period = is_exam_period
if is_before_holidays is not None:
context.is_before_holidays = is_before_holidays
self.db.commit()
self.db.refresh(context)
return context
def to_dict(self, context: TeacherContextDB) -> Dict[str, Any]:
"""Konvertiert DB-Model zu Dictionary."""
return {
"id": context.id,
"teacher_id": context.teacher_id,
"school": {
"federal_state": context.federal_state,
"federal_state_name": FEDERAL_STATES.get(context.federal_state, ""),
"school_type": context.school_type,
"school_type_name": SCHOOL_TYPES.get(context.school_type, ""),
},
"school_year": {
"id": context.schoolyear,
"start": context.schoolyear_start.isoformat() if context.schoolyear_start else None,
"current_week": context.current_week,
},
"macro_phase": {
"id": context.macro_phase.value,
"label": self._get_phase_label(context.macro_phase),
},
"flags": {
"onboarding_completed": context.onboarding_completed,
"has_classes": context.has_classes,
"has_schedule": context.has_schedule,
"is_exam_period": context.is_exam_period,
"is_before_holidays": context.is_before_holidays,
},
"created_at": context.created_at.isoformat() if context.created_at else None,
"updated_at": context.updated_at.isoformat() if context.updated_at else None,
}
def _get_phase_label(self, phase: MacroPhaseEnum) -> str:
"""Gibt den Anzeigenamen einer Makro-Phase zurueck."""
labels = {
MacroPhaseEnum.ONBOARDING: "Einrichtung",
MacroPhaseEnum.SCHULJAHRESSTART: "Schuljahresstart",
MacroPhaseEnum.UNTERRICHTSAUFBAU: "Unterrichtsaufbau",
MacroPhaseEnum.LEISTUNGSPHASE_1: "Leistungsphase 1",
MacroPhaseEnum.HALBJAHRESABSCHLUSS: "Halbjahresabschluss",
MacroPhaseEnum.LEISTUNGSPHASE_2: "Leistungsphase 2",
MacroPhaseEnum.JAHRESABSCHLUSS: "Jahresabschluss",
}
return labels.get(phase, phase.value)
class SchoolyearEventRepository:
"""Repository fuer Schuljahr-Events (Phase 8)."""
def __init__(self, db: DBSession):
self.db = db
def create(
self,
teacher_id: str,
title: str,
start_date: datetime,
event_type: str = "other",
end_date: datetime = None,
class_id: str = None,
subject: str = None,
description: str = "",
needs_preparation: bool = True,
reminder_days_before: int = 7,
extra_data: Dict[str, Any] = None,
) -> SchoolyearEventDB:
"""Erstellt ein neues Schuljahr-Event."""
from uuid import uuid4
event = SchoolyearEventDB(
id=str(uuid4()),
teacher_id=teacher_id,
title=title,
event_type=EventTypeEnum(event_type),
start_date=start_date,
end_date=end_date,
class_id=class_id,
subject=subject,
description=description,
needs_preparation=needs_preparation,
reminder_days_before=reminder_days_before,
extra_data=extra_data or {},
)
self.db.add(event)
self.db.commit()
self.db.refresh(event)
return event
def get_by_id(self, event_id: str) -> Optional[SchoolyearEventDB]:
"""Holt ein Event nach ID."""
return self.db.query(SchoolyearEventDB).filter(
SchoolyearEventDB.id == event_id
).first()
def get_by_teacher(
self,
teacher_id: str,
status: str = None,
event_type: str = None,
limit: int = 50,
) -> List[SchoolyearEventDB]:
"""Holt Events eines Lehrers."""
query = self.db.query(SchoolyearEventDB).filter(
SchoolyearEventDB.teacher_id == teacher_id
)
if status:
query = query.filter(SchoolyearEventDB.status == EventStatusEnum(status))
if event_type:
query = query.filter(SchoolyearEventDB.event_type == EventTypeEnum(event_type))
return query.order_by(SchoolyearEventDB.start_date).limit(limit).all()
def get_upcoming(
self,
teacher_id: str,
days: int = 30,
limit: int = 10,
) -> List[SchoolyearEventDB]:
"""Holt anstehende Events der naechsten X Tage."""
from datetime import timedelta
now = datetime.utcnow()
end = now + timedelta(days=days)
return self.db.query(SchoolyearEventDB).filter(
SchoolyearEventDB.teacher_id == teacher_id,
SchoolyearEventDB.start_date >= now,
SchoolyearEventDB.start_date <= end,
SchoolyearEventDB.status != EventStatusEnum.CANCELLED,
).order_by(SchoolyearEventDB.start_date).limit(limit).all()
def update_status(
self,
event_id: str,
status: str,
preparation_done: bool = None,
) -> Optional[SchoolyearEventDB]:
"""Aktualisiert den Status eines Events."""
event = self.get_by_id(event_id)
if not event:
return None
event.status = EventStatusEnum(status)
if preparation_done is not None:
event.preparation_done = preparation_done
self.db.commit()
self.db.refresh(event)
return event
def delete(self, event_id: str) -> bool:
"""Loescht ein Event."""
event = self.get_by_id(event_id)
if not event:
return False
self.db.delete(event)
self.db.commit()
return True
def to_dict(self, event: SchoolyearEventDB) -> Dict[str, Any]:
"""Konvertiert DB-Model zu Dictionary."""
return {
"id": event.id,
"teacher_id": event.teacher_id,
"event_type": event.event_type.value,
"title": event.title,
"description": event.description,
"start_date": event.start_date.isoformat() if event.start_date else None,
"end_date": event.end_date.isoformat() if event.end_date else None,
"class_id": event.class_id,
"subject": event.subject,
"status": event.status.value,
"needs_preparation": event.needs_preparation,
"preparation_done": event.preparation_done,
"reminder_days_before": event.reminder_days_before,
"extra_data": event.extra_data,
"created_at": event.created_at.isoformat() if event.created_at else None,
}
class RecurringRoutineRepository:
"""Repository fuer wiederkehrende Routinen (Phase 8)."""
def __init__(self, db: DBSession):
self.db = db
def create(
self,
teacher_id: str,
title: str,
routine_type: str = "other",
recurrence_pattern: str = "weekly",
day_of_week: int = None,
day_of_month: int = None,
time_of_day: str = None, # Format: "14:00"
duration_minutes: int = 60,
description: str = "",
valid_from: datetime = None,
valid_until: datetime = None,
) -> RecurringRoutineDB:
"""Erstellt eine neue wiederkehrende Routine."""
from uuid import uuid4
from datetime import time as dt_time
time_obj = None
if time_of_day:
parts = time_of_day.split(":")
time_obj = dt_time(int(parts[0]), int(parts[1]))
routine = RecurringRoutineDB(
id=str(uuid4()),
teacher_id=teacher_id,
title=title,
routine_type=RoutineTypeEnum(routine_type),
recurrence_pattern=RecurrencePatternEnum(recurrence_pattern),
day_of_week=day_of_week,
day_of_month=day_of_month,
time_of_day=time_obj,
duration_minutes=duration_minutes,
description=description,
valid_from=valid_from,
valid_until=valid_until,
)
self.db.add(routine)
self.db.commit()
self.db.refresh(routine)
return routine
def get_by_id(self, routine_id: str) -> Optional[RecurringRoutineDB]:
"""Holt eine Routine nach ID."""
return self.db.query(RecurringRoutineDB).filter(
RecurringRoutineDB.id == routine_id
).first()
def get_by_teacher(
self,
teacher_id: str,
is_active: bool = True,
routine_type: str = None,
) -> List[RecurringRoutineDB]:
"""Holt Routinen eines Lehrers."""
query = self.db.query(RecurringRoutineDB).filter(
RecurringRoutineDB.teacher_id == teacher_id
)
if is_active is not None:
query = query.filter(RecurringRoutineDB.is_active == is_active)
if routine_type:
query = query.filter(RecurringRoutineDB.routine_type == RoutineTypeEnum(routine_type))
return query.all()
def get_today(self, teacher_id: str) -> List[RecurringRoutineDB]:
"""Holt Routinen die heute stattfinden."""
today = datetime.utcnow()
day_of_week = today.weekday() # 0 = Montag
day_of_month = today.day
routines = self.get_by_teacher(teacher_id, is_active=True)
today_routines = []
for routine in routines:
if routine.recurrence_pattern == RecurrencePatternEnum.DAILY:
today_routines.append(routine)
elif routine.recurrence_pattern == RecurrencePatternEnum.WEEKLY:
if routine.day_of_week == day_of_week:
today_routines.append(routine)
elif routine.recurrence_pattern == RecurrencePatternEnum.BIWEEKLY:
# Vereinfacht: Pruefen ob Tag passt (echte Logik braucht Startdatum)
if routine.day_of_week == day_of_week:
today_routines.append(routine)
elif routine.recurrence_pattern == RecurrencePatternEnum.MONTHLY:
if routine.day_of_month == day_of_month:
today_routines.append(routine)
return today_routines
def update(
self,
routine_id: str,
title: str = None,
is_active: bool = None,
day_of_week: int = None,
time_of_day: str = None,
) -> Optional[RecurringRoutineDB]:
"""Aktualisiert eine Routine."""
routine = self.get_by_id(routine_id)
if not routine:
return None
if title is not None:
routine.title = title
if is_active is not None:
routine.is_active = is_active
if day_of_week is not None:
routine.day_of_week = day_of_week
if time_of_day is not None:
from datetime import time as dt_time
parts = time_of_day.split(":")
routine.time_of_day = dt_time(int(parts[0]), int(parts[1]))
self.db.commit()
self.db.refresh(routine)
return routine
def delete(self, routine_id: str) -> bool:
"""Loescht eine Routine."""
routine = self.get_by_id(routine_id)
if not routine:
return False
self.db.delete(routine)
self.db.commit()
return True
def to_dict(self, routine: RecurringRoutineDB) -> Dict[str, Any]:
"""Konvertiert DB-Model zu Dictionary."""
return {
"id": routine.id,
"teacher_id": routine.teacher_id,
"routine_type": routine.routine_type.value,
"title": routine.title,
"description": routine.description,
"recurrence_pattern": routine.recurrence_pattern.value,
"day_of_week": routine.day_of_week,
"day_of_month": routine.day_of_month,
"time_of_day": routine.time_of_day.isoformat() if routine.time_of_day else None,
"duration_minutes": routine.duration_minutes,
"is_active": routine.is_active,
"valid_from": routine.valid_from.isoformat() if routine.valid_from else None,
"valid_until": routine.valid_until.isoformat() if routine.valid_until else None,
"created_at": routine.created_at.isoformat() if routine.created_at else None,
}

View File

@@ -0,0 +1,182 @@
"""
Teacher Feedback Repository.
CRUD-Operationen fuer Lehrer-Feedback (Phase 7).
Ermoeglicht Lehrern, Bugs, Feature-Requests und Verbesserungen zu melden.
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session as DBSession
from .db_models import (
TeacherFeedbackDB, FeedbackTypeEnum, FeedbackStatusEnum,
FeedbackPriorityEnum,
)
class TeacherFeedbackRepository:
"""
Repository fuer Lehrer-Feedback CRUD-Operationen.
Ermoeglicht Lehrern, Feedback (Bugs, Feature-Requests, Verbesserungen)
direkt aus dem Lehrer-Frontend zu senden.
"""
def __init__(self, db: DBSession):
self.db = db
def create(
self,
teacher_id: str,
title: str,
description: str,
feedback_type: str = "improvement",
priority: str = "medium",
teacher_name: str = "",
teacher_email: str = "",
context_url: str = "",
context_phase: str = "",
context_session_id: str = None,
user_agent: str = "",
related_feature: str = None,
) -> TeacherFeedbackDB:
"""Erstellt neues Feedback."""
import uuid
db_feedback = TeacherFeedbackDB(
id=str(uuid.uuid4()),
teacher_id=teacher_id,
teacher_name=teacher_name,
teacher_email=teacher_email,
title=title,
description=description,
feedback_type=FeedbackTypeEnum(feedback_type),
priority=FeedbackPriorityEnum(priority),
status=FeedbackStatusEnum.NEW,
related_feature=related_feature,
context_url=context_url,
context_phase=context_phase,
context_session_id=context_session_id,
user_agent=user_agent,
)
self.db.add(db_feedback)
self.db.commit()
self.db.refresh(db_feedback)
return db_feedback
def get_by_id(self, feedback_id: str) -> Optional[TeacherFeedbackDB]:
"""Holt Feedback nach ID."""
return self.db.query(TeacherFeedbackDB).filter(
TeacherFeedbackDB.id == feedback_id
).first()
def get_all(
self,
status: str = None,
feedback_type: str = None,
limit: int = 100,
offset: int = 0
) -> List[TeacherFeedbackDB]:
"""Holt alle Feedbacks mit optionalen Filtern."""
query = self.db.query(TeacherFeedbackDB)
if status:
query = query.filter(TeacherFeedbackDB.status == FeedbackStatusEnum(status))
if feedback_type:
query = query.filter(TeacherFeedbackDB.feedback_type == FeedbackTypeEnum(feedback_type))
return query.order_by(
TeacherFeedbackDB.created_at.desc()
).offset(offset).limit(limit).all()
def get_by_teacher(self, teacher_id: str, limit: int = 50) -> List[TeacherFeedbackDB]:
"""Holt Feedback eines bestimmten Lehrers."""
return self.db.query(TeacherFeedbackDB).filter(
TeacherFeedbackDB.teacher_id == teacher_id
).order_by(
TeacherFeedbackDB.created_at.desc()
).limit(limit).all()
def update_status(
self,
feedback_id: str,
status: str,
response: str = None,
responded_by: str = None
) -> Optional[TeacherFeedbackDB]:
"""Aktualisiert den Status eines Feedbacks."""
db_feedback = self.get_by_id(feedback_id)
if not db_feedback:
return None
db_feedback.status = FeedbackStatusEnum(status)
if response:
db_feedback.response = response
db_feedback.responded_at = datetime.utcnow()
db_feedback.responded_by = responded_by
self.db.commit()
self.db.refresh(db_feedback)
return db_feedback
def delete(self, feedback_id: str) -> bool:
"""Loescht ein Feedback."""
db_feedback = self.get_by_id(feedback_id)
if not db_feedback:
return False
self.db.delete(db_feedback)
self.db.commit()
return True
def get_stats(self) -> Dict[str, Any]:
"""Gibt Statistiken ueber alle Feedbacks zurueck."""
all_feedback = self.db.query(TeacherFeedbackDB).all()
stats = {
"total": len(all_feedback),
"by_status": {},
"by_type": {},
"by_priority": {},
}
for fb in all_feedback:
# By Status
status = fb.status.value
stats["by_status"][status] = stats["by_status"].get(status, 0) + 1
# By Type
fb_type = fb.feedback_type.value
stats["by_type"][fb_type] = stats["by_type"].get(fb_type, 0) + 1
# By Priority
priority = fb.priority.value
stats["by_priority"][priority] = stats["by_priority"].get(priority, 0) + 1
return stats
def to_dict(self, db_feedback: TeacherFeedbackDB) -> Dict[str, Any]:
"""Konvertiert DB-Model zu Dictionary."""
return {
"id": db_feedback.id,
"teacher_id": db_feedback.teacher_id,
"teacher_name": db_feedback.teacher_name,
"teacher_email": db_feedback.teacher_email,
"title": db_feedback.title,
"description": db_feedback.description,
"feedback_type": db_feedback.feedback_type.value,
"priority": db_feedback.priority.value,
"status": db_feedback.status.value,
"related_feature": db_feedback.related_feature,
"context_url": db_feedback.context_url,
"context_phase": db_feedback.context_phase,
"context_session_id": db_feedback.context_session_id,
"user_agent": db_feedback.user_agent,
"response": db_feedback.response,
"responded_at": db_feedback.responded_at.isoformat() if db_feedback.responded_at else None,
"responded_by": db_feedback.responded_by,
"created_at": db_feedback.created_at.isoformat() if db_feedback.created_at else None,
"updated_at": db_feedback.updated_at.isoformat() if db_feedback.updated_at else None,
}

View File

@@ -0,0 +1,382 @@
"""
Homework & Material Repositories.
CRUD-Operationen fuer Hausaufgaben (Feature f20) und Phasen-Materialien (Feature f19).
"""
from datetime import datetime
from typing import Optional, List
from sqlalchemy.orm import Session as DBSession
from .db_models import (
HomeworkDB, HomeworkStatusEnum, PhaseMaterialDB, MaterialTypeEnum,
)
from .models import (
Homework, HomeworkStatus, PhaseMaterial, MaterialType,
)
class HomeworkRepository:
"""Repository fuer Hausaufgaben-Tracking (Feature f20)."""
def __init__(self, db: DBSession):
self.db = db
# ==================== CREATE ====================
def create(self, homework: Homework) -> HomeworkDB:
"""Erstellt eine neue Hausaufgabe."""
db_homework = HomeworkDB(
id=homework.homework_id,
teacher_id=homework.teacher_id,
class_id=homework.class_id,
subject=homework.subject,
title=homework.title,
description=homework.description,
session_id=homework.session_id,
due_date=homework.due_date,
status=HomeworkStatusEnum(homework.status.value),
)
self.db.add(db_homework)
self.db.commit()
self.db.refresh(db_homework)
return db_homework
# ==================== READ ====================
def get_by_id(self, homework_id: str) -> Optional[HomeworkDB]:
"""Holt eine Hausaufgabe nach ID."""
return self.db.query(HomeworkDB).filter(
HomeworkDB.id == homework_id
).first()
def get_by_teacher(
self,
teacher_id: str,
status: Optional[str] = None,
limit: int = 50
) -> List[HomeworkDB]:
"""Holt alle Hausaufgaben eines Lehrers."""
query = self.db.query(HomeworkDB).filter(
HomeworkDB.teacher_id == teacher_id
)
if status:
query = query.filter(HomeworkDB.status == HomeworkStatusEnum(status))
return query.order_by(
HomeworkDB.due_date.asc().nullslast(),
HomeworkDB.created_at.desc()
).limit(limit).all()
def get_by_class(
self,
class_id: str,
teacher_id: str,
include_completed: bool = False,
limit: int = 20
) -> List[HomeworkDB]:
"""Holt alle Hausaufgaben einer Klasse."""
query = self.db.query(HomeworkDB).filter(
HomeworkDB.class_id == class_id,
HomeworkDB.teacher_id == teacher_id
)
if not include_completed:
query = query.filter(HomeworkDB.status != HomeworkStatusEnum.COMPLETED)
return query.order_by(
HomeworkDB.due_date.asc().nullslast(),
HomeworkDB.created_at.desc()
).limit(limit).all()
def get_by_session(self, session_id: str) -> List[HomeworkDB]:
"""Holt alle Hausaufgaben einer Session."""
return self.db.query(HomeworkDB).filter(
HomeworkDB.session_id == session_id
).order_by(HomeworkDB.created_at.desc()).all()
def get_pending(
self,
teacher_id: str,
days_ahead: int = 7
) -> List[HomeworkDB]:
"""Holt anstehende Hausaufgaben der naechsten X Tage."""
from datetime import timedelta
cutoff = datetime.utcnow() + timedelta(days=days_ahead)
return self.db.query(HomeworkDB).filter(
HomeworkDB.teacher_id == teacher_id,
HomeworkDB.status.in_([HomeworkStatusEnum.ASSIGNED, HomeworkStatusEnum.IN_PROGRESS]),
HomeworkDB.due_date <= cutoff
).order_by(HomeworkDB.due_date.asc()).all()
# ==================== UPDATE ====================
def update_status(
self,
homework_id: str,
status: HomeworkStatus
) -> Optional[HomeworkDB]:
"""Aktualisiert den Status einer Hausaufgabe."""
db_homework = self.get_by_id(homework_id)
if not db_homework:
return None
db_homework.status = HomeworkStatusEnum(status.value)
self.db.commit()
self.db.refresh(db_homework)
return db_homework
def update(self, homework: Homework) -> Optional[HomeworkDB]:
"""Aktualisiert eine Hausaufgabe."""
db_homework = self.get_by_id(homework.homework_id)
if not db_homework:
return None
db_homework.title = homework.title
db_homework.description = homework.description
db_homework.due_date = homework.due_date
db_homework.status = HomeworkStatusEnum(homework.status.value)
self.db.commit()
self.db.refresh(db_homework)
return db_homework
# ==================== DELETE ====================
def delete(self, homework_id: str) -> bool:
"""Loescht eine Hausaufgabe."""
db_homework = self.get_by_id(homework_id)
if not db_homework:
return False
self.db.delete(db_homework)
self.db.commit()
return True
# ==================== CONVERSION ====================
def to_dataclass(self, db_homework: HomeworkDB) -> Homework:
"""Konvertiert DB-Model zu Dataclass."""
return Homework(
homework_id=db_homework.id,
teacher_id=db_homework.teacher_id,
class_id=db_homework.class_id,
subject=db_homework.subject,
title=db_homework.title,
description=db_homework.description or "",
session_id=db_homework.session_id,
due_date=db_homework.due_date,
status=HomeworkStatus(db_homework.status.value),
created_at=db_homework.created_at,
updated_at=db_homework.updated_at,
)
class MaterialRepository:
"""Repository fuer Phasen-Materialien (Feature f19)."""
def __init__(self, db: DBSession):
self.db = db
# ==================== CREATE ====================
def create(self, material: PhaseMaterial) -> PhaseMaterialDB:
"""Erstellt ein neues Material."""
db_material = PhaseMaterialDB(
id=material.material_id,
teacher_id=material.teacher_id,
title=material.title,
material_type=MaterialTypeEnum(material.material_type.value),
url=material.url,
description=material.description,
phase=material.phase,
subject=material.subject,
grade_level=material.grade_level,
tags=material.tags,
is_public=material.is_public,
usage_count=material.usage_count,
session_id=material.session_id,
)
self.db.add(db_material)
self.db.commit()
self.db.refresh(db_material)
return db_material
# ==================== READ ====================
def get_by_id(self, material_id: str) -> Optional[PhaseMaterialDB]:
"""Holt ein Material nach ID."""
return self.db.query(PhaseMaterialDB).filter(
PhaseMaterialDB.id == material_id
).first()
def get_by_teacher(
self,
teacher_id: str,
phase: Optional[str] = None,
subject: Optional[str] = None,
limit: int = 50
) -> List[PhaseMaterialDB]:
"""Holt alle Materialien eines Lehrers."""
query = self.db.query(PhaseMaterialDB).filter(
PhaseMaterialDB.teacher_id == teacher_id
)
if phase:
query = query.filter(PhaseMaterialDB.phase == phase)
if subject:
query = query.filter(PhaseMaterialDB.subject == subject)
return query.order_by(
PhaseMaterialDB.usage_count.desc(),
PhaseMaterialDB.created_at.desc()
).limit(limit).all()
def get_by_phase(
self,
phase: str,
teacher_id: str,
include_public: bool = True
) -> List[PhaseMaterialDB]:
"""Holt alle Materialien fuer eine bestimmte Phase."""
if include_public:
return self.db.query(PhaseMaterialDB).filter(
PhaseMaterialDB.phase == phase,
(PhaseMaterialDB.teacher_id == teacher_id) |
(PhaseMaterialDB.is_public == True)
).order_by(
PhaseMaterialDB.usage_count.desc()
).all()
else:
return self.db.query(PhaseMaterialDB).filter(
PhaseMaterialDB.phase == phase,
PhaseMaterialDB.teacher_id == teacher_id
).order_by(
PhaseMaterialDB.created_at.desc()
).all()
def get_by_session(self, session_id: str) -> List[PhaseMaterialDB]:
"""Holt alle Materialien einer Session."""
return self.db.query(PhaseMaterialDB).filter(
PhaseMaterialDB.session_id == session_id
).order_by(PhaseMaterialDB.phase, PhaseMaterialDB.created_at).all()
def get_public_materials(
self,
phase: Optional[str] = None,
subject: Optional[str] = None,
limit: int = 20
) -> List[PhaseMaterialDB]:
"""Holt oeffentliche Materialien."""
query = self.db.query(PhaseMaterialDB).filter(
PhaseMaterialDB.is_public == True
)
if phase:
query = query.filter(PhaseMaterialDB.phase == phase)
if subject:
query = query.filter(PhaseMaterialDB.subject == subject)
return query.order_by(
PhaseMaterialDB.usage_count.desc()
).limit(limit).all()
def search_by_tags(
self,
tags: List[str],
teacher_id: Optional[str] = None
) -> List[PhaseMaterialDB]:
"""Sucht Materialien nach Tags."""
query = self.db.query(PhaseMaterialDB)
if teacher_id:
query = query.filter(
(PhaseMaterialDB.teacher_id == teacher_id) |
(PhaseMaterialDB.is_public == True)
)
else:
query = query.filter(PhaseMaterialDB.is_public == True)
# Filter by tags - vereinfachte Implementierung
results = []
for material in query.all():
if material.tags and any(tag in material.tags for tag in tags):
results.append(material)
return results[:50]
# ==================== UPDATE ====================
def update(self, material: PhaseMaterial) -> Optional[PhaseMaterialDB]:
"""Aktualisiert ein Material."""
db_material = self.get_by_id(material.material_id)
if not db_material:
return None
db_material.title = material.title
db_material.material_type = MaterialTypeEnum(material.material_type.value)
db_material.url = material.url
db_material.description = material.description
db_material.phase = material.phase
db_material.subject = material.subject
db_material.grade_level = material.grade_level
db_material.tags = material.tags
db_material.is_public = material.is_public
self.db.commit()
self.db.refresh(db_material)
return db_material
def increment_usage(self, material_id: str) -> Optional[PhaseMaterialDB]:
"""Erhoeht den Usage-Counter eines Materials."""
db_material = self.get_by_id(material_id)
if not db_material:
return None
db_material.usage_count += 1
self.db.commit()
self.db.refresh(db_material)
return db_material
def attach_to_session(
self,
material_id: str,
session_id: str
) -> Optional[PhaseMaterialDB]:
"""Verknuepft ein Material mit einer Session."""
db_material = self.get_by_id(material_id)
if not db_material:
return None
db_material.session_id = session_id
db_material.usage_count += 1
self.db.commit()
self.db.refresh(db_material)
return db_material
# ==================== DELETE ====================
def delete(self, material_id: str) -> bool:
"""Loescht ein Material."""
db_material = self.get_by_id(material_id)
if not db_material:
return False
self.db.delete(db_material)
self.db.commit()
return True
# ==================== CONVERSION ====================
def to_dataclass(self, db_material: PhaseMaterialDB) -> PhaseMaterial:
"""Konvertiert DB-Model zu Dataclass."""
return PhaseMaterial(
material_id=db_material.id,
teacher_id=db_material.teacher_id,
title=db_material.title,
material_type=MaterialType(db_material.material_type.value),
url=db_material.url,
description=db_material.description or "",
phase=db_material.phase,
subject=db_material.subject or "",
grade_level=db_material.grade_level or "",
tags=db_material.tags or [],
is_public=db_material.is_public,
usage_count=db_material.usage_count,
session_id=db_material.session_id,
created_at=db_material.created_at,
updated_at=db_material.updated_at,
)

View File

@@ -0,0 +1,315 @@
"""
Reflection & Analytics Repositories.
CRUD-Operationen fuer Lesson-Reflections und Analytics-Abfragen (Phase 5).
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session as DBSession
from .db_models import LessonSessionDB, LessonPhaseEnum, LessonReflectionDB
from .analytics import (
LessonReflection, SessionSummary, TeacherAnalytics, AnalyticsCalculator,
)
class ReflectionRepository:
"""Repository fuer LessonReflection CRUD-Operationen."""
def __init__(self, db: DBSession):
self.db = db
# ==================== CREATE ====================
def create(self, reflection: LessonReflection) -> LessonReflectionDB:
"""Erstellt eine neue Reflection."""
db_reflection = LessonReflectionDB(
id=reflection.reflection_id,
session_id=reflection.session_id,
teacher_id=reflection.teacher_id,
notes=reflection.notes,
overall_rating=reflection.overall_rating,
what_worked=reflection.what_worked,
improvements=reflection.improvements,
notes_for_next_lesson=reflection.notes_for_next_lesson,
)
self.db.add(db_reflection)
self.db.commit()
self.db.refresh(db_reflection)
return db_reflection
# ==================== READ ====================
def get_by_id(self, reflection_id: str) -> Optional[LessonReflectionDB]:
"""Holt eine Reflection nach ID."""
return self.db.query(LessonReflectionDB).filter(
LessonReflectionDB.id == reflection_id
).first()
def get_by_session(self, session_id: str) -> Optional[LessonReflectionDB]:
"""Holt die Reflection einer Session."""
return self.db.query(LessonReflectionDB).filter(
LessonReflectionDB.session_id == session_id
).first()
def get_by_teacher(
self,
teacher_id: str,
limit: int = 20,
offset: int = 0
) -> List[LessonReflectionDB]:
"""Holt alle Reflections eines Lehrers."""
return self.db.query(LessonReflectionDB).filter(
LessonReflectionDB.teacher_id == teacher_id
).order_by(
LessonReflectionDB.created_at.desc()
).offset(offset).limit(limit).all()
# ==================== UPDATE ====================
def update(self, reflection: LessonReflection) -> Optional[LessonReflectionDB]:
"""Aktualisiert eine Reflection."""
db_reflection = self.get_by_id(reflection.reflection_id)
if not db_reflection:
return None
db_reflection.notes = reflection.notes
db_reflection.overall_rating = reflection.overall_rating
db_reflection.what_worked = reflection.what_worked
db_reflection.improvements = reflection.improvements
db_reflection.notes_for_next_lesson = reflection.notes_for_next_lesson
self.db.commit()
self.db.refresh(db_reflection)
return db_reflection
# ==================== DELETE ====================
def delete(self, reflection_id: str) -> bool:
"""Loescht eine Reflection."""
db_reflection = self.get_by_id(reflection_id)
if not db_reflection:
return False
self.db.delete(db_reflection)
self.db.commit()
return True
# ==================== CONVERSION ====================
def to_dataclass(self, db_reflection: LessonReflectionDB) -> LessonReflection:
"""Konvertiert DB-Model zu Dataclass."""
return LessonReflection(
reflection_id=db_reflection.id,
session_id=db_reflection.session_id,
teacher_id=db_reflection.teacher_id,
notes=db_reflection.notes or "",
overall_rating=db_reflection.overall_rating,
what_worked=db_reflection.what_worked or [],
improvements=db_reflection.improvements or [],
notes_for_next_lesson=db_reflection.notes_for_next_lesson or "",
created_at=db_reflection.created_at,
updated_at=db_reflection.updated_at,
)
class AnalyticsRepository:
"""Repository fuer Analytics-Abfragen."""
def __init__(self, db: DBSession):
self.db = db
def get_session_summary(self, session_id: str) -> Optional[SessionSummary]:
"""
Berechnet die Summary einer abgeschlossenen Session.
Args:
session_id: ID der Session
Returns:
SessionSummary oder None wenn Session nicht gefunden
"""
db_session = self.db.query(LessonSessionDB).filter(
LessonSessionDB.id == session_id
).first()
if not db_session:
return None
# Session-Daten zusammenstellen
session_data = {
"session_id": db_session.id,
"teacher_id": db_session.teacher_id,
"class_id": db_session.class_id,
"subject": db_session.subject,
"topic": db_session.topic,
"lesson_started_at": db_session.lesson_started_at,
"lesson_ended_at": db_session.lesson_ended_at,
"phase_durations": db_session.phase_durations or {},
}
# Phase History aus DB oder JSON
phase_history = db_session.phase_history or []
# Summary berechnen
return AnalyticsCalculator.calculate_session_summary(
session_data, phase_history
)
def get_teacher_analytics(
self,
teacher_id: str,
period_start: Optional[datetime] = None,
period_end: Optional[datetime] = None
) -> TeacherAnalytics:
"""
Berechnet aggregierte Statistiken fuer einen Lehrer.
Args:
teacher_id: ID des Lehrers
period_start: Beginn des Zeitraums (default: 30 Tage zurueck)
period_end: Ende des Zeitraums (default: jetzt)
Returns:
TeacherAnalytics mit aggregierten Statistiken
"""
from datetime import timedelta
if not period_end:
period_end = datetime.utcnow()
if not period_start:
period_start = period_end - timedelta(days=30)
# Sessions im Zeitraum abfragen
sessions_query = self.db.query(LessonSessionDB).filter(
LessonSessionDB.teacher_id == teacher_id,
LessonSessionDB.lesson_started_at >= period_start,
LessonSessionDB.lesson_started_at <= period_end
).all()
# Sessions zu Dictionaries konvertieren
sessions_data = []
for db_session in sessions_query:
sessions_data.append({
"session_id": db_session.id,
"teacher_id": db_session.teacher_id,
"class_id": db_session.class_id,
"subject": db_session.subject,
"topic": db_session.topic,
"lesson_started_at": db_session.lesson_started_at,
"lesson_ended_at": db_session.lesson_ended_at,
"phase_durations": db_session.phase_durations or {},
"phase_history": db_session.phase_history or [],
})
return AnalyticsCalculator.calculate_teacher_analytics(
sessions_data, period_start, period_end
)
def get_phase_duration_trends(
self,
teacher_id: str,
phase: str,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
Gibt die Dauer-Trends fuer eine bestimmte Phase zurueck.
Args:
teacher_id: ID des Lehrers
phase: Phasen-ID (einstieg, erarbeitung, etc.)
limit: Max Anzahl der Datenpunkte
Returns:
Liste von Datenpunkten [{date, planned, actual, difference}]
"""
sessions = self.db.query(LessonSessionDB).filter(
LessonSessionDB.teacher_id == teacher_id,
LessonSessionDB.current_phase == LessonPhaseEnum.ENDED
).order_by(
LessonSessionDB.lesson_ended_at.desc()
).limit(limit).all()
trends = []
for db_session in sessions:
history = db_session.phase_history or []
for entry in history:
if entry.get("phase") == phase:
planned = (db_session.phase_durations or {}).get(phase, 0) * 60
actual = entry.get("duration_seconds", 0) or 0
trends.append({
"date": db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None,
"session_id": db_session.id,
"subject": db_session.subject,
"planned_seconds": planned,
"actual_seconds": actual,
"difference_seconds": actual - planned,
})
break
return list(reversed(trends)) # Chronologisch sortieren
def get_overtime_analysis(
self,
teacher_id: str,
limit: int = 30
) -> Dict[str, Any]:
"""
Analysiert Overtime-Muster.
Args:
teacher_id: ID des Lehrers
limit: Anzahl der zu analysierenden Sessions
Returns:
Dict mit Overtime-Statistiken pro Phase
"""
sessions = self.db.query(LessonSessionDB).filter(
LessonSessionDB.teacher_id == teacher_id,
LessonSessionDB.current_phase == LessonPhaseEnum.ENDED
).order_by(
LessonSessionDB.lesson_ended_at.desc()
).limit(limit).all()
phase_overtime: Dict[str, List[int]] = {
"einstieg": [],
"erarbeitung": [],
"sicherung": [],
"transfer": [],
"reflexion": [],
}
for db_session in sessions:
history = db_session.phase_history or []
phase_durations = db_session.phase_durations or {}
for entry in history:
phase = entry.get("phase", "")
if phase in phase_overtime:
planned = phase_durations.get(phase, 0) * 60
actual = entry.get("duration_seconds", 0) or 0
overtime = max(0, actual - planned)
phase_overtime[phase].append(overtime)
# Statistiken berechnen
result = {}
for phase, overtimes in phase_overtime.items():
if overtimes:
result[phase] = {
"count": len([o for o in overtimes if o > 0]),
"total": len(overtimes),
"avg_overtime_seconds": sum(overtimes) / len(overtimes),
"max_overtime_seconds": max(overtimes),
"overtime_percentage": len([o for o in overtimes if o > 0]) / len(overtimes) * 100,
}
else:
result[phase] = {
"count": 0,
"total": 0,
"avg_overtime_seconds": 0,
"max_overtime_seconds": 0,
"overtime_percentage": 0,
}
return result

View File

@@ -0,0 +1,248 @@
"""
Session & Teacher Settings Repositories.
CRUD-Operationen fuer LessonSessions und Lehrer-Einstellungen.
"""
from typing import Optional, List, Dict
from sqlalchemy.orm import Session as DBSession
from .db_models import (
LessonSessionDB, LessonPhaseEnum, TeacherSettingsDB,
)
from .models import (
LessonSession, LessonPhase, get_default_durations,
)
class SessionRepository:
"""Repository fuer LessonSession CRUD-Operationen."""
def __init__(self, db: DBSession):
self.db = db
# ==================== CREATE ====================
def create(self, session: LessonSession) -> LessonSessionDB:
"""
Erstellt eine neue Session in der Datenbank.
Args:
session: LessonSession Dataclass
Returns:
LessonSessionDB Model
"""
db_session = LessonSessionDB(
id=session.session_id,
teacher_id=session.teacher_id,
class_id=session.class_id,
subject=session.subject,
topic=session.topic,
current_phase=LessonPhaseEnum(session.current_phase.value),
is_paused=session.is_paused,
lesson_started_at=session.lesson_started_at,
lesson_ended_at=session.lesson_ended_at,
phase_started_at=session.phase_started_at,
pause_started_at=session.pause_started_at,
total_paused_seconds=session.total_paused_seconds,
phase_durations=session.phase_durations,
phase_history=session.phase_history,
notes=session.notes,
homework=session.homework,
)
self.db.add(db_session)
self.db.commit()
self.db.refresh(db_session)
return db_session
# ==================== READ ====================
def get_by_id(self, session_id: str) -> Optional[LessonSessionDB]:
"""Holt eine Session nach ID."""
return self.db.query(LessonSessionDB).filter(
LessonSessionDB.id == session_id
).first()
def get_active_by_teacher(self, teacher_id: str) -> List[LessonSessionDB]:
"""Holt alle aktiven Sessions eines Lehrers."""
return self.db.query(LessonSessionDB).filter(
LessonSessionDB.teacher_id == teacher_id,
LessonSessionDB.current_phase != LessonPhaseEnum.ENDED
).all()
def get_history_by_teacher(
self,
teacher_id: str,
limit: int = 20,
offset: int = 0
) -> List[LessonSessionDB]:
"""Holt Session-History eines Lehrers (Feature f17)."""
return self.db.query(LessonSessionDB).filter(
LessonSessionDB.teacher_id == teacher_id,
LessonSessionDB.current_phase == LessonPhaseEnum.ENDED
).order_by(
LessonSessionDB.lesson_ended_at.desc()
).offset(offset).limit(limit).all()
def get_by_class(
self,
class_id: str,
limit: int = 20
) -> List[LessonSessionDB]:
"""Holt Sessions einer Klasse."""
return self.db.query(LessonSessionDB).filter(
LessonSessionDB.class_id == class_id
).order_by(
LessonSessionDB.created_at.desc()
).limit(limit).all()
# ==================== UPDATE ====================
def update(self, session: LessonSession) -> Optional[LessonSessionDB]:
"""
Aktualisiert eine bestehende Session.
Args:
session: LessonSession Dataclass mit aktualisierten Werten
Returns:
Aktualisierte LessonSessionDB oder None
"""
db_session = self.get_by_id(session.session_id)
if not db_session:
return None
db_session.current_phase = LessonPhaseEnum(session.current_phase.value)
db_session.is_paused = session.is_paused
db_session.lesson_started_at = session.lesson_started_at
db_session.lesson_ended_at = session.lesson_ended_at
db_session.phase_started_at = session.phase_started_at
db_session.pause_started_at = session.pause_started_at
db_session.total_paused_seconds = session.total_paused_seconds
db_session.phase_durations = session.phase_durations
db_session.phase_history = session.phase_history
db_session.notes = session.notes
db_session.homework = session.homework
self.db.commit()
self.db.refresh(db_session)
return db_session
def update_notes(
self,
session_id: str,
notes: str,
homework: str
) -> Optional[LessonSessionDB]:
"""Aktualisiert nur Notizen und Hausaufgaben."""
db_session = self.get_by_id(session_id)
if not db_session:
return None
db_session.notes = notes
db_session.homework = homework
self.db.commit()
self.db.refresh(db_session)
return db_session
# ==================== DELETE ====================
def delete(self, session_id: str) -> bool:
"""Loescht eine Session."""
db_session = self.get_by_id(session_id)
if not db_session:
return False
self.db.delete(db_session)
self.db.commit()
return True
# ==================== CONVERSION ====================
def to_dataclass(self, db_session: LessonSessionDB) -> LessonSession:
"""
Konvertiert DB-Model zu Dataclass.
Args:
db_session: LessonSessionDB Model
Returns:
LessonSession Dataclass
"""
return LessonSession(
session_id=db_session.id,
teacher_id=db_session.teacher_id,
class_id=db_session.class_id,
subject=db_session.subject,
topic=db_session.topic,
current_phase=LessonPhase(db_session.current_phase.value),
phase_started_at=db_session.phase_started_at,
lesson_started_at=db_session.lesson_started_at,
lesson_ended_at=db_session.lesson_ended_at,
is_paused=db_session.is_paused,
pause_started_at=db_session.pause_started_at,
total_paused_seconds=db_session.total_paused_seconds or 0,
phase_durations=db_session.phase_durations or get_default_durations(),
phase_history=db_session.phase_history or [],
notes=db_session.notes or "",
homework=db_session.homework or "",
)
class TeacherSettingsRepository:
"""Repository fuer Lehrer-Einstellungen (Feature f16)."""
def __init__(self, db: DBSession):
self.db = db
def get_or_create(self, teacher_id: str) -> TeacherSettingsDB:
"""Holt oder erstellt Einstellungen fuer einen Lehrer."""
settings = self.db.query(TeacherSettingsDB).filter(
TeacherSettingsDB.teacher_id == teacher_id
).first()
if not settings:
settings = TeacherSettingsDB(
teacher_id=teacher_id,
default_phase_durations=get_default_durations(),
)
self.db.add(settings)
self.db.commit()
self.db.refresh(settings)
return settings
def update_phase_durations(
self,
teacher_id: str,
durations: Dict[str, int]
) -> TeacherSettingsDB:
"""Aktualisiert die Standard-Phasendauern."""
settings = self.get_or_create(teacher_id)
settings.default_phase_durations = durations
self.db.commit()
self.db.refresh(settings)
return settings
def update_preferences(
self,
teacher_id: str,
audio_enabled: Optional[bool] = None,
high_contrast: Optional[bool] = None,
show_statistics: Optional[bool] = None
) -> TeacherSettingsDB:
"""Aktualisiert UI-Praeferenzen."""
settings = self.get_or_create(teacher_id)
if audio_enabled is not None:
settings.audio_enabled = audio_enabled
if high_contrast is not None:
settings.high_contrast = high_contrast
if show_statistics is not None:
settings.show_statistics = show_statistics
self.db.commit()
self.db.refresh(settings)
return settings

View File

@@ -0,0 +1,167 @@
"""
Template Repository.
CRUD-Operationen fuer Stunden-Vorlagen (Feature f37).
"""
from typing import Optional, List
from sqlalchemy.orm import Session as DBSession
from .db_models import LessonTemplateDB
from .models import LessonTemplate, get_default_durations
class TemplateRepository:
"""Repository fuer Stunden-Vorlagen (Feature f37)."""
def __init__(self, db: DBSession):
self.db = db
# ==================== CREATE ====================
def create(self, template: LessonTemplate) -> LessonTemplateDB:
"""Erstellt eine neue Vorlage."""
db_template = LessonTemplateDB(
id=template.template_id,
teacher_id=template.teacher_id,
name=template.name,
description=template.description,
subject=template.subject,
grade_level=template.grade_level,
phase_durations=template.phase_durations,
default_topic=template.default_topic,
default_notes=template.default_notes,
is_public=template.is_public,
usage_count=template.usage_count,
)
self.db.add(db_template)
self.db.commit()
self.db.refresh(db_template)
return db_template
# ==================== READ ====================
def get_by_id(self, template_id: str) -> Optional[LessonTemplateDB]:
"""Holt eine Vorlage nach ID."""
return self.db.query(LessonTemplateDB).filter(
LessonTemplateDB.id == template_id
).first()
def get_by_teacher(
self,
teacher_id: str,
include_public: bool = True
) -> List[LessonTemplateDB]:
"""
Holt alle Vorlagen eines Lehrers.
Args:
teacher_id: ID des Lehrers
include_public: Auch oeffentliche Vorlagen anderer Lehrer einbeziehen
"""
if include_public:
return self.db.query(LessonTemplateDB).filter(
(LessonTemplateDB.teacher_id == teacher_id) |
(LessonTemplateDB.is_public == True)
).order_by(
LessonTemplateDB.usage_count.desc()
).all()
else:
return self.db.query(LessonTemplateDB).filter(
LessonTemplateDB.teacher_id == teacher_id
).order_by(
LessonTemplateDB.created_at.desc()
).all()
def get_public_templates(self, limit: int = 20) -> List[LessonTemplateDB]:
"""Holt oeffentliche Vorlagen, sortiert nach Beliebtheit."""
return self.db.query(LessonTemplateDB).filter(
LessonTemplateDB.is_public == True
).order_by(
LessonTemplateDB.usage_count.desc()
).limit(limit).all()
def get_by_subject(
self,
subject: str,
teacher_id: Optional[str] = None
) -> List[LessonTemplateDB]:
"""Holt Vorlagen fuer ein bestimmtes Fach."""
query = self.db.query(LessonTemplateDB).filter(
LessonTemplateDB.subject == subject
)
if teacher_id:
query = query.filter(
(LessonTemplateDB.teacher_id == teacher_id) |
(LessonTemplateDB.is_public == True)
)
else:
query = query.filter(LessonTemplateDB.is_public == True)
return query.order_by(
LessonTemplateDB.usage_count.desc()
).all()
# ==================== UPDATE ====================
def update(self, template: LessonTemplate) -> Optional[LessonTemplateDB]:
"""Aktualisiert eine Vorlage."""
db_template = self.get_by_id(template.template_id)
if not db_template:
return None
db_template.name = template.name
db_template.description = template.description
db_template.subject = template.subject
db_template.grade_level = template.grade_level
db_template.phase_durations = template.phase_durations
db_template.default_topic = template.default_topic
db_template.default_notes = template.default_notes
db_template.is_public = template.is_public
self.db.commit()
self.db.refresh(db_template)
return db_template
def increment_usage(self, template_id: str) -> Optional[LessonTemplateDB]:
"""Erhoeht den Usage-Counter einer Vorlage."""
db_template = self.get_by_id(template_id)
if not db_template:
return None
db_template.usage_count += 1
self.db.commit()
self.db.refresh(db_template)
return db_template
# ==================== DELETE ====================
def delete(self, template_id: str) -> bool:
"""Loescht eine Vorlage."""
db_template = self.get_by_id(template_id)
if not db_template:
return False
self.db.delete(db_template)
self.db.commit()
return True
# ==================== CONVERSION ====================
def to_dataclass(self, db_template: LessonTemplateDB) -> LessonTemplate:
"""Konvertiert DB-Model zu Dataclass."""
return LessonTemplate(
template_id=db_template.id,
teacher_id=db_template.teacher_id,
name=db_template.name,
description=db_template.description or "",
subject=db_template.subject or "",
grade_level=db_template.grade_level or "",
phase_durations=db_template.phase_durations or get_default_durations(),
default_topic=db_template.default_topic or "",
default_notes=db_template.default_notes or "",
is_public=db_template.is_public,
usage_count=db_template.usage_count,
created_at=db_template.created_at,
updated_at=db_template.updated_at,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,498 @@
"""
Cell-grid construction v2 (hybrid: broad columns via word lookup, narrow via cell-crop).
Extracted from cv_cell_grid.py.
Lizenz: Apache 2.0 — DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from cv_vocab_types import PageRegion, RowGeometry
from cv_ocr_engines import (
RAPIDOCR_AVAILABLE,
_assign_row_words_to_columns,
_clean_cell_text,
_clean_cell_text_lite,
_words_to_reading_order_text,
_words_to_spaced_text,
ocr_region_lighton,
ocr_region_rapid,
ocr_region_trocr,
)
from cv_cell_grid_helpers import (
_MIN_WORD_CONF,
_ensure_minimum_crop_size,
_heal_row_gaps,
_is_artifact_row,
_select_psm_for_column,
)
logger = logging.getLogger(__name__)
try:
import cv2
except ImportError:
cv2 = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# _ocr_cell_crop — isolated cell-crop OCR for v2 hybrid mode
# ---------------------------------------------------------------------------
def _ocr_cell_crop(
row_idx: int,
col_idx: int,
row: RowGeometry,
col: PageRegion,
ocr_img: np.ndarray,
img_bgr: Optional[np.ndarray],
img_w: int,
img_h: int,
engine_name: str,
lang: str,
lang_map: Dict[str, str],
) -> Dict[str, Any]:
"""OCR a single cell by cropping the exact column x row intersection.
No padding beyond cell boundaries -> no neighbour bleeding.
"""
# Display bbox: exact column x row intersection
disp_x = col.x
disp_y = row.y
disp_w = col.width
disp_h = row.height
# Crop boundaries: add small internal padding (3px each side) to avoid
# clipping characters near column/row edges (e.g. parentheses, descenders).
# Stays within image bounds but may extend slightly beyond strict cell.
# 3px is small enough to avoid neighbour content at typical scan DPI (200-300).
_PAD = 3
cx = max(0, disp_x - _PAD)
cy = max(0, disp_y - _PAD)
cx2 = min(img_w, disp_x + disp_w + _PAD)
cy2 = min(img_h, disp_y + disp_h + _PAD)
cw = cx2 - cx
ch = cy2 - cy
empty_cell = {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': '',
'confidence': 0.0,
'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h},
'bbox_pct': {
'x': round(disp_x / img_w * 100, 2) if img_w else 0,
'y': round(disp_y / img_h * 100, 2) if img_h else 0,
'w': round(disp_w / img_w * 100, 2) if img_w else 0,
'h': round(disp_h / img_h * 100, 2) if img_h else 0,
},
'ocr_engine': 'cell_crop_v2',
'is_bold': False,
}
if cw <= 0 or ch <= 0:
logger.debug("_ocr_cell_crop R%02d_C%d: zero-size crop (%dx%d)", row_idx, col_idx, cw, ch)
return empty_cell
# --- Pixel-density check: skip truly empty cells ---
if ocr_img is not None:
crop = ocr_img[cy:cy + ch, cx:cx + cw]
if crop.size > 0:
dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size
if dark_ratio < 0.005:
logger.debug("_ocr_cell_crop R%02d_C%d: skip empty (dark_ratio=%.4f, crop=%dx%d)",
row_idx, col_idx, dark_ratio, cw, ch)
return empty_cell
# --- Prepare crop for OCR ---
cell_lang = lang_map.get(col.type, lang)
psm = _select_psm_for_column(col.type, col.width, row.height)
text = ''
avg_conf = 0.0
used_engine = 'cell_crop_v2'
if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None:
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
words = ocr_region_trocr(img_bgr, cell_region,
handwritten=(engine_name == "trocr-handwritten"))
elif engine_name == "lighton" and img_bgr is not None:
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
words = ocr_region_lighton(img_bgr, cell_region)
elif engine_name == "rapid" and img_bgr is not None:
# Upscale small BGR crops for RapidOCR.
bgr_crop = img_bgr[cy:cy + ch, cx:cx + cw]
if bgr_crop.size == 0:
words = []
else:
crop_h, crop_w = bgr_crop.shape[:2]
if crop_h < 80:
# Force 3x upscale for short rows — small chars need more pixels
scale = 3.0
bgr_up = cv2.resize(bgr_crop, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC)
else:
bgr_up = _ensure_minimum_crop_size(bgr_crop, min_dim=150, max_scale=3)
up_h, up_w = bgr_up.shape[:2]
scale_x = up_w / max(crop_w, 1)
scale_y = up_h / max(crop_h, 1)
was_scaled = (up_w != crop_w or up_h != crop_h)
logger.debug("_ocr_cell_crop R%02d_C%d: rapid %dx%d -> %dx%d (scale=%.1fx)",
row_idx, col_idx, crop_w, crop_h, up_w, up_h, scale_y)
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
words = ocr_region_rapid(bgr_up, tmp_region)
# Remap positions back to original image coords
if words and was_scaled:
for w in words:
w['left'] = int(w['left'] / scale_x) + cx
w['top'] = int(w['top'] / scale_y) + cy
w['width'] = int(w['width'] / scale_x)
w['height'] = int(w['height'] / scale_y)
elif words:
for w in words:
w['left'] += cx
w['top'] += cy
else:
# Tesseract: upscale tiny crops for better recognition
if ocr_img is not None:
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
upscaled = _ensure_minimum_crop_size(crop_slice)
up_h, up_w = upscaled.shape[:2]
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm)
# Remap word positions back to original image coordinates
if words and (up_w != cw or up_h != ch):
sx = cw / max(up_w, 1)
sy = ch / max(up_h, 1)
for w in words:
w['left'] = int(w['left'] * sx) + cx
w['top'] = int(w['top'] * sy) + cy
w['width'] = int(w['width'] * sx)
w['height'] = int(w['height'] * sy)
elif words:
for w in words:
w['left'] += cx
w['top'] += cy
else:
words = []
# Filter low-confidence words
if words:
words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF]
if words:
y_tol = max(15, ch)
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
logger.debug("_ocr_cell_crop R%02d_C%d: OCR raw text=%r conf=%.1f nwords=%d crop=%dx%d psm=%s engine=%s",
row_idx, col_idx, text, avg_conf, len(words), cw, ch, psm, engine_name)
else:
logger.debug("_ocr_cell_crop R%02d_C%d: OCR returned NO words (crop=%dx%d psm=%s engine=%s)",
row_idx, col_idx, cw, ch, psm, engine_name)
# --- PSM 7 fallback for still-empty Tesseract cells ---
if not text.strip() and engine_name == "tesseract" and ocr_img is not None:
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
upscaled = _ensure_minimum_crop_size(crop_slice)
up_h, up_w = upscaled.shape[:2]
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7)
if psm7_words:
psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF]
if psm7_words:
p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10)
if p7_text.strip():
text = p7_text
avg_conf = round(
sum(w['conf'] for w in psm7_words) / len(psm7_words), 1
)
used_engine = 'cell_crop_v2_psm7'
# Remap PSM7 word positions back to original image coords
if up_w != cw or up_h != ch:
sx = cw / max(up_w, 1)
sy = ch / max(up_h, 1)
for w in psm7_words:
w['left'] = int(w['left'] * sx) + cx
w['top'] = int(w['top'] * sy) + cy
w['width'] = int(w['width'] * sx)
w['height'] = int(w['height'] * sy)
else:
for w in psm7_words:
w['left'] += cx
w['top'] += cy
words = psm7_words
# --- Noise filter ---
if text.strip():
pre_filter = text
text = _clean_cell_text_lite(text)
if not text:
logger.debug("_ocr_cell_crop R%02d_C%d: _clean_cell_text_lite REMOVED %r",
row_idx, col_idx, pre_filter)
avg_conf = 0.0
result = dict(empty_cell)
result['text'] = text
result['confidence'] = avg_conf
result['ocr_engine'] = used_engine
# Store individual word bounding boxes (absolute image coordinates)
# for pixel-accurate overlay positioning in the frontend.
if words and text.strip():
result['word_boxes'] = [
{
'text': w.get('text', ''),
'left': w['left'],
'top': w['top'],
'width': w['width'],
'height': w['height'],
'conf': w.get('conf', 0),
}
for w in words
if w.get('text', '').strip()
]
return result
# Threshold: columns narrower than this (% of image width) use single-cell
# crop OCR instead of full-page word assignment.
_NARROW_COL_THRESHOLD_PCT = 15.0
# ---------------------------------------------------------------------------
# build_cell_grid_v2 — hybrid grid builder (current default)
# ---------------------------------------------------------------------------
def build_cell_grid_v2(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
skip_heal_gaps: bool = False,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Hybrid Grid: full-page OCR for broad columns, cell-crop for narrow ones.
Drop-in replacement for build_cell_grid() -- same signature & return type.
Strategy:
- Broad columns (>15% image width): Use pre-assigned full-page Tesseract
words (from row.words). Handles IPA brackets, punctuation, sentence
continuity correctly.
- Narrow columns (<15% image width): Use isolated cell-crop OCR to prevent
neighbour bleeding from adjacent broad columns.
"""
engine_name = "tesseract"
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
engine_name = ocr_engine
elif ocr_engine == "rapid" and RAPIDOCR_AVAILABLE:
engine_name = "rapid"
logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}' (hybrid mode)")
# Filter to content rows only
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
logger.warning("build_cell_grid_v2: no content rows found")
return [], []
# Filter phantom rows (word_count=0) and artifact rows
before = len(content_rows)
content_rows = [r for r in content_rows if r.word_count > 0]
skipped = before - len(content_rows)
if skipped > 0:
logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)")
if not content_rows:
logger.warning("build_cell_grid_v2: no content rows with words found")
return [], []
before_art = len(content_rows)
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
artifact_skipped = before_art - len(content_rows)
if artifact_skipped > 0:
logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows")
if not content_rows:
logger.warning("build_cell_grid_v2: no content rows after artifact filtering")
return [], []
# Filter columns
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
'margin_bottom', 'margin_left', 'margin_right'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
logger.warning("build_cell_grid_v2: no usable columns found")
return [], []
# Heal row gaps -- use header/footer boundaries
content_rows.sort(key=lambda r: r.y)
header_rows = [r for r in row_geometries if r.row_type == 'header']
footer_rows = [r for r in row_geometries if r.row_type == 'footer']
if header_rows:
top_bound = max(r.y + r.height for r in header_rows)
else:
top_bound = content_rows[0].y
if footer_rows:
bottom_bound = min(r.y for r in footer_rows)
else:
bottom_bound = content_rows[-1].y + content_rows[-1].height
# skip_heal_gaps: When True, keep cell positions at their exact row geometry
# positions without expanding to fill gaps from removed rows.
if not skip_heal_gaps:
_heal_row_gaps(content_rows, top_bound=top_bound, bottom_bound=bottom_bound)
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
for ci, c in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
# --- Classify columns as broad vs narrow ---
narrow_col_indices = set()
for ci, col in enumerate(relevant_cols):
col_pct = (col.width / img_w * 100) if img_w > 0 else 0
if col_pct < _NARROW_COL_THRESHOLD_PCT:
narrow_col_indices.add(ci)
broad_col_count = len(relevant_cols) - len(narrow_col_indices)
logger.info(f"build_cell_grid_v2: {broad_col_count} broad columns (full-page), "
f"{len(narrow_col_indices)} narrow columns (cell-crop)")
# --- Phase 1: Broad columns via full-page word assignment ---
cells: List[Dict[str, Any]] = []
for row_idx, row in enumerate(content_rows):
# Assign full-page words to columns for this row
col_words = _assign_row_words_to_columns(row, relevant_cols)
for col_idx, col in enumerate(relevant_cols):
if col_idx not in narrow_col_indices:
# BROAD column: use pre-assigned full-page words
words = col_words.get(col_idx, [])
# Filter low-confidence words
words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF]
# Single full-width column (box sub-session): preserve spacing
is_single_full_column = (
len(relevant_cols) == 1
and img_w > 0
and relevant_cols[0].width / img_w > 0.9
)
if words:
y_tol = max(15, row.height)
if is_single_full_column:
text = _words_to_spaced_text(words, y_tolerance_px=y_tol)
logger.info(f"R{row_idx:02d}: {len(words)} words, "
f"text={text!r:.100}")
else:
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
else:
text = ''
avg_conf = 0.0
if is_single_full_column:
logger.info(f"R{row_idx:02d}: 0 words (row has "
f"{row.word_count} total, y={row.y}..{row.y+row.height})")
# Apply noise filter -- but NOT for single-column sub-sessions
if not is_single_full_column:
text = _clean_cell_text(text)
cell = {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': text,
'confidence': avg_conf,
'bbox_px': {
'x': col.x, 'y': row.y,
'w': col.width, 'h': row.height,
},
'bbox_pct': {
'x': round(col.x / img_w * 100, 2) if img_w else 0,
'y': round(row.y / img_h * 100, 2) if img_h else 0,
'w': round(col.width / img_w * 100, 2) if img_w else 0,
'h': round(row.height / img_h * 100, 2) if img_h else 0,
},
'ocr_engine': 'word_lookup',
'is_bold': False,
}
# Store word bounding boxes for pixel-accurate overlay
if words and text.strip():
cell['word_boxes'] = [
{
'text': w.get('text', ''),
'left': w['left'],
'top': w['top'],
'width': w['width'],
'height': w['height'],
'conf': w.get('conf', 0),
}
for w in words
if w.get('text', '').strip()
]
cells.append(cell)
# --- Phase 2: Narrow columns via cell-crop OCR (parallel) ---
narrow_tasks = []
for row_idx, row in enumerate(content_rows):
for col_idx, col in enumerate(relevant_cols):
if col_idx in narrow_col_indices:
narrow_tasks.append((row_idx, col_idx, row, col))
if narrow_tasks:
max_workers = 4 if engine_name == "tesseract" else 2
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
_ocr_cell_crop,
ri, ci, row, col,
ocr_img, img_bgr, img_w, img_h,
engine_name, lang, lang_map,
): (ri, ci)
for ri, ci, row, col in narrow_tasks
}
for future in as_completed(futures):
try:
cell = future.result()
cells.append(cell)
except Exception as e:
ri, ci = futures[future]
logger.error(f"build_cell_grid_v2: narrow cell R{ri:02d}_C{ci} failed: {e}")
# Sort cells by (row_index, col_index)
cells.sort(key=lambda c: (c['row_index'], c['col_index']))
# Remove all-empty rows
rows_with_text: set = set()
for cell in cells:
if cell['text'].strip():
rows_with_text.add(cell['row_index'])
before_filter = len(cells)
cells = [c for c in cells if c['row_index'] in rows_with_text]
empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1)
if empty_rows_removed > 0:
logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows")
logger.info(f"build_cell_grid_v2: {len(cells)} cells from "
f"{len(content_rows)} rows x {len(relevant_cols)} columns, "
f"engine={engine_name} (hybrid)")
return cells, columns_meta

View File

@@ -0,0 +1,136 @@
"""
Shared helpers for cell-grid construction (v2 + legacy).
Extracted from cv_cell_grid.py — used by both cv_cell_grid_build and
cv_cell_grid_legacy.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import List
import numpy as np
from cv_vocab_types import RowGeometry
logger = logging.getLogger(__name__)
try:
import cv2
except ImportError:
cv2 = None # type: ignore[assignment]
# Minimum OCR word confidence to keep (used across multiple functions)
_MIN_WORD_CONF = 30
def _compute_cell_padding(col_width: int, img_w: int) -> int:
"""Adaptive padding for OCR crops based on column width.
Narrow columns (page_ref, marker) need more surrounding context so
Tesseract can segment characters correctly. Wide columns keep the
minimal 4 px padding to avoid pulling in neighbours.
"""
col_pct = col_width / img_w * 100 if img_w > 0 else 100
if col_pct < 5:
return max(20, col_width // 2)
if col_pct < 10:
return max(12, col_width // 4)
if col_pct < 15:
return 8
return 4
def _ensure_minimum_crop_size(crop: np.ndarray, min_dim: int = 150,
max_scale: int = 3) -> np.ndarray:
"""Upscale tiny crops so Tesseract gets enough pixel data.
If either dimension is below *min_dim*, the crop is bicubic-upscaled
so the smallest dimension reaches *min_dim* (capped at *max_scale* x).
"""
h, w = crop.shape[:2]
if h >= min_dim and w >= min_dim:
return crop
scale = min(max_scale, max(min_dim / max(h, 1), min_dim / max(w, 1)))
if scale <= 1.0:
return crop
new_w = int(w * scale)
new_h = int(h * scale)
return cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
def _select_psm_for_column(col_type: str, col_width: int,
row_height: int) -> int:
"""Choose the best Tesseract PSM for a given column geometry.
- page_ref columns are almost always single short tokens -> PSM 8
- Very narrow or short cells -> PSM 7 (single text line)
- Everything else -> PSM 6 (uniform block)
"""
if col_type in ('page_ref', 'marker'):
return 8 # single word
if col_width < 100 or row_height < 30:
return 7 # single line
return 6 # uniform block
def _is_artifact_row(row: RowGeometry) -> bool:
"""Return True if this row contains only scan artifacts, not real text.
Artifact rows (scanner shadows, noise) typically produce only single-character
detections. A real content row always has at least one token with 2+ characters.
"""
if row.word_count == 0:
return True
texts = [w.get('text', '').strip() for w in row.words]
return all(len(t) <= 1 for t in texts)
def _heal_row_gaps(
rows: List[RowGeometry],
top_bound: int,
bottom_bound: int,
) -> None:
"""Expand row y/height to fill vertical gaps caused by removed adjacent rows.
After filtering out empty or artifact rows, remaining content rows may have
gaps between them where the removed rows used to be. This function mutates
each row to extend upward/downward to the midpoint of such gaps so that
OCR crops cover the full available content area.
The first row always extends to top_bound; the last row to bottom_bound.
"""
if not rows:
return
rows.sort(key=lambda r: r.y)
n = len(rows)
orig = [(r.y, r.y + r.height) for r in rows] # snapshot before mutation
for i, row in enumerate(rows):
# New top: midpoint between previous row's bottom and this row's top
if i == 0:
new_top = top_bound
else:
prev_bot = orig[i - 1][1]
my_top = orig[i][0]
gap = my_top - prev_bot
new_top = prev_bot + gap // 2 if gap > 1 else my_top
# New bottom: midpoint between this row's bottom and next row's top
if i == n - 1:
new_bottom = bottom_bound
else:
my_bot = orig[i][1]
next_top = orig[i + 1][0]
gap = next_top - my_bot
new_bottom = my_bot + gap // 2 if gap > 1 else my_bot
row.y = new_top
row.height = max(5, new_bottom - new_top)
logger.debug(
f"_heal_row_gaps: {n} rows -> y range [{rows[0].y}..{rows[-1].y + rows[-1].height}] "
f"(bounds: top={top_bound}, bottom={bottom_bound})"
)

View File

@@ -0,0 +1,436 @@
"""
Legacy cell-grid construction (v1) -- DEPRECATED, kept for backward compat.
Extracted from cv_cell_grid.py. Prefer build_cell_grid_v2 for new code.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from cv_vocab_types import PageRegion, RowGeometry
from cv_ocr_engines import (
RAPIDOCR_AVAILABLE,
_assign_row_words_to_columns,
_clean_cell_text,
_words_to_reading_order_text,
ocr_region_lighton,
ocr_region_rapid,
ocr_region_trocr,
)
from cv_cell_grid_helpers import (
_MIN_WORD_CONF,
_compute_cell_padding,
_ensure_minimum_crop_size,
_heal_row_gaps,
_is_artifact_row,
_select_psm_for_column,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# _ocr_single_cell — legacy per-cell OCR with multi-level fallback
# ---------------------------------------------------------------------------
def _ocr_single_cell(
row_idx: int,
col_idx: int,
row: RowGeometry,
col: PageRegion,
ocr_img: np.ndarray,
img_bgr: Optional[np.ndarray],
img_w: int,
img_h: int,
use_rapid: bool,
engine_name: str,
lang: str,
lang_map: Dict[str, str],
preassigned_words: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
"""Populate a single cell (column x row intersection) via word lookup."""
# Display bbox: exact column x row intersection (no padding)
disp_x = col.x
disp_y = row.y
disp_w = col.width
disp_h = row.height
# OCR crop: adaptive padding -- narrow columns get more context
pad = _compute_cell_padding(col.width, img_w)
cell_x = max(0, col.x - pad)
cell_y = max(0, row.y - pad)
cell_w = min(col.width + 2 * pad, img_w - cell_x)
cell_h = min(row.height + 2 * pad, img_h - cell_y)
is_narrow = (col.width / img_w * 100) < 15 if img_w > 0 else False
if disp_w <= 0 or disp_h <= 0:
return {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': '',
'confidence': 0.0,
'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height},
'bbox_pct': {
'x': round(col.x / img_w * 100, 2),
'y': round(row.y / img_h * 100, 2),
'w': round(col.width / img_w * 100, 2),
'h': round(row.height / img_h * 100, 2),
},
'ocr_engine': 'word_lookup',
}
# --- PRIMARY: Word-lookup from full-page Tesseract ---
words = preassigned_words if preassigned_words is not None else []
used_engine = 'word_lookup'
# Filter low-confidence words
if words:
words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF]
if words:
y_tol = max(15, row.height)
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
else:
text = ''
avg_conf = 0.0
# --- FALLBACK: Cell-OCR for empty cells ---
_run_fallback = False
if not text.strip() and cell_w > 0 and cell_h > 0:
if ocr_img is not None:
crop = ocr_img[cell_y:cell_y + cell_h, cell_x:cell_x + cell_w]
if crop.size > 0:
dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size
_run_fallback = dark_ratio > 0.005
if _run_fallback:
# For narrow columns, upscale the crop before OCR
if is_narrow and ocr_img is not None:
_crop_slice = ocr_img[cell_y:cell_y + cell_h, cell_x:cell_x + cell_w]
_upscaled = _ensure_minimum_crop_size(_crop_slice)
if _upscaled is not _crop_slice:
_up_h, _up_w = _upscaled.shape[:2]
_tmp_region = PageRegion(
type=col.type, x=0, y=0, width=_up_w, height=_up_h,
)
_cell_psm = _select_psm_for_column(col.type, col.width, row.height)
cell_lang = lang_map.get(col.type, lang)
fallback_words = ocr_region(_upscaled, _tmp_region,
lang=cell_lang, psm=_cell_psm)
# Remap word positions back to original image coordinates
_sx = cell_w / max(_up_w, 1)
_sy = cell_h / max(_up_h, 1)
for _fw in (fallback_words or []):
_fw['left'] = int(_fw['left'] * _sx) + cell_x
_fw['top'] = int(_fw['top'] * _sy) + cell_y
_fw['width'] = int(_fw['width'] * _sx)
_fw['height'] = int(_fw['height'] * _sy)
else:
cell_region = PageRegion(
type=col.type, x=cell_x, y=cell_y,
width=cell_w, height=cell_h,
)
_cell_psm = _select_psm_for_column(col.type, col.width, row.height)
cell_lang = lang_map.get(col.type, lang)
fallback_words = ocr_region(ocr_img, cell_region,
lang=cell_lang, psm=_cell_psm)
else:
cell_region = PageRegion(
type=col.type,
x=cell_x, y=cell_y,
width=cell_w, height=cell_h,
)
if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None:
fallback_words = ocr_region_trocr(img_bgr, cell_region, handwritten=(engine_name == "trocr-handwritten"))
elif engine_name == "lighton" and img_bgr is not None:
fallback_words = ocr_region_lighton(img_bgr, cell_region)
elif use_rapid and img_bgr is not None:
fallback_words = ocr_region_rapid(img_bgr, cell_region)
else:
_cell_psm = _select_psm_for_column(col.type, col.width, row.height)
cell_lang = lang_map.get(col.type, lang)
fallback_words = ocr_region(ocr_img, cell_region,
lang=cell_lang, psm=_cell_psm)
if fallback_words:
fallback_words = [w for w in fallback_words if w.get('conf', 0) >= _MIN_WORD_CONF]
if fallback_words:
fb_avg_h = sum(w['height'] for w in fallback_words) / len(fallback_words)
fb_y_tol = max(10, int(fb_avg_h * 0.5))
fb_text = _words_to_reading_order_text(fallback_words, y_tolerance_px=fb_y_tol)
if fb_text.strip():
text = fb_text
avg_conf = round(
sum(w['conf'] for w in fallback_words) / len(fallback_words), 1
)
used_engine = 'cell_ocr_fallback'
# --- SECONDARY FALLBACK: PSM=7 (single line) for still-empty cells ---
if not text.strip() and _run_fallback and not use_rapid:
_fb_region = PageRegion(
type=col.type, x=cell_x, y=cell_y,
width=cell_w, height=cell_h,
)
cell_lang = lang_map.get(col.type, lang)
psm7_words = ocr_region(ocr_img, _fb_region, lang=cell_lang, psm=7)
if psm7_words:
psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF]
if psm7_words:
p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10)
if p7_text.strip():
text = p7_text
avg_conf = round(
sum(w['conf'] for w in psm7_words) / len(psm7_words), 1
)
used_engine = 'cell_ocr_psm7'
# --- TERTIARY FALLBACK: Row-strip re-OCR for narrow columns ---
if not text.strip() and is_narrow and img_bgr is not None:
row_region = PageRegion(
type='_row_strip', x=0, y=row.y,
width=img_w, height=row.height,
)
strip_words = ocr_region_rapid(img_bgr, row_region)
if strip_words:
col_left = col.x
col_right = col.x + col.width
col_words = []
for sw in strip_words:
sw_left = sw.get('left', 0)
sw_right = sw_left + sw.get('width', 0)
overlap = max(0, min(sw_right, col_right) - max(sw_left, col_left))
if overlap > sw.get('width', 1) * 0.3:
col_words.append(sw)
if col_words:
col_words = [w for w in col_words if w.get('conf', 0) >= _MIN_WORD_CONF]
if col_words:
rs_text = _words_to_reading_order_text(col_words, y_tolerance_px=row.height)
if rs_text.strip():
text = rs_text
avg_conf = round(
sum(w['conf'] for w in col_words) / len(col_words), 1
)
used_engine = 'row_strip_rapid'
# --- NOISE FILTER: clear cells that contain only OCR artifacts ---
if text.strip():
text = _clean_cell_text(text)
if not text:
avg_conf = 0.0
return {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': text,
'confidence': avg_conf,
'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h},
'bbox_pct': {
'x': round(disp_x / img_w * 100, 2),
'y': round(disp_y / img_h * 100, 2),
'w': round(disp_w / img_w * 100, 2),
'h': round(disp_h / img_h * 100, 2),
},
'ocr_engine': used_engine,
}
# ---------------------------------------------------------------------------
# build_cell_grid — legacy grid builder (DEPRECATED)
# ---------------------------------------------------------------------------
def build_cell_grid(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Generic Cell-Grid: Columns x Rows -> cells with OCR text.
DEPRECATED: Use build_cell_grid_v2 instead.
"""
# Resolve engine choice
use_rapid = False
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
engine_name = ocr_engine
elif ocr_engine == "auto":
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
engine_name = "rapid" if use_rapid else "tesseract"
elif ocr_engine == "rapid":
if not RAPIDOCR_AVAILABLE:
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
else:
use_rapid = True
engine_name = "rapid" if use_rapid else "tesseract"
else:
engine_name = "tesseract"
logger.info(f"build_cell_grid: using OCR engine '{engine_name}'")
# Filter to content rows only (skip header/footer)
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
logger.warning("build_cell_grid: no content rows found")
return [], []
before = len(content_rows)
content_rows = [r for r in content_rows if r.word_count > 0]
skipped = before - len(content_rows)
if skipped > 0:
logger.info(f"build_cell_grid: skipped {skipped} phantom rows (word_count=0)")
if not content_rows:
logger.warning("build_cell_grid: no content rows with words found")
return [], []
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
logger.warning("build_cell_grid: no usable columns found")
return [], []
before_art = len(content_rows)
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
artifact_skipped = before_art - len(content_rows)
if artifact_skipped > 0:
logger.info(f"build_cell_grid: skipped {artifact_skipped} artifact rows (all single-char words)")
if not content_rows:
logger.warning("build_cell_grid: no content rows after artifact filtering")
return [], []
_heal_row_gaps(
content_rows,
top_bound=min(c.y for c in relevant_cols),
bottom_bound=max(c.y + c.height for c in relevant_cols),
)
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{
'index': col_idx,
'type': col.type,
'x': col.x,
'width': col.width,
}
for col_idx, col in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
cells: List[Dict[str, Any]] = []
for row_idx, row in enumerate(content_rows):
col_words = _assign_row_words_to_columns(row, relevant_cols)
for col_idx, col in enumerate(relevant_cols):
cell = _ocr_single_cell(
row_idx, col_idx, row, col,
ocr_img, img_bgr, img_w, img_h,
use_rapid, engine_name, lang, lang_map,
preassigned_words=col_words[col_idx],
)
cells.append(cell)
# --- BATCH FALLBACK: re-OCR empty cells by column strip ---
empty_by_col: Dict[int, List[int]] = {}
for ci, cell in enumerate(cells):
if not cell['text'].strip() and cell.get('ocr_engine') != 'cell_ocr_psm7':
bpx = cell['bbox_px']
x, y, w, h = bpx['x'], bpx['y'], bpx['w'], bpx['h']
if w > 0 and h > 0 and ocr_img is not None:
crop = ocr_img[y:y + h, x:x + w]
if crop.size > 0:
dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size
if dark_ratio > 0.005:
empty_by_col.setdefault(cell['col_index'], []).append(ci)
for col_idx, cell_indices in empty_by_col.items():
if len(cell_indices) < 3:
continue
min_y = min(cells[ci]['bbox_px']['y'] for ci in cell_indices)
max_y_h = max(cells[ci]['bbox_px']['y'] + cells[ci]['bbox_px']['h'] for ci in cell_indices)
col_x = cells[cell_indices[0]]['bbox_px']['x']
col_w = cells[cell_indices[0]]['bbox_px']['w']
strip_region = PageRegion(
type=relevant_cols[col_idx].type,
x=col_x, y=min_y,
width=col_w, height=max_y_h - min_y,
)
strip_lang = lang_map.get(relevant_cols[col_idx].type, lang)
if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None:
strip_words = ocr_region_trocr(img_bgr, strip_region, handwritten=(engine_name == "trocr-handwritten"))
elif engine_name == "lighton" and img_bgr is not None:
strip_words = ocr_region_lighton(img_bgr, strip_region)
elif use_rapid and img_bgr is not None:
strip_words = ocr_region_rapid(img_bgr, strip_region)
else:
strip_words = ocr_region(ocr_img, strip_region, lang=strip_lang, psm=6)
if not strip_words:
continue
strip_words = [w for w in strip_words if w.get('conf', 0) >= 30]
if not strip_words:
continue
for ci in cell_indices:
cell_y = cells[ci]['bbox_px']['y']
cell_h = cells[ci]['bbox_px']['h']
cell_mid_y = cell_y + cell_h / 2
matched_words = [
w for w in strip_words
if abs((w['top'] + w['height'] / 2) - cell_mid_y) < cell_h * 0.8
]
if matched_words:
matched_words.sort(key=lambda w: w['left'])
batch_text = ' '.join(w['text'] for w in matched_words)
batch_text = _clean_cell_text(batch_text)
if batch_text.strip():
cells[ci]['text'] = batch_text
cells[ci]['confidence'] = round(
sum(w['conf'] for w in matched_words) / len(matched_words), 1
)
cells[ci]['ocr_engine'] = 'batch_column_ocr'
batch_filled = sum(1 for ci in cell_indices if cells[ci]['text'].strip())
if batch_filled > 0:
logger.info(
f"build_cell_grid: batch OCR filled {batch_filled}/{len(cell_indices)} "
f"empty cells in column {col_idx}"
)
# Remove all-empty rows
rows_with_text: set = set()
for cell in cells:
if cell['text'].strip():
rows_with_text.add(cell['row_index'])
before_filter = len(cells)
cells = [c for c in cells if c['row_index'] in rows_with_text]
empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1)
if empty_rows_removed > 0:
logger.info(f"build_cell_grid: removed {empty_rows_removed} all-empty rows after OCR")
logger.info(f"build_cell_grid: {len(cells)} cells from "
f"{len(content_rows)} rows x {len(relevant_cols)} columns, "
f"engine={engine_name}")
return cells, columns_meta

View File

@@ -0,0 +1,235 @@
"""
Row-merging logic for vocabulary entries (phonetic, wrapped, continuation rows).
Extracted from cv_cell_grid.py.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List
from cv_ocr_engines import _RE_ALPHA
logger = logging.getLogger(__name__)
# Regex: line starts with phonetic bracket content only (no real word before it)
_PHONETIC_ONLY_RE = re.compile(
r'''^\s*[\[\('"]*[^\]]*[\])\s]*$'''
)
def _is_phonetic_only_text(text: str) -> bool:
"""Check if text consists only of phonetic transcription.
Phonetic-only patterns:
['mani serva] -> True
[dance] -> True
["a:mand] -> True
almond ['a:mand] -> False (has real word before bracket)
Mandel -> False
"""
t = text.strip()
if not t:
return False
# Must contain at least one bracket
if '[' not in t and ']' not in t:
return False
# Remove all bracket content and surrounding punctuation/whitespace
without_brackets = re.sub(r"\[.*?\]", '', t)
without_brackets = re.sub(r"[\[\]'\"()\s]", '', without_brackets)
# If nothing meaningful remains, it's phonetic-only
alpha_remaining = ''.join(_RE_ALPHA.findall(without_brackets))
return len(alpha_remaining) < 2
def _merge_phonetic_continuation_rows(
entries: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Merge rows that contain only phonetic transcription into previous entry.
In dictionary pages, phonetic transcription sometimes wraps to the next
row. E.g.:
Row 28: EN="it's a money-saver" DE="es spart Kosten"
Row 29: EN="['mani serva]" DE=""
Row 29 is phonetic-only -> merge into row 28's EN field.
"""
if len(entries) < 2:
return entries
merged: List[Dict[str, Any]] = []
for entry in entries:
en = (entry.get('english') or '').strip()
de = (entry.get('german') or '').strip()
ex = (entry.get('example') or '').strip()
# Check if this entry is phonetic-only (EN has only phonetics, DE empty)
if merged and _is_phonetic_only_text(en) and not de:
prev = merged[-1]
prev_en = (prev.get('english') or '').strip()
# Append phonetic to previous entry's EN
if prev_en:
prev['english'] = prev_en + ' ' + en
else:
prev['english'] = en
# If there was an example, append to previous too
if ex:
prev_ex = (prev.get('example') or '').strip()
prev['example'] = (prev_ex + ' ' + ex).strip() if prev_ex else ex
logger.debug(
f"Merged phonetic row {entry.get('row_index')} "
f"into previous entry: {prev['english']!r}"
)
continue
merged.append(entry)
return merged
def _merge_wrapped_rows(
entries: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Merge rows where the primary column (EN) is empty -- cell wrap continuation.
In textbook vocabulary tables, columns are often narrow, so the author
wraps text within a cell. OCR treats each physical line as a separate row.
The key indicator: if the EN column is empty but DE/example have text,
this row is a continuation of the previous row's cells.
Example (original textbook has ONE row):
Row 2: EN="take part (in)" DE="teilnehmen (an), mitmachen" EX="More than 200 singers took"
Row 3: EN="" DE="(bei)" EX="part in the concert."
-> Merged: EN="take part (in)" DE="teilnehmen (an), mitmachen (bei)" EX="..."
Also handles the reverse case: DE empty but EN has text (wrap in EN column).
"""
if len(entries) < 2:
return entries
merged: List[Dict[str, Any]] = []
for entry in entries:
en = (entry.get('english') or '').strip()
de = (entry.get('german') or '').strip()
ex = (entry.get('example') or '').strip()
if not merged:
merged.append(entry)
continue
prev = merged[-1]
prev_en = (prev.get('english') or '').strip()
prev_de = (prev.get('german') or '').strip()
prev_ex = (prev.get('example') or '').strip()
# Case 1: EN is empty -> continuation of previous row
if not en and (de or ex) and prev_en:
if de:
if prev_de.endswith(','):
sep = ' '
elif prev_de.endswith(('-', '(')):
sep = ''
else:
sep = ' '
prev['german'] = (prev_de + sep + de).strip()
if ex:
sep = ' ' if prev_ex else ''
prev['example'] = (prev_ex + sep + ex).strip()
logger.debug(
f"Merged wrapped row {entry.get('row_index')} into previous "
f"(empty EN): DE={prev['german']!r}, EX={prev.get('example', '')!r}"
)
continue
# Case 2: DE is empty, EN has text that looks like continuation
if en and not de and prev_de:
is_paren = en.startswith('(')
first_alpha = next((c for c in en if c.isalpha()), '')
starts_lower = first_alpha and first_alpha.islower()
if (is_paren or starts_lower) and len(en.split()) < 5:
sep = ' ' if prev_en and not prev_en.endswith((',', '-', '(')) else ''
prev['english'] = (prev_en + sep + en).strip()
if ex:
sep2 = ' ' if prev_ex else ''
prev['example'] = (prev_ex + sep2 + ex).strip()
logger.debug(
f"Merged wrapped row {entry.get('row_index')} into previous "
f"(empty DE): EN={prev['english']!r}"
)
continue
merged.append(entry)
if len(merged) < len(entries):
logger.info(
f"_merge_wrapped_rows: merged {len(entries) - len(merged)} "
f"continuation rows ({len(entries)} -> {len(merged)})"
)
return merged
def _merge_continuation_rows(
entries: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Merge multi-line vocabulary entries where text wraps to the next row.
A row is a continuation of the previous entry when:
- EN has text, but DE is empty
- EN starts with a lowercase letter (not a new vocab entry)
- Previous entry's EN does NOT end with a sentence terminator (.!?)
- The continuation text has fewer than 4 words (not an example sentence)
- The row was not already merged as phonetic
Example:
Row 5: EN="to put up" DE="aufstellen"
Row 6: EN="with sth." DE=""
-> Merged: EN="to put up with sth." DE="aufstellen"
"""
if len(entries) < 2:
return entries
merged: List[Dict[str, Any]] = []
for entry in entries:
en = (entry.get('english') or '').strip()
de = (entry.get('german') or '').strip()
if merged and en and not de:
# Check: not phonetic (already handled)
if _is_phonetic_only_text(en):
merged.append(entry)
continue
# Check: starts with lowercase
first_alpha = next((c for c in en if c.isalpha()), '')
starts_lower = first_alpha and first_alpha.islower()
# Check: fewer than 4 words (not an example sentence)
word_count = len(en.split())
is_short = word_count < 4
# Check: previous entry doesn't end with sentence terminator
prev = merged[-1]
prev_en = (prev.get('english') or '').strip()
prev_ends_sentence = prev_en and prev_en[-1] in '.!?'
if starts_lower and is_short and not prev_ends_sentence:
# Merge into previous entry
prev['english'] = (prev_en + ' ' + en).strip()
# Merge example if present
ex = (entry.get('example') or '').strip()
if ex:
prev_ex = (prev.get('example') or '').strip()
prev['example'] = (prev_ex + ' ' + ex).strip() if prev_ex else ex
logger.debug(
f"Merged continuation row {entry.get('row_index')} "
f"into previous entry: {prev['english']!r}"
)
continue
merged.append(entry)
return merged

View File

@@ -0,0 +1,217 @@
"""
Streaming variants of cell-grid builders (v2 + legacy).
Extracted from cv_cell_grid.py. These yield cells one-by-one as OCR'd,
useful for progress reporting.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, Generator, List, Optional, Tuple
import numpy as np
from cv_vocab_types import PageRegion, RowGeometry
from cv_ocr_engines import (
RAPIDOCR_AVAILABLE,
_assign_row_words_to_columns,
)
from cv_cell_grid_helpers import (
_heal_row_gaps,
_is_artifact_row,
)
from cv_cell_grid_build import _ocr_cell_crop
from cv_cell_grid_legacy import _ocr_single_cell
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# build_cell_grid_v2_streaming
# ---------------------------------------------------------------------------
def build_cell_grid_v2_streaming(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
"""Streaming variant of build_cell_grid_v2 -- yields each cell as OCR'd.
Yields:
(cell_dict, columns_meta, total_cells)
"""
use_rapid = False
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
engine_name = ocr_engine
elif ocr_engine == "auto":
engine_name = "tesseract"
elif ocr_engine == "rapid":
if not RAPIDOCR_AVAILABLE:
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
else:
use_rapid = True
engine_name = "rapid" if use_rapid else "tesseract"
else:
engine_name = "tesseract"
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
return
content_rows = [r for r in content_rows if r.word_count > 0]
if not content_rows:
return
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
'margin_bottom', 'margin_left', 'margin_right'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
return
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
if not content_rows:
return
# Use header/footer boundaries for heal_row_gaps
content_rows.sort(key=lambda r: r.y)
header_rows = [r for r in row_geometries if r.row_type == 'header']
footer_rows = [r for r in row_geometries if r.row_type == 'footer']
if header_rows:
top_bound = max(r.y + r.height for r in header_rows)
else:
top_bound = content_rows[0].y
if footer_rows:
bottom_bound = min(r.y for r in footer_rows)
else:
bottom_bound = content_rows[-1].y + content_rows[-1].height
_heal_row_gaps(content_rows, top_bound=top_bound, bottom_bound=bottom_bound)
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
for ci, c in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
total_cells = len(content_rows) * len(relevant_cols)
for row_idx, row in enumerate(content_rows):
for col_idx, col in enumerate(relevant_cols):
cell = _ocr_cell_crop(
row_idx, col_idx, row, col,
ocr_img, img_bgr, img_w, img_h,
engine_name, lang, lang_map,
)
yield cell, columns_meta, total_cells
# ---------------------------------------------------------------------------
# build_cell_grid_streaming — legacy streaming variant
# ---------------------------------------------------------------------------
def build_cell_grid_streaming(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
"""Like build_cell_grid(), but yields each cell as it is OCR'd.
DEPRECATED: Use build_cell_grid_v2_streaming instead.
Yields:
(cell_dict, columns_meta, total_cells) for each cell.
"""
use_rapid = False
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
engine_name = ocr_engine
elif ocr_engine == "auto":
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
engine_name = "rapid" if use_rapid else "tesseract"
elif ocr_engine == "rapid":
if not RAPIDOCR_AVAILABLE:
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
else:
use_rapid = True
engine_name = "rapid" if use_rapid else "tesseract"
else:
engine_name = "tesseract"
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
return
before = len(content_rows)
content_rows = [r for r in content_rows if r.word_count > 0]
skipped = before - len(content_rows)
if skipped > 0:
logger.info(f"build_cell_grid_streaming: skipped {skipped} phantom rows (word_count=0)")
if not content_rows:
return
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
return
before_art = len(content_rows)
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
artifact_skipped = before_art - len(content_rows)
if artifact_skipped > 0:
logger.info(f"build_cell_grid_streaming: skipped {artifact_skipped} artifact rows")
if not content_rows:
return
_heal_row_gaps(
content_rows,
top_bound=min(c.y for c in relevant_cols),
bottom_bound=max(c.y + c.height for c in relevant_cols),
)
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{
'index': col_idx,
'type': col.type,
'x': col.x,
'width': col.width,
}
for col_idx, col in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
total_cells = len(content_rows) * len(relevant_cols)
for row_idx, row in enumerate(content_rows):
col_words = _assign_row_words_to_columns(row, relevant_cols)
for col_idx, col in enumerate(relevant_cols):
cell = _ocr_single_cell(
row_idx, col_idx, row, col,
ocr_img, img_bgr, img_w, img_h,
use_rapid, engine_name, lang, lang_map,
preassigned_words=col_words[col_idx],
)
yield cell, columns_meta, total_cells

View File

@@ -0,0 +1,200 @@
"""
Vocabulary extraction: cells -> vocab entries, and build_word_grid wrapper.
Extracted from cv_cell_grid.py.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
from cv_ocr_engines import (
_attach_example_sentences,
_fix_phonetic_brackets,
_split_comma_entries,
)
from cv_cell_grid_legacy import build_cell_grid
from cv_cell_grid_merge import (
_merge_continuation_rows,
_merge_phonetic_continuation_rows,
_merge_wrapped_rows,
)
logger = logging.getLogger(__name__)
def _cells_to_vocab_entries(
cells: List[Dict[str, Any]],
columns_meta: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Map generic cells to vocab entries with english/german/example fields.
Groups cells by row_index, maps col_type -> field name, and produces
one entry per row (only rows with at least one non-empty field).
"""
col_type_to_field = {
'column_en': 'english',
'column_de': 'german',
'column_example': 'example',
'page_ref': 'source_page',
'column_marker': 'marker',
'column_text': 'text', # generic single-column (box sub-sessions)
}
bbox_key_map = {
'column_en': 'bbox_en',
'column_de': 'bbox_de',
'column_example': 'bbox_ex',
'page_ref': 'bbox_ref',
'column_marker': 'bbox_marker',
'column_text': 'bbox_text',
}
# Group cells by row_index
rows: Dict[int, List[Dict]] = {}
for cell in cells:
ri = cell['row_index']
rows.setdefault(ri, []).append(cell)
entries: List[Dict[str, Any]] = []
for row_idx in sorted(rows.keys()):
row_cells = rows[row_idx]
entry: Dict[str, Any] = {
'row_index': row_idx,
'english': '',
'german': '',
'example': '',
'text': '', # generic single-column (box sub-sessions)
'source_page': '',
'marker': '',
'confidence': 0.0,
'bbox': None,
'bbox_en': None,
'bbox_de': None,
'bbox_ex': None,
'bbox_ref': None,
'bbox_marker': None,
'bbox_text': None,
'ocr_engine': row_cells[0].get('ocr_engine', '') if row_cells else '',
}
confidences = []
for cell in row_cells:
col_type = cell['col_type']
field = col_type_to_field.get(col_type)
if field:
entry[field] = cell['text']
bbox_field = bbox_key_map.get(col_type)
if bbox_field:
entry[bbox_field] = cell['bbox_pct']
if cell['confidence'] > 0:
confidences.append(cell['confidence'])
# Compute row-level bbox as union of all cell bboxes
all_bboxes = [c['bbox_pct'] for c in row_cells if c.get('bbox_pct')]
if all_bboxes:
min_x = min(b['x'] for b in all_bboxes)
min_y = min(b['y'] for b in all_bboxes)
max_x2 = max(b['x'] + b['w'] for b in all_bboxes)
max_y2 = max(b['y'] + b['h'] for b in all_bboxes)
entry['bbox'] = {
'x': round(min_x, 2),
'y': round(min_y, 2),
'w': round(max_x2 - min_x, 2),
'h': round(max_y2 - min_y, 2),
}
entry['confidence'] = round(
sum(confidences) / len(confidences), 1
) if confidences else 0.0
# Only include if at least one mapped field has text
has_content = any(
entry.get(f)
for f in col_type_to_field.values()
)
if has_content:
entries.append(entry)
return entries
def build_word_grid(
ocr_img,
column_regions,
row_geometries,
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr=None,
pronunciation: str = "british",
) -> List[Dict[str, Any]]:
"""Vocab-specific: Cell-Grid + Vocab-Mapping + Post-Processing.
Wrapper around build_cell_grid() that adds vocabulary-specific logic:
- Maps cells to english/german/example entries
- Applies character confusion fixes, IPA lookup, comma splitting, etc.
- Falls back to returning raw cells if no vocab columns detected.
Args:
ocr_img: Binarized full-page image (for Tesseract).
column_regions: Classified columns from Step 3.
row_geometries: Rows from Step 4.
img_w, img_h: Image dimensions.
lang: Default Tesseract language.
ocr_engine: 'tesseract', 'rapid', or 'auto'.
img_bgr: BGR color image (required for RapidOCR).
pronunciation: 'british' or 'american' for IPA lookup.
Returns:
List of entry dicts with english/german/example text and bbox info (percent).
"""
cells, columns_meta = build_cell_grid(
ocr_img, column_regions, row_geometries, img_w, img_h,
lang=lang, ocr_engine=ocr_engine, img_bgr=img_bgr,
)
if not cells:
return []
# Check if vocab layout is present
col_types = {c['type'] for c in columns_meta}
if not (col_types & {'column_en', 'column_de'}):
logger.info("build_word_grid: no vocab columns -- returning raw cells")
return cells
# Vocab mapping: cells -> entries
entries = _cells_to_vocab_entries(cells, columns_meta)
# --- Post-processing pipeline (deterministic, no LLM) ---
n_raw = len(entries)
# 0. Merge cell-wrap continuation rows (empty primary column = text wrap)
entries = _merge_wrapped_rows(entries)
# 0a. Merge phonetic-only continuation rows into previous entry
entries = _merge_phonetic_continuation_rows(entries)
# 0b. Merge multi-line continuation rows (lowercase EN, empty DE)
entries = _merge_continuation_rows(entries)
# 1. Character confusion (| -> I, 1 -> I, 8 -> B) is now run in
# llm_review_entries_streaming so changes are visible to the user in Step 6.
# 2. Replace OCR'd phonetics with dictionary IPA
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
# 3. Split comma-separated word forms (break, broke, broken -> 3 entries)
entries = _split_comma_entries(entries)
# 4. Attach example sentences (rows without DE -> examples for preceding entry)
entries = _attach_example_sentences(entries)
engine_name = cells[0].get('ocr_engine', 'unknown') if cells else 'unknown'
logger.info(f"build_word_grid: {len(entries)} entries from "
f"{n_raw} raw -> {len(entries)} after post-processing "
f"(engine={engine_name})")
return entries

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,437 @@
"""
CV Preprocessing Deskew — Rotation correction via Hough lines, word alignment, and iterative projection.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from collections import defaultdict
from typing import Any, Dict, Tuple
import numpy as np
from cv_vocab_types import (
CV2_AVAILABLE,
TESSERACT_AVAILABLE,
)
logger = logging.getLogger(__name__)
try:
import cv2
except ImportError:
cv2 = None # type: ignore[assignment]
try:
import pytesseract
from PIL import Image
except ImportError:
pytesseract = None # type: ignore[assignment]
Image = None # type: ignore[assignment,misc]
# =============================================================================
# Deskew via Hough Lines
# =============================================================================
def deskew_image(img: np.ndarray) -> Tuple[np.ndarray, float]:
"""Correct rotation using Hough Line detection.
Args:
img: BGR image.
Returns:
Tuple of (corrected image, detected angle in degrees).
"""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
lines = cv2.HoughLinesP(binary, 1, np.pi / 180, threshold=100,
minLineLength=img.shape[1] // 4, maxLineGap=20)
if lines is None or len(lines) < 3:
return img, 0.0
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
if abs(angle) < 15:
angles.append(angle)
if not angles:
return img, 0.0
median_angle = float(np.median(angles))
if abs(median_angle) > 5.0:
median_angle = 5.0 * np.sign(median_angle)
if abs(median_angle) < 0.1:
return img, 0.0
h, w = img.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
corrected = cv2.warpAffine(img, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
logger.info(f"Deskew: corrected {median_angle:.2f}\u00b0 rotation")
return corrected, median_angle
# =============================================================================
# Deskew via Word Alignment
# =============================================================================
def deskew_image_by_word_alignment(
image_data: bytes,
lang: str = "eng+deu",
downscale_factor: float = 0.5,
) -> Tuple[bytes, float]:
"""Correct rotation by fitting a line through left-most word starts per text line.
More robust than Hough-based deskew for vocabulary worksheets where text lines
have consistent left-alignment.
Args:
image_data: Raw image bytes (PNG/JPEG).
lang: Tesseract language string for the quick pass.
downscale_factor: Shrink factor for the quick Tesseract pass (0.5 = 50%).
Returns:
Tuple of (rotated image as PNG bytes, detected angle in degrees).
"""
if not CV2_AVAILABLE or not TESSERACT_AVAILABLE:
return image_data, 0.0
img_array = np.frombuffer(image_data, dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
if img is None:
logger.warning("deskew_by_word_alignment: could not decode image")
return image_data, 0.0
orig_h, orig_w = img.shape[:2]
small_w = int(orig_w * downscale_factor)
small_h = int(orig_h * downscale_factor)
small = cv2.resize(img, (small_w, small_h), interpolation=cv2.INTER_AREA)
pil_small = Image.fromarray(cv2.cvtColor(small, cv2.COLOR_BGR2RGB))
try:
data = pytesseract.image_to_data(
pil_small, lang=lang, config="--psm 6 --oem 3",
output_type=pytesseract.Output.DICT,
)
except Exception as e:
logger.warning(f"deskew_by_word_alignment: Tesseract failed: {e}")
return image_data, 0.0
line_groups: Dict[tuple, list] = defaultdict(list)
for i in range(len(data["text"])):
text = (data["text"][i] or "").strip()
conf = int(data["conf"][i])
if not text or conf < 20:
continue
key = (data["block_num"][i], data["par_num"][i], data["line_num"][i])
line_groups[key].append(i)
if len(line_groups) < 5:
logger.info(f"deskew_by_word_alignment: only {len(line_groups)} lines, skipping")
return image_data, 0.0
scale = 1.0 / downscale_factor
points = []
for key, indices in line_groups.items():
best_idx = min(indices, key=lambda i: data["left"][i])
lx = data["left"][best_idx] * scale
top = data["top"][best_idx] * scale
h = data["height"][best_idx] * scale
cy = top + h / 2.0
points.append((lx, cy))
xs = np.array([p[0] for p in points])
ys = np.array([p[1] for p in points])
median_x = float(np.median(xs))
tolerance = orig_w * 0.03
mask = np.abs(xs - median_x) <= tolerance
filtered_xs = xs[mask]
filtered_ys = ys[mask]
if len(filtered_xs) < 5:
logger.info(f"deskew_by_word_alignment: only {len(filtered_xs)} aligned points after filter, skipping")
return image_data, 0.0
coeffs = np.polyfit(filtered_ys, filtered_xs, 1)
slope = coeffs[0]
angle_rad = np.arctan(slope)
angle_deg = float(np.degrees(angle_rad))
angle_deg = max(-5.0, min(5.0, angle_deg))
logger.info(f"deskew_by_word_alignment: detected {angle_deg:.2f}\u00b0 from {len(filtered_xs)} points "
f"(total lines: {len(line_groups)})")
if abs(angle_deg) < 0.05:
return image_data, 0.0
center = (orig_w // 2, orig_h // 2)
M = cv2.getRotationMatrix2D(center, angle_deg, 1.0)
rotated = cv2.warpAffine(img, M, (orig_w, orig_h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
success, png_buf = cv2.imencode(".png", rotated)
if not success:
logger.warning("deskew_by_word_alignment: PNG encoding failed")
return image_data, 0.0
return png_buf.tobytes(), angle_deg
# =============================================================================
# Projection Gradient Scoring
# =============================================================================
def _projection_gradient_score(profile: np.ndarray) -> float:
"""Score a projection profile by the L2-norm of its first derivative."""
diff = np.diff(profile)
return float(np.sum(diff * diff))
# =============================================================================
# Iterative Deskew (Vertical-Edge Projection)
# =============================================================================
def deskew_image_iterative(
img: np.ndarray,
coarse_range: float = 5.0,
coarse_step: float = 0.1,
fine_range: float = 0.15,
fine_step: float = 0.02,
) -> Tuple[np.ndarray, float, Dict[str, Any]]:
"""Iterative deskew using vertical-edge projection optimisation.
Args:
img: BGR image (full resolution).
coarse_range: half-range in degrees for the coarse sweep.
coarse_step: step size in degrees for the coarse sweep.
fine_range: half-range around the coarse winner for the fine sweep.
fine_step: step size in degrees for the fine sweep.
Returns:
(rotated_bgr, angle_degrees, debug_dict)
"""
h, w = img.shape[:2]
debug: Dict[str, Any] = {}
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
y_lo, y_hi = int(h * 0.15), int(h * 0.85)
x_lo, x_hi = int(w * 0.10), int(w * 0.90)
gray_crop = gray[y_lo:y_hi, x_lo:x_hi]
sobel_x = cv2.Sobel(gray_crop, cv2.CV_64F, 1, 0, ksize=3)
edges = np.abs(sobel_x)
edge_max = edges.max()
if edge_max > 0:
edges = (edges / edge_max * 255).astype(np.uint8)
else:
return img, 0.0, {"error": "no edges detected"}
crop_h, crop_w = edges.shape[:2]
crop_center = (crop_w // 2, crop_h // 2)
trim_y = max(4, int(crop_h * 0.03))
trim_x = max(4, int(crop_w * 0.03))
def _sweep_edges(angles: np.ndarray) -> list:
results = []
for angle in angles:
if abs(angle) < 1e-6:
rotated = edges
else:
M = cv2.getRotationMatrix2D(crop_center, angle, 1.0)
rotated = cv2.warpAffine(edges, M, (crop_w, crop_h),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_REPLICATE)
trimmed = rotated[trim_y:-trim_y, trim_x:-trim_x]
v_profile = np.sum(trimmed, axis=0, dtype=np.float64)
score = _projection_gradient_score(v_profile)
results.append((float(angle), score))
return results
coarse_angles = np.arange(-coarse_range, coarse_range + coarse_step * 0.5, coarse_step)
coarse_results = _sweep_edges(coarse_angles)
best_coarse = max(coarse_results, key=lambda x: x[1])
best_coarse_angle, best_coarse_score = best_coarse
debug["coarse_best_angle"] = round(best_coarse_angle, 2)
debug["coarse_best_score"] = round(best_coarse_score, 1)
debug["coarse_scores"] = [(round(a, 2), round(s, 1)) for a, s in coarse_results]
fine_lo = best_coarse_angle - fine_range
fine_hi = best_coarse_angle + fine_range
fine_angles = np.arange(fine_lo, fine_hi + fine_step * 0.5, fine_step)
fine_results = _sweep_edges(fine_angles)
best_fine = max(fine_results, key=lambda x: x[1])
best_fine_angle, best_fine_score = best_fine
debug["fine_best_angle"] = round(best_fine_angle, 2)
debug["fine_best_score"] = round(best_fine_score, 1)
debug["fine_scores"] = [(round(a, 2), round(s, 1)) for a, s in fine_results]
final_angle = best_fine_angle
final_angle = max(-5.0, min(5.0, final_angle))
logger.info(f"deskew_iterative: coarse={best_coarse_angle:.2f}\u00b0 fine={best_fine_angle:.2f}\u00b0 -> {final_angle:.2f}\u00b0")
if abs(final_angle) < 0.05:
return img, 0.0, debug
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, final_angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
return rotated, final_angle, debug
# =============================================================================
# Text-Line Slope Measurement
# =============================================================================
def _measure_textline_slope(img: np.ndarray) -> float:
"""Measure residual text-line slope via Tesseract word-position regression."""
import math as _math
if not TESSERACT_AVAILABLE or not CV2_AVAILABLE:
return 0.0
h, w = img.shape[:2]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
data = pytesseract.image_to_data(
Image.fromarray(gray),
output_type=pytesseract.Output.DICT,
config="--psm 6",
)
lines: Dict[tuple, list] = {}
for i in range(len(data["text"])):
txt = (data["text"][i] or "").strip()
if len(txt) < 2 or int(data["conf"][i]) < 30:
continue
key = (data["block_num"][i], data["par_num"][i], data["line_num"][i])
cx = data["left"][i] + data["width"][i] / 2.0
cy = data["top"][i] + data["height"][i] / 2.0
lines.setdefault(key, []).append((cx, cy))
slopes: list = []
for pts in lines.values():
if len(pts) < 3:
continue
pts.sort(key=lambda p: p[0])
xs = np.array([p[0] for p in pts], dtype=np.float64)
ys = np.array([p[1] for p in pts], dtype=np.float64)
if xs[-1] - xs[0] < w * 0.15:
continue
A = np.vstack([xs, np.ones_like(xs)]).T
result = np.linalg.lstsq(A, ys, rcond=None)
slope = result[0][0]
slopes.append(_math.degrees(_math.atan(slope)))
if len(slopes) < 3:
return 0.0
slopes.sort()
trim = max(1, len(slopes) // 10)
trimmed = slopes[trim:-trim] if len(slopes) > 2 * trim else slopes
if not trimmed:
return 0.0
return sum(trimmed) / len(trimmed)
# =============================================================================
# Two-Pass Deskew
# =============================================================================
def deskew_two_pass(
img: np.ndarray,
coarse_range: float = 5.0,
) -> Tuple[np.ndarray, float, Dict[str, Any]]:
"""Two-pass deskew: iterative projection + word-alignment residual check.
Returns:
(corrected_bgr, total_angle_degrees, debug_dict)
"""
debug: Dict[str, Any] = {}
# --- Pass 1: iterative projection ---
corrected, angle1, dbg1 = deskew_image_iterative(
img.copy(), coarse_range=coarse_range,
)
debug["pass1_angle"] = round(angle1, 3)
debug["pass1_method"] = "iterative"
debug["pass1_debug"] = dbg1
# --- Pass 2: word-alignment residual check ---
angle2 = 0.0
try:
ok, buf = cv2.imencode(".png", corrected)
if ok:
corrected_bytes, angle2 = deskew_image_by_word_alignment(buf.tobytes())
if abs(angle2) >= 0.3:
arr2 = np.frombuffer(corrected_bytes, dtype=np.uint8)
corrected2 = cv2.imdecode(arr2, cv2.IMREAD_COLOR)
if corrected2 is not None:
corrected = corrected2
logger.info(f"deskew_two_pass: pass2 residual={angle2:.2f}\u00b0 applied "
f"(total={angle1 + angle2:.2f}\u00b0)")
else:
angle2 = 0.0
else:
logger.info(f"deskew_two_pass: pass2 residual={angle2:.2f}\u00b0 < 0.3\u00b0 -- skipped")
angle2 = 0.0
except Exception as e:
logger.warning(f"deskew_two_pass: pass2 word-alignment failed: {e}")
angle2 = 0.0
# --- Pass 3: Tesseract text-line regression residual check ---
angle3 = 0.0
try:
residual = _measure_textline_slope(corrected)
debug["pass3_raw"] = round(residual, 3)
if abs(residual) >= 0.3:
h3, w3 = corrected.shape[:2]
center3 = (w3 // 2, h3 // 2)
M3 = cv2.getRotationMatrix2D(center3, residual, 1.0)
corrected = cv2.warpAffine(
corrected, M3, (w3, h3),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
angle3 = residual
logger.info("deskew_two_pass: pass3 text-line residual=%.2f\u00b0 applied", residual)
else:
logger.info("deskew_two_pass: pass3 text-line residual=%.2f\u00b0 < 0.3\u00b0 -- skipped", residual)
except Exception as e:
logger.warning("deskew_two_pass: pass3 text-line check failed: %s", e)
total_angle = angle1 + angle2 + angle3
debug["pass2_angle"] = round(angle2, 3)
debug["pass2_method"] = "word_alignment"
debug["pass3_angle"] = round(angle3, 3)
debug["pass3_method"] = "textline_regression"
debug["total_angle"] = round(total_angle, 3)
logger.info(
"deskew_two_pass: pass1=%.2f\u00b0 + pass2=%.2f\u00b0 + pass3=%.2f\u00b0 = %.2f\u00b0",
angle1, angle2, angle3, total_angle,
)
return corrected, total_angle, debug

View File

@@ -0,0 +1,474 @@
"""
CV Preprocessing Dewarp — Vertical shear detection and correction.
Provides four shear detection methods (vertical edge, projection variance,
Hough lines, text-line drift), ensemble combination, quality gating,
and the main dewarp_image() function.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import math
import time
from typing import Any, Dict, List, Tuple
import numpy as np
from cv_vocab_types import (
CV2_AVAILABLE,
TESSERACT_AVAILABLE,
)
logger = logging.getLogger(__name__)
try:
import cv2
except ImportError:
cv2 = None # type: ignore[assignment]
try:
import pytesseract
from PIL import Image
except ImportError:
pytesseract = None # type: ignore[assignment]
Image = None # type: ignore[assignment,misc]
# =============================================================================
# Shear Detection Methods
# =============================================================================
def _detect_shear_angle(img: np.ndarray) -> Dict[str, Any]:
"""Detect vertical shear angle via strongest vertical edge tracking (Method A)."""
h, w = img.shape[:2]
result = {"method": "vertical_edge", "shear_degrees": 0.0, "confidence": 0.0}
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
abs_sobel = np.abs(sobel_x).astype(np.uint8)
_, binary = cv2.threshold(abs_sobel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
num_strips = 20
strip_h = h // num_strips
edge_positions = []
for i in range(num_strips):
y_start = i * strip_h
y_end = min((i + 1) * strip_h, h)
strip = binary[y_start:y_end, :]
projection = np.sum(strip, axis=0).astype(np.float64)
if projection.max() == 0:
continue
search_w = int(w * 0.4)
left_proj = projection[:search_w]
if left_proj.max() == 0:
continue
kernel_size = max(3, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
smoothed = cv2.GaussianBlur(left_proj.reshape(1, -1), (kernel_size, 1), 0).flatten()
x_pos = float(np.argmax(smoothed))
y_center = (y_start + y_end) / 2.0
edge_positions.append((y_center, x_pos))
if len(edge_positions) < 8:
return result
ys = np.array([p[0] for p in edge_positions])
xs = np.array([p[1] for p in edge_positions])
median_x = np.median(xs)
std_x = max(np.std(xs), 1.0)
mask = np.abs(xs - median_x) < 2 * std_x
ys = ys[mask]
xs = xs[mask]
if len(ys) < 6:
return result
straight_coeffs = np.polyfit(ys, xs, 1)
slope = straight_coeffs[0]
fitted = np.polyval(straight_coeffs, ys)
residuals = xs - fitted
rmse = float(np.sqrt(np.mean(residuals ** 2)))
shear_degrees = math.degrees(math.atan(slope))
confidence = min(1.0, len(ys) / 15.0) * max(0.5, 1.0 - rmse / 5.0)
result["shear_degrees"] = round(shear_degrees, 3)
result["confidence"] = round(float(confidence), 2)
return result
def _detect_shear_by_projection(img: np.ndarray) -> Dict[str, Any]:
"""Detect shear angle by maximising variance of horizontal text-line projections (Method B)."""
result = {"method": "projection", "shear_degrees": 0.0, "confidence": 0.0}
h, w = img.shape[:2]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
small = cv2.resize(binary, (w // 2, h // 2), interpolation=cv2.INTER_AREA)
sh, sw = small.shape
def _sweep_variance(angles_list):
results = []
for angle_deg in angles_list:
if abs(angle_deg) < 0.001:
rotated = small
else:
shear_tan = math.tan(math.radians(angle_deg))
M = np.float32([[1, shear_tan, -sh / 2.0 * shear_tan], [0, 1, 0]])
rotated = cv2.warpAffine(small, M, (sw, sh),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT)
profile = np.sum(rotated, axis=1).astype(float)
results.append((angle_deg, float(np.var(profile))))
return results
coarse_angles = [a * 0.5 for a in range(-6, 7)]
coarse_results = _sweep_variance(coarse_angles)
coarse_best = max(coarse_results, key=lambda x: x[1])
fine_center = coarse_best[0]
fine_angles = [fine_center + a * 0.05 for a in range(-10, 11)]
fine_results = _sweep_variance(fine_angles)
fine_best = max(fine_results, key=lambda x: x[1])
best_angle = fine_best[0]
best_variance = fine_best[1]
variances = coarse_results + fine_results
all_mean = sum(v for _, v in variances) / len(variances)
if all_mean > 0 and best_variance > all_mean:
confidence = min(1.0, (best_variance - all_mean) / (all_mean + 1.0) * 0.6)
else:
confidence = 0.0
result["shear_degrees"] = round(best_angle, 3)
result["confidence"] = round(max(0.0, min(1.0, confidence)), 2)
return result
def _detect_shear_by_hough(img: np.ndarray) -> Dict[str, Any]:
"""Detect shear using Hough transform on printed table / ruled lines (Method C)."""
result = {"method": "hough_lines", "shear_degrees": 0.0, "confidence": 0.0}
h, w = img.shape[:2]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
min_len = int(w * 0.15)
lines = cv2.HoughLinesP(
edges, rho=1, theta=np.pi / 360,
threshold=int(w * 0.08),
minLineLength=min_len,
maxLineGap=20,
)
if lines is None or len(lines) < 3:
return result
horizontal_angles: List[Tuple[float, float]] = []
for line in lines:
x1, y1, x2, y2 = line[0]
if x1 == x2:
continue
angle = float(np.degrees(np.arctan2(y2 - y1, x2 - x1)))
if abs(angle) <= 5.0:
length = float(np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2))
horizontal_angles.append((angle, length))
if len(horizontal_angles) < 3:
return result
angles_arr = np.array([a for a, _ in horizontal_angles])
weights_arr = np.array([l for _, l in horizontal_angles])
sorted_idx = np.argsort(angles_arr)
s_angles = angles_arr[sorted_idx]
s_weights = weights_arr[sorted_idx]
cum = np.cumsum(s_weights)
mid_idx = int(np.searchsorted(cum, cum[-1] / 2.0))
median_angle = float(s_angles[min(mid_idx, len(s_angles) - 1)])
agree = sum(1 for a, _ in horizontal_angles if abs(a - median_angle) < 1.0)
confidence = min(1.0, agree / max(len(horizontal_angles), 1)) * 0.85
shear_degrees = -median_angle
result["shear_degrees"] = round(shear_degrees, 3)
result["confidence"] = round(max(0.0, min(1.0, confidence)), 2)
return result
def _detect_shear_by_text_lines(img: np.ndarray) -> Dict[str, Any]:
"""Detect shear by measuring text-line straightness (Method D)."""
result = {"method": "text_lines", "shear_degrees": 0.0, "confidence": 0.0}
h, w = img.shape[:2]
scale = 0.5
small = cv2.resize(img, (int(w * scale), int(h * scale)),
interpolation=cv2.INTER_AREA)
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
pil_img = Image.fromarray(gray)
try:
data = pytesseract.image_to_data(
pil_img, lang='eng+deu', config='--psm 11 --oem 3',
output_type=pytesseract.Output.DICT,
)
except Exception:
return result
words = []
for i in range(len(data['text'])):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if not text or conf < 20 or len(text) < 2:
continue
left_x = float(data['left'][i])
cy = data['top'][i] + data['height'][i] / 2.0
word_w = float(data['width'][i])
words.append((left_x, cy, word_w))
if len(words) < 15:
return result
avg_w = sum(ww for _, _, ww in words) / len(words)
x_tol = max(avg_w * 0.4, 8)
words_by_x = sorted(words, key=lambda w: w[0])
columns: List[List[Tuple[float, float]]] = []
cur_col: List[Tuple[float, float]] = [(words_by_x[0][0], words_by_x[0][1])]
cur_x = words_by_x[0][0]
for lx, cy, _ in words_by_x[1:]:
if abs(lx - cur_x) <= x_tol:
cur_col.append((lx, cy))
cur_x = cur_x * 0.8 + lx * 0.2
else:
if len(cur_col) >= 5:
columns.append(cur_col)
cur_col = [(lx, cy)]
cur_x = lx
if len(cur_col) >= 5:
columns.append(cur_col)
if len(columns) < 2:
return result
drifts = []
for col in columns:
ys = np.array([p[1] for p in col])
xs = np.array([p[0] for p in col])
y_range = ys.max() - ys.min()
if y_range < h * scale * 0.3:
continue
coeffs = np.polyfit(ys, xs, 1)
drifts.append(coeffs[0])
if len(drifts) < 2:
return result
median_drift = float(np.median(drifts))
shear_degrees = math.degrees(math.atan(median_drift))
drift_std = float(np.std(drifts))
consistency = max(0.0, 1.0 - drift_std * 50)
count_factor = min(1.0, len(drifts) / 4.0)
confidence = count_factor * 0.5 + consistency * 0.5
result["shear_degrees"] = round(shear_degrees, 3)
result["confidence"] = round(max(0.0, min(1.0, confidence)), 2)
logger.info("text_lines(v2): %d columns, %d drifts, median=%.4f, "
"shear=%.3f\u00b0, conf=%.2f",
len(columns), len(drifts), median_drift,
shear_degrees, confidence)
return result
# =============================================================================
# Quality Check and Shear Application
# =============================================================================
def _dewarp_quality_check(original: np.ndarray, corrected: np.ndarray) -> bool:
"""Check whether the dewarp correction actually improved alignment."""
def _h_proj_variance(img: np.ndarray) -> float:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255,
cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
small = cv2.resize(binary, (binary.shape[1] // 2, binary.shape[0] // 2),
interpolation=cv2.INTER_AREA)
profile = np.sum(small, axis=1).astype(float)
return float(np.var(profile))
var_before = _h_proj_variance(original)
var_after = _h_proj_variance(corrected)
return var_after > var_before
def _apply_shear(img: np.ndarray, shear_degrees: float) -> np.ndarray:
"""Apply a vertical shear correction to an image."""
h, w = img.shape[:2]
shear_tan = math.tan(math.radians(shear_degrees))
M = np.float32([
[1, shear_tan, -h / 2.0 * shear_tan],
[0, 1, 0],
])
corrected = cv2.warpAffine(img, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
return corrected
# =============================================================================
# Ensemble Shear Combination
# =============================================================================
def _ensemble_shear(detections: List[Dict[str, Any]]) -> Tuple[float, float, str]:
"""Combine multiple shear detections into a single weighted estimate (v2)."""
_MIN_CONF = 0.35
_METHOD_WEIGHT_BOOST = {"text_lines": 1.5}
accepted = []
for d in detections:
if d["confidence"] < _MIN_CONF:
continue
boost = _METHOD_WEIGHT_BOOST.get(d["method"], 1.0)
effective_conf = d["confidence"] * boost
accepted.append((d["shear_degrees"], effective_conf, d["method"]))
if not accepted:
return 0.0, 0.0, "none"
if len(accepted) == 1:
deg, conf, method = accepted[0]
return deg, min(conf, 1.0), method
total_w = sum(c for _, c, _ in accepted)
w_mean = sum(d * c for d, c, _ in accepted) / total_w
filtered = [(d, c, m) for d, c, m in accepted if abs(d - w_mean) <= 1.0]
if not filtered:
filtered = accepted
total_w2 = sum(c for _, c, _ in filtered)
final_deg = sum(d * c for d, c, _ in filtered) / total_w2
avg_conf = total_w2 / len(filtered)
spread = max(d for d, _, _ in filtered) - min(d for d, _, _ in filtered)
agreement_bonus = 0.15 if spread < 0.5 else 0.0
ensemble_conf = min(1.0, avg_conf + agreement_bonus)
methods_str = "+".join(m for _, _, m in filtered)
return round(final_deg, 3), round(min(ensemble_conf, 1.0), 2), methods_str
# =============================================================================
# Main Dewarp Function
# =============================================================================
def dewarp_image(img: np.ndarray, use_ensemble: bool = True) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Correct vertical shear after deskew (v2 with quality gate).
Methods (all run in ~150ms total):
A. _detect_shear_angle() -- vertical edge profile (~50ms)
B. _detect_shear_by_projection() -- horizontal text-line variance (~30ms)
C. _detect_shear_by_hough() -- Hough lines on table borders (~20ms)
D. _detect_shear_by_text_lines() -- text-line straightness (~50ms)
Args:
img: BGR image (already deskewed).
use_ensemble: If False, fall back to single-method behaviour (method A only).
Returns:
Tuple of (corrected_image, dewarp_info).
"""
no_correction = {
"method": "none",
"shear_degrees": 0.0,
"confidence": 0.0,
"detections": [],
}
if not CV2_AVAILABLE:
return img, no_correction
t0 = time.time()
if use_ensemble:
det_a = _detect_shear_angle(img)
det_b = _detect_shear_by_projection(img)
det_c = _detect_shear_by_hough(img)
det_d = _detect_shear_by_text_lines(img)
detections = [det_a, det_b, det_c, det_d]
shear_deg, confidence, method = _ensemble_shear(detections)
else:
det_a = _detect_shear_angle(img)
detections = [det_a]
shear_deg = det_a["shear_degrees"]
confidence = det_a["confidence"]
method = det_a["method"]
duration = time.time() - t0
logger.info(
"dewarp: ensemble shear=%.3f\u00b0 conf=%.2f method=%s (%.2fs) | "
"A=%.3f/%.2f B=%.3f/%.2f C=%.3f/%.2f D=%.3f/%.2f",
shear_deg, confidence, method, duration,
detections[0]["shear_degrees"], detections[0]["confidence"],
detections[1]["shear_degrees"] if len(detections) > 1 else 0.0,
detections[1]["confidence"] if len(detections) > 1 else 0.0,
detections[2]["shear_degrees"] if len(detections) > 2 else 0.0,
detections[2]["confidence"] if len(detections) > 2 else 0.0,
detections[3]["shear_degrees"] if len(detections) > 3 else 0.0,
detections[3]["confidence"] if len(detections) > 3 else 0.0,
)
_all_detections = [
{"method": d["method"], "shear_degrees": d["shear_degrees"],
"confidence": d["confidence"]}
for d in detections
]
if abs(shear_deg) < 0.08 or confidence < 0.4:
no_correction["detections"] = _all_detections
return img, no_correction
corrected = _apply_shear(img, -shear_deg)
if abs(shear_deg) >= 0.5 and not _dewarp_quality_check(img, corrected):
logger.info("dewarp: quality gate REJECTED correction (%.3f\u00b0) -- "
"projection variance did not improve", shear_deg)
no_correction["detections"] = _all_detections
return img, no_correction
info = {
"method": method,
"shear_degrees": shear_deg,
"confidence": confidence,
"detections": _all_detections,
}
return corrected, info
def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray:
"""Apply shear correction with a manual angle."""
if abs(shear_degrees) < 0.001:
return img
return _apply_shear(img, -shear_degrees)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,388 @@
"""
CV Review LLM — LLM-based OCR correction: prompt building, change detection, streaming.
Handles the LLM review path (REVIEW_ENGINE=llm) and shared utilities like
_entry_needs_review, _is_spurious_change, _diff_batch, and JSON parsing.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
import time
from typing import Dict, List, Tuple
import httpx
logger = logging.getLogger(__name__)
_OLLAMA_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
OLLAMA_REVIEW_MODEL = os.getenv("OLLAMA_REVIEW_MODEL", "qwen3:0.6b")
_REVIEW_BATCH_SIZE = int(os.getenv("OLLAMA_REVIEW_BATCH_SIZE", "20"))
logger.info("LLM review model: %s (batch=%d)", OLLAMA_REVIEW_MODEL, _REVIEW_BATCH_SIZE)
REVIEW_ENGINE = os.getenv("REVIEW_ENGINE", "spell") # "spell" (default) | "llm"
# Regex: entry contains IPA phonetic brackets like "dance [da:ns]"
_HAS_PHONETIC_RE = re.compile(r'\[.*?[\u02c8\u02cc\u02d0\u0283\u0292\u03b8\u00f0\u014b\u0251\u0252\u0254\u0259\u025c\u026a\u028a\u028c\u00e6].*?\]')
# Regex: digit adjacent to a letter -- OCR digit<->letter confusion
_OCR_DIGIT_IN_WORD_RE = re.compile(r'(?<=[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df])[01568]|[01568](?=[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df])')
def _entry_needs_review(entry: Dict) -> bool:
"""Check if an entry should be sent for review.
Sends all non-empty entries that don't have IPA phonetic transcriptions.
"""
en = entry.get("english", "") or ""
de = entry.get("german", "") or ""
if not en.strip() and not de.strip():
return False
if _HAS_PHONETIC_RE.search(en) or _HAS_PHONETIC_RE.search(de):
return False
return True
def _build_llm_prompt(table_lines: List[Dict]) -> str:
"""Build the LLM correction prompt for a batch of entries."""
return f"""Du bist ein OCR-Zeichenkorrektur-Werkzeug fuer Vokabeltabellen (Englisch-Deutsch).
DEINE EINZIGE AUFGABE: Einzelne Zeichen korrigieren, die vom OCR-Scanner als Ziffer statt als Buchstabe erkannt wurden.
NUR diese Korrekturen sind erlaubt:
- Ziffer 8 statt B: "8en" -> "Ben", "8uch" -> "Buch", "8all" -> "Ball"
- Ziffer 0 statt O oder o: "L0ndon" -> "London", "0ld" -> "Old"
- Ziffer 1 statt l oder I: "1ong" -> "long", "Ber1in" -> "Berlin"
- Ziffer 5 statt S oder s: "5tadt" -> "Stadt", "5ee" -> "See"
- Ziffer 6 statt G oder g: "6eld" -> "Geld"
- Senkrechter Strich | statt I oder l: "| want" -> "I want", "|ong" -> "long", "he| p" -> "help"
ABSOLUT VERBOTEN -- aendere NIEMALS:
- Woerter die korrekt geschrieben sind -- auch wenn du eine andere Schreibweise kennst
- Uebersetzungen -- du uebersetzt NICHTS, weder EN->DE noch DE->EN
- Korrekte englische Woerter (en-Spalte) -- auch wenn du eine Bedeutung kennst
- Korrekte deutsche Woerter (de-Spalte) -- auch wenn du sie anders sagen wuerdest
- Eigennamen: Ben, London, China, Africa, Shakespeare usw.
- Abkuerzungen: sth., sb., etc., e.g., i.e., v.t., smb. usw.
- Lautschrift in eckigen Klammern [...] -- diese NIEMALS beruehren
- Beispielsaetze in der ex-Spalte -- NIEMALS aendern
Wenn ein Wort keinen Ziffer-Buchstaben-Fehler enthaelt: gib es UNVERAENDERT zurueck und setze "corrected": false.
Antworte NUR mit dem JSON-Array. Kein Text davor oder danach.
Behalte die exakte Struktur (gleiche Anzahl Eintraege, gleiche Reihenfolge).
/no_think
Eingabe:
{json.dumps(table_lines, ensure_ascii=False, indent=2)}"""
def _is_spurious_change(old_val: str, new_val: str) -> bool:
"""Detect LLM changes that are likely wrong and should be discarded.
Only digit<->letter substitutions (0->O, 1->l, 5->S, 6->G, 8->B) are
legitimate OCR corrections. Everything else is rejected.
"""
if not old_val or not new_val:
return False
if old_val.lower() == new_val.lower():
return True
old_words = old_val.split()
new_words = new_val.split()
if abs(len(old_words) - len(new_words)) > 1:
return True
_OCR_CHAR_MAP = {
'0': set('oOgG'),
'1': set('lLiI'),
'5': set('sS'),
'6': set('gG'),
'8': set('bB'),
'|': set('lLiI1'),
'l': set('iI|1'),
}
has_valid_fix = False
if len(old_val) == len(new_val):
for oc, nc in zip(old_val, new_val):
if oc != nc:
if oc in _OCR_CHAR_MAP and nc in _OCR_CHAR_MAP[oc]:
has_valid_fix = True
elif nc in _OCR_CHAR_MAP and oc in _OCR_CHAR_MAP[nc]:
has_valid_fix = True
else:
_OCR_SUSPICIOUS_RE = re.compile(r'[|01568]')
if abs(len(old_val) - len(new_val)) <= 1 and _OCR_SUSPICIOUS_RE.search(old_val):
has_valid_fix = True
if not has_valid_fix:
return True
return False
def _diff_batch(originals: List[Dict], corrected: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""Compare original entries with LLM-corrected ones, return (changes, corrected_entries)."""
changes = []
entries_out = []
for i, orig in enumerate(originals):
if i < len(corrected):
c = corrected[i]
entry = dict(orig)
for field_name, key in [("english", "en"), ("german", "de"), ("example", "ex")]:
new_val = c.get(key, "").strip()
old_val = (orig.get(field_name, "") or "").strip()
if new_val and new_val != old_val:
if _is_spurious_change(old_val, new_val):
continue
changes.append({
"row_index": orig.get("row_index", i),
"field": field_name,
"old": old_val,
"new": new_val,
})
entry[field_name] = new_val
entry["llm_corrected"] = True
entries_out.append(entry)
else:
entries_out.append(dict(orig))
return changes, entries_out
def _sanitize_for_json(text: str) -> str:
"""Remove or escape control characters that break JSON parsing."""
return re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', ' ', text)
def _parse_llm_json_array(text: str) -> List[Dict]:
"""Extract JSON array from LLM response (handles markdown fences and qwen3 think-tags)."""
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'```json\s*', '', text)
text = re.sub(r'```\s*', '', text)
text = _sanitize_for_json(text)
match = re.search(r'\[.*\]', text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except (ValueError, json.JSONDecodeError) as e:
logger.warning("LLM review: JSON parse failed: %s | raw snippet: %.200s", e, match.group()[:200])
else:
logger.warning("LLM review: no JSON array found in response (%.200s)", text[:200])
return []
async def llm_review_entries(
entries: List[Dict],
model: str = None,
) -> Dict:
"""OCR error correction. Uses spell-checker (REVIEW_ENGINE=spell) or LLM (REVIEW_ENGINE=llm)."""
from cv_review_spell import spell_review_entries_sync, _SPELL_AVAILABLE
if REVIEW_ENGINE == "spell" and _SPELL_AVAILABLE:
return spell_review_entries_sync(entries)
if REVIEW_ENGINE == "spell" and not _SPELL_AVAILABLE:
logger.warning("REVIEW_ENGINE=spell but pyspellchecker not installed, using LLM")
model = model or OLLAMA_REVIEW_MODEL
reviewable = [(i, e) for i, e in enumerate(entries) if _entry_needs_review(e)]
if not reviewable:
return {
"entries_original": entries,
"entries_corrected": [dict(e) for e in entries],
"changes": [],
"skipped_count": len(entries),
"model_used": model,
"duration_ms": 0,
}
review_entries = [e for _, e in reviewable]
table_lines = [
{"row": e.get("row_index", 0), "en": e.get("english", ""), "de": e.get("german", ""), "ex": e.get("example", "")}
for e in review_entries
]
logger.info("LLM review: sending %d/%d entries to %s (skipped %d without digit-pattern)",
len(review_entries), len(entries), model, len(entries) - len(reviewable))
prompt = _build_llm_prompt(table_lines)
t0 = time.time()
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(
f"{_OLLAMA_URL}/api/chat",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"think": False,
"options": {"temperature": 0.1, "num_predict": 8192},
},
)
resp.raise_for_status()
content = resp.json().get("message", {}).get("content", "")
duration_ms = int((time.time() - t0) * 1000)
logger.info("LLM review: response in %dms, raw length=%d chars", duration_ms, len(content))
corrected = _parse_llm_json_array(content)
changes, corrected_entries = _diff_batch(review_entries, corrected)
all_corrected = [dict(e) for e in entries]
for batch_idx, (orig_idx, _) in enumerate(reviewable):
if batch_idx < len(corrected_entries):
all_corrected[orig_idx] = corrected_entries[batch_idx]
return {
"entries_original": entries,
"entries_corrected": all_corrected,
"changes": changes,
"skipped_count": len(entries) - len(reviewable),
"model_used": model,
"duration_ms": duration_ms,
}
async def llm_review_entries_streaming(
entries: List[Dict],
model: str = None,
batch_size: int = _REVIEW_BATCH_SIZE,
):
"""Async generator: yield SSE events. Uses spell-checker or LLM depending on REVIEW_ENGINE.
Phase 0 (always): Run _fix_character_confusion and emit any changes.
"""
from cv_ocr_engines import _fix_character_confusion
from cv_review_spell import spell_review_entries_streaming, _SPELL_AVAILABLE
_CONF_FIELDS = ('english', 'german', 'example')
originals = [{f: e.get(f, '') for f in _CONF_FIELDS} for e in entries]
_fix_character_confusion(entries)
char_changes = [
{'row_index': i, 'field': f, 'old': originals[i][f], 'new': entries[i].get(f, '')}
for i in range(len(entries))
for f in _CONF_FIELDS
if originals[i][f] != entries[i].get(f, '')
]
if REVIEW_ENGINE == "spell" and _SPELL_AVAILABLE:
_meta_sent = False
async for event in spell_review_entries_streaming(entries, batch_size):
yield event
if not _meta_sent and event.get('type') == 'meta' and char_changes:
_meta_sent = True
yield {
'type': 'batch',
'changes': char_changes,
'entries_reviewed': sorted({c['row_index'] for c in char_changes}),
'progress': {'current': 0, 'total': len(entries)},
}
return
if REVIEW_ENGINE == "spell" and not _SPELL_AVAILABLE:
logger.warning("REVIEW_ENGINE=spell but pyspellchecker not installed, using LLM")
# LLM path
if char_changes:
yield {
'type': 'batch',
'changes': char_changes,
'entries_reviewed': sorted({c['row_index'] for c in char_changes}),
'progress': {'current': 0, 'total': len(entries)},
}
model = model or OLLAMA_REVIEW_MODEL
reviewable = []
skipped_indices = []
for i, e in enumerate(entries):
if _entry_needs_review(e):
reviewable.append((i, e))
else:
skipped_indices.append(i)
total_to_review = len(reviewable)
yield {
"type": "meta",
"total_entries": len(entries),
"to_review": total_to_review,
"skipped": len(skipped_indices),
"model": model,
"batch_size": batch_size,
}
all_changes = []
all_corrected = [dict(e) for e in entries]
total_duration_ms = 0
reviewed_count = 0
for batch_start in range(0, total_to_review, batch_size):
batch_items = reviewable[batch_start:batch_start + batch_size]
batch_entries = [e for _, e in batch_items]
table_lines = [
{"row": e.get("row_index", 0), "en": e.get("english", ""), "de": e.get("german", ""), "ex": e.get("example", "")}
for e in batch_entries
]
prompt = _build_llm_prompt(table_lines)
logger.info("LLM review streaming: batch %d -- sending %d entries to %s",
batch_start // batch_size, len(batch_entries), model)
t0 = time.time()
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(
f"{_OLLAMA_URL}/api/chat",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"think": False,
"options": {"temperature": 0.1, "num_predict": 8192},
},
)
resp.raise_for_status()
content = resp.json().get("message", {}).get("content", "")
batch_ms = int((time.time() - t0) * 1000)
total_duration_ms += batch_ms
corrected = _parse_llm_json_array(content)
batch_changes, batch_corrected = _diff_batch(batch_entries, corrected)
for batch_idx, (orig_idx, _) in enumerate(batch_items):
if batch_idx < len(batch_corrected):
all_corrected[orig_idx] = batch_corrected[batch_idx]
all_changes.extend(batch_changes)
reviewed_count += len(batch_items)
yield {
"type": "batch",
"batch_index": batch_start // batch_size,
"entries_reviewed": [e.get("row_index", 0) for _, e in batch_items],
"changes": batch_changes,
"duration_ms": batch_ms,
"progress": {"current": reviewed_count, "total": total_to_review},
}
yield {
"type": "complete",
"changes": all_changes,
"model_used": model,
"duration_ms": total_duration_ms,
"total_entries": len(entries),
"reviewed": total_to_review,
"skipped": len(skipped_indices),
"corrections_found": len(all_changes),
"entries_corrected": all_corrected,
}

View File

@@ -0,0 +1,430 @@
"""
CV Review Pipeline — Multi-pass OCR, line alignment, LLM post-correction, and orchestration.
Stages 6-8 of the CV vocabulary pipeline plus the main orchestrator.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from typing import Any, Dict, List, Optional
import numpy as np
from cv_vocab_types import (
CV_PIPELINE_AVAILABLE,
PageRegion,
PipelineResult,
VocabRow,
)
from cv_preprocessing import (
deskew_image,
dewarp_image,
render_image_high_res,
render_pdf_high_res,
)
from cv_layout import (
analyze_layout,
create_layout_image,
create_ocr_image,
)
from cv_ocr_engines import (
_group_words_into_lines,
)
logger = logging.getLogger(__name__)
try:
import cv2
except ImportError:
cv2 = None # type: ignore[assignment]
try:
import pytesseract
from PIL import Image
except ImportError:
pytesseract = None # type: ignore[assignment]
Image = None # type: ignore[assignment,misc]
# =============================================================================
# Stage 6: Multi-Pass OCR
# =============================================================================
def ocr_region(ocr_img: np.ndarray, region: PageRegion, lang: str,
psm: int, fallback_psm: Optional[int] = None,
min_confidence: float = 40.0) -> List[Dict[str, Any]]:
"""Run Tesseract OCR on a specific region with given PSM.
Args:
ocr_img: Binarized full-page image.
region: Region to crop and OCR.
lang: Tesseract language string.
psm: Page Segmentation Mode.
fallback_psm: If confidence too low, retry with this PSM per line.
min_confidence: Minimum average confidence before fallback.
Returns:
List of word dicts with text, position, confidence.
"""
crop = ocr_img[region.y:region.y + region.height,
region.x:region.x + region.width]
if crop.size == 0:
return []
pil_img = Image.fromarray(crop)
config = f'--psm {psm} --oem 3'
try:
data = pytesseract.image_to_data(pil_img, lang=lang, config=config,
output_type=pytesseract.Output.DICT)
except Exception as e:
logger.warning(f"Tesseract failed for region {region.type}: {e}")
return []
words = []
for i in range(len(data['text'])):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if not text or conf < 10:
continue
words.append({
'text': text,
'left': data['left'][i] + region.x,
'top': data['top'][i] + region.y,
'width': data['width'][i],
'height': data['height'][i],
'conf': conf,
'region_type': region.type,
})
if words and fallback_psm is not None:
avg_conf = sum(w['conf'] for w in words) / len(words)
if avg_conf < min_confidence:
logger.info(f"Region {region.type}: avg confidence {avg_conf:.0f}% < {min_confidence}%, "
f"trying fallback PSM {fallback_psm}")
words = _ocr_region_line_by_line(ocr_img, region, lang, fallback_psm)
return words
def _ocr_region_line_by_line(ocr_img: np.ndarray, region: PageRegion,
lang: str, psm: int) -> List[Dict[str, Any]]:
"""OCR a region line by line (fallback for low-confidence regions)."""
crop = ocr_img[region.y:region.y + region.height,
region.x:region.x + region.width]
if crop.size == 0:
return []
inv = cv2.bitwise_not(crop)
h_proj = np.sum(inv, axis=1)
threshold = np.max(h_proj) * 0.05 if np.max(h_proj) > 0 else 0
lines = []
in_text = False
line_start = 0
for y in range(len(h_proj)):
if h_proj[y] > threshold and not in_text:
line_start = y
in_text = True
elif h_proj[y] <= threshold and in_text:
if y - line_start > 5:
lines.append((line_start, y))
in_text = False
if in_text and len(h_proj) - line_start > 5:
lines.append((line_start, len(h_proj)))
all_words = []
config = f'--psm {psm} --oem 3'
for line_y_start, line_y_end in lines:
pad = 3
y1 = max(0, line_y_start - pad)
y2 = min(crop.shape[0], line_y_end + pad)
line_crop = crop[y1:y2, :]
if line_crop.size == 0:
continue
pil_img = Image.fromarray(line_crop)
try:
data = pytesseract.image_to_data(pil_img, lang=lang, config=config,
output_type=pytesseract.Output.DICT)
except Exception:
continue
for i in range(len(data['text'])):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if not text or conf < 10:
continue
all_words.append({
'text': text,
'left': data['left'][i] + region.x,
'top': data['top'][i] + region.y + y1,
'width': data['width'][i],
'height': data['height'][i],
'conf': conf,
'region_type': region.type,
})
return all_words
def run_multi_pass_ocr(ocr_img: np.ndarray,
regions: List[PageRegion],
lang: str = "eng+deu") -> Dict[str, List[Dict]]:
"""Run OCR on each detected region with optimized settings."""
results: Dict[str, List[Dict]] = {}
_ocr_skip = {'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
for region in regions:
if region.type in _ocr_skip:
continue
if region.type == 'column_en':
words = ocr_region(ocr_img, region, lang='eng', psm=4)
elif region.type == 'column_de':
words = ocr_region(ocr_img, region, lang='deu', psm=4)
elif region.type == 'column_example':
words = ocr_region(ocr_img, region, lang=lang, psm=6,
fallback_psm=7, min_confidence=40.0)
else:
words = ocr_region(ocr_img, region, lang=lang, psm=6)
results[region.type] = words
logger.info(f"OCR {region.type}: {len(words)} words")
return results
# =============================================================================
# Stage 7: Line Alignment -> Vocabulary Entries
# =============================================================================
def match_lines_to_vocab(ocr_results: Dict[str, List[Dict]],
regions: List[PageRegion],
y_tolerance_px: int = 25) -> List[VocabRow]:
"""Align OCR results from different columns into vocabulary rows."""
if 'column_en' not in ocr_results and 'column_de' not in ocr_results:
logger.info("match_lines_to_vocab: no column_en/column_de in OCR results, returning empty")
return []
en_lines = _group_words_into_lines(ocr_results.get('column_en', []), y_tolerance_px)
de_lines = _group_words_into_lines(ocr_results.get('column_de', []), y_tolerance_px)
ex_lines = _group_words_into_lines(ocr_results.get('column_example', []), y_tolerance_px)
def line_y_center(line: List[Dict]) -> float:
return sum(w['top'] + w['height'] / 2 for w in line) / len(line)
def line_text(line: List[Dict]) -> str:
return ' '.join(w['text'] for w in line)
def line_confidence(line: List[Dict]) -> float:
return sum(w['conf'] for w in line) / len(line) if line else 0
vocab_rows: List[VocabRow] = []
for en_line in en_lines:
en_y = line_y_center(en_line)
en_text = line_text(en_line)
en_conf = line_confidence(en_line)
if len(en_text.strip()) < 2:
continue
de_text = ""
de_conf = 0.0
best_de_dist = float('inf')
best_de_idx = -1
for idx, de_line in enumerate(de_lines):
dist = abs(line_y_center(de_line) - en_y)
if dist < y_tolerance_px and dist < best_de_dist:
best_de_dist = dist
best_de_idx = idx
if best_de_idx >= 0:
de_text = line_text(de_lines[best_de_idx])
de_conf = line_confidence(de_lines[best_de_idx])
ex_text = ""
ex_conf = 0.0
best_ex_dist = float('inf')
best_ex_idx = -1
for idx, ex_line in enumerate(ex_lines):
dist = abs(line_y_center(ex_line) - en_y)
if dist < y_tolerance_px and dist < best_ex_dist:
best_ex_dist = dist
best_ex_idx = idx
if best_ex_idx >= 0:
ex_text = line_text(ex_lines[best_ex_idx])
ex_conf = line_confidence(ex_lines[best_ex_idx])
avg_conf = en_conf
conf_count = 1
if de_conf > 0:
avg_conf += de_conf
conf_count += 1
if ex_conf > 0:
avg_conf += ex_conf
conf_count += 1
vocab_rows.append(VocabRow(
english=en_text.strip(),
german=de_text.strip(),
example=ex_text.strip(),
confidence=avg_conf / conf_count,
y_position=int(en_y),
))
# Handle multi-line wrapping in example column
matched_ex_ys = set()
for row in vocab_rows:
if row.example:
matched_ex_ys.add(row.y_position)
for ex_line in ex_lines:
ex_y = line_y_center(ex_line)
already_matched = any(abs(ex_y - y) < y_tolerance_px for y in matched_ex_ys)
if already_matched:
continue
best_row = None
best_dist = float('inf')
for row in vocab_rows:
dist = ex_y - row.y_position
if 0 < dist < y_tolerance_px * 3 and dist < best_dist:
best_dist = dist
best_row = row
if best_row:
continuation = line_text(ex_line).strip()
if continuation:
best_row.example = (best_row.example + " " + continuation).strip()
vocab_rows.sort(key=lambda r: r.y_position)
return vocab_rows
# =============================================================================
# Stage 8: Optional LLM Post-Correction
# =============================================================================
async def llm_post_correct(img: np.ndarray, vocab_rows: List[VocabRow],
confidence_threshold: float = 50.0,
enabled: bool = False) -> List[VocabRow]:
"""Optionally send low-confidence regions to Qwen-VL for correction."""
if not enabled:
return vocab_rows
logger.info(f"LLM post-correction skipped (not yet implemented)")
return vocab_rows
# =============================================================================
# Orchestrator
# =============================================================================
async def run_cv_pipeline(
pdf_data: Optional[bytes] = None,
image_data: Optional[bytes] = None,
page_number: int = 0,
zoom: float = 3.0,
enable_dewarp: bool = True,
enable_llm_correction: bool = False,
lang: str = "eng+deu",
) -> PipelineResult:
"""Run the complete CV document reconstruction pipeline."""
if not CV_PIPELINE_AVAILABLE:
return PipelineResult(error="CV pipeline not available (OpenCV or Tesseract missing)")
result = PipelineResult()
total_start = time.time()
try:
# Stage 1: Render
t = time.time()
if pdf_data:
img = render_pdf_high_res(pdf_data, page_number, zoom)
elif image_data:
img = render_image_high_res(image_data)
else:
return PipelineResult(error="No input data (pdf_data or image_data required)")
result.stages['render'] = round(time.time() - t, 2)
result.image_width = img.shape[1]
result.image_height = img.shape[0]
logger.info(f"Stage 1 (render): {img.shape[1]}x{img.shape[0]} in {result.stages['render']}s")
# Stage 2: Deskew
t = time.time()
img, angle = deskew_image(img)
result.stages['deskew'] = round(time.time() - t, 2)
logger.info(f"Stage 2 (deskew): {angle:.2f}\u00b0 in {result.stages['deskew']}s")
# Stage 3: Dewarp
if enable_dewarp:
t = time.time()
img, _dewarp_info = dewarp_image(img)
result.stages['dewarp'] = round(time.time() - t, 2)
# Stage 4: Dual image preparation
t = time.time()
ocr_img = create_ocr_image(img)
layout_img = create_layout_image(img)
result.stages['image_prep'] = round(time.time() - t, 2)
# Stage 5: Layout analysis
t = time.time()
regions = analyze_layout(layout_img, ocr_img)
result.stages['layout'] = round(time.time() - t, 2)
result.columns_detected = len([r for r in regions if r.type.startswith('column')])
logger.info(f"Stage 5 (layout): {result.columns_detected} columns in {result.stages['layout']}s")
# Stage 6: Multi-pass OCR
t = time.time()
ocr_results = run_multi_pass_ocr(ocr_img, regions, lang)
result.stages['ocr'] = round(time.time() - t, 2)
total_words = sum(len(w) for w in ocr_results.values())
result.word_count = total_words
logger.info(f"Stage 6 (OCR): {total_words} words in {result.stages['ocr']}s")
# Stage 7: Line alignment
t = time.time()
vocab_rows = match_lines_to_vocab(ocr_results, regions)
result.stages['alignment'] = round(time.time() - t, 2)
# Stage 8: Optional LLM correction
if enable_llm_correction:
t = time.time()
vocab_rows = await llm_post_correct(img, vocab_rows)
result.stages['llm_correction'] = round(time.time() - t, 2)
# Convert to output format
result.vocabulary = [
{
"english": row.english,
"german": row.german,
"example": row.example,
"confidence": round(row.confidence, 1),
}
for row in vocab_rows
if row.english or row.german
]
result.duration_seconds = round(time.time() - total_start, 2)
logger.info(f"CV Pipeline complete: {len(result.vocabulary)} entries in {result.duration_seconds}s")
except Exception as e:
logger.error(f"CV Pipeline error: {e}")
import traceback
logger.debug(traceback.format_exc())
result.error = str(e)
result.duration_seconds = round(time.time() - total_start, 2)
return result

View File

@@ -0,0 +1,315 @@
"""
CV Review Spell — Rule-based OCR spell correction (no LLM).
Provides dictionary-backed digit-to-letter substitution, umlaut correction,
general spell correction, merged-word splitting, and page-ref normalization.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
import time
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
try:
from spellchecker import SpellChecker as _SpellChecker
_en_spell = _SpellChecker(language='en', distance=1)
_de_spell = _SpellChecker(language='de', distance=1)
_SPELL_AVAILABLE = True
logger.info("pyspellchecker loaded (EN+DE)")
except ImportError:
_SPELL_AVAILABLE = False
_en_spell = None # type: ignore[assignment]
_de_spell = None # type: ignore[assignment]
logger.warning("pyspellchecker not installed")
# ---- Page-Ref Normalization ----
# Normalizes OCR variants like "p-60", "p 61", "p60" -> "p.60"
_PAGE_REF_RE = re.compile(r'\bp[\s\-]?(\d+)', re.IGNORECASE)
def _normalize_page_ref(text: str) -> str:
"""Normalize page references: 'p-60' / 'p 61' / 'p60' -> 'p.60'."""
if not text:
return text
return _PAGE_REF_RE.sub(lambda m: f"p.{m.group(1)}", text)
# Suspicious OCR chars -> ordered list of most-likely correct replacements
_SPELL_SUBS: Dict[str, List[str]] = {
'0': ['O', 'o'],
'1': ['l', 'I'],
'5': ['S', 's'],
'6': ['G', 'g'],
'8': ['B', 'b'],
'|': ['I', 'l', '1'],
}
_SPELL_SUSPICIOUS = frozenset(_SPELL_SUBS.keys())
# Tokenizer: word tokens (letters + pipe) alternating with separators
_SPELL_TOKEN_RE = re.compile(r'([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df|]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df|]*)')
def _spell_dict_knows(word: str) -> bool:
"""True if word is known in EN or DE dictionary."""
if not _SPELL_AVAILABLE:
return False
w = word.lower()
return bool(_en_spell.known([w])) or bool(_de_spell.known([w]))
def _try_split_merged_word(token: str) -> Optional[str]:
"""Try to split a merged word like 'atmyschool' into 'at my school'.
Uses dynamic programming to find the shortest sequence of dictionary
words that covers the entire token. Only returns a result when the
split produces at least 2 words and ALL parts are known dictionary words.
Preserves original capitalisation by mapping back to the input string.
"""
if not _SPELL_AVAILABLE or len(token) < 4:
return None
lower = token.lower()
n = len(lower)
# dp[i] = (word_lengths_list, score) for best split of lower[:i], or None
dp: list = [None] * (n + 1)
dp[0] = ([], 0)
for i in range(1, n + 1):
for j in range(max(0, i - 20), i):
if dp[j] is None:
continue
candidate = lower[j:i]
word_len = i - j
if word_len == 1 and candidate not in ('a', 'i'):
continue
if _spell_dict_knows(candidate):
prev_words, prev_sq = dp[j]
new_words = prev_words + [word_len]
new_sq = prev_sq + word_len * word_len
new_key = (-len(new_words), new_sq)
if dp[i] is None:
dp[i] = (new_words, new_sq)
else:
old_key = (-len(dp[i][0]), dp[i][1])
if new_key >= old_key:
dp[i] = (new_words, new_sq)
if dp[n] is None or len(dp[n][0]) < 2:
return None
result = []
pos = 0
for wlen in dp[n][0]:
result.append(token[pos:pos + wlen])
pos += wlen
logger.debug("Split merged word: %r -> %r", token, " ".join(result))
return " ".join(result)
def _spell_fix_token(token: str, field: str = "") -> Optional[str]:
"""Return corrected form of token, or None if no fix needed/possible.
*field* is 'english' or 'german' -- used to pick the right dictionary.
"""
has_suspicious = any(ch in _SPELL_SUSPICIOUS for ch in token)
# 1. Already known word -> no fix needed
if _spell_dict_knows(token):
return None
# 2. Digit/pipe substitution
if has_suspicious:
if token == '|':
return 'I'
for i, ch in enumerate(token):
if ch not in _SPELL_SUBS:
continue
for replacement in _SPELL_SUBS[ch]:
candidate = token[:i] + replacement + token[i + 1:]
if _spell_dict_knows(candidate):
return candidate
first = token[0]
if first in _SPELL_SUBS and len(token) >= 2:
rest = token[1:]
if rest.isalpha() and rest.islower():
candidate = _SPELL_SUBS[first][0] + rest
if not candidate[0].isdigit():
return candidate
# 3. OCR umlaut confusion
if len(token) >= 3 and token.isalpha() and field == "german":
_UMLAUT_SUBS = {'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc'}
for i, ch in enumerate(token):
if ch in _UMLAUT_SUBS:
candidate = token[:i] + _UMLAUT_SUBS[ch] + token[i + 1:]
if _spell_dict_knows(candidate):
return candidate
# 4. General spell correction for unknown words (no digits/pipes)
if not has_suspicious and len(token) >= 3 and token.isalpha():
spell = _en_spell if field == "english" else _de_spell if field == "german" else None
if spell is not None:
correction = spell.correction(token.lower())
if correction and correction != token.lower():
if token[0].isupper():
correction = correction[0].upper() + correction[1:]
if _spell_dict_knows(correction):
return correction
# 5. Merged-word split
if len(token) >= 4 and token.isalpha():
split = _try_split_merged_word(token)
if split:
return split
return None
def _spell_fix_field(text: str, field: str = "") -> Tuple[str, bool]:
"""Apply OCR corrections to a text field. Returns (fixed_text, was_changed)."""
if not text:
return text, False
has_suspicious = any(ch in text for ch in _SPELL_SUSPICIOUS)
if not has_suspicious and not any(c.isalpha() for c in text):
return text, False
# Pattern: | immediately before . or , -> numbered list prefix
fixed = re.sub(r'(?<!\w)\|(?=[.,])', '1', text) if has_suspicious else text
changed = fixed != text
# Tokenize and fix word by word
parts: List[str] = []
pos = 0
for m in _SPELL_TOKEN_RE.finditer(fixed):
token, sep = m.group(1), m.group(2)
correction = _spell_fix_token(token, field=field)
if correction:
parts.append(correction)
changed = True
else:
parts.append(token)
parts.append(sep)
pos = m.end()
if pos < len(fixed):
parts.append(fixed[pos:])
return ''.join(parts), changed
def spell_review_entries_sync(entries: List[Dict]) -> Dict:
"""Rule-based OCR correction: spell-checker + structural heuristics.
Deterministic -- never translates, never touches IPA, never hallucinates.
Uses SmartSpellChecker for language-aware corrections with context-based
disambiguation (a/I), multi-digit substitution, and cross-language guard.
"""
from cv_review_llm import _entry_needs_review
t0 = time.time()
changes: List[Dict] = []
all_corrected: List[Dict] = []
# Use SmartSpellChecker if available
_smart = None
try:
from smart_spell import SmartSpellChecker
_smart = SmartSpellChecker()
logger.debug("spell_review: using SmartSpellChecker")
except Exception:
logger.debug("spell_review: SmartSpellChecker not available, using legacy")
_LANG_MAP = {"english": "en", "german": "de", "example": "auto"}
for i, entry in enumerate(entries):
e = dict(entry)
# Page-ref normalization
old_ref = (e.get("source_page") or "").strip()
if old_ref:
new_ref = _normalize_page_ref(old_ref)
if new_ref != old_ref:
changes.append({
"row_index": e.get("row_index", i),
"field": "source_page",
"old": old_ref,
"new": new_ref,
})
e["source_page"] = new_ref
e["llm_corrected"] = True
if not _entry_needs_review(e):
all_corrected.append(e)
continue
for field_name in ("english", "german", "example"):
old_val = (e.get(field_name) or "").strip()
if not old_val:
continue
if _smart:
lang_code = _LANG_MAP.get(field_name, "en")
result = _smart.correct_text(old_val, lang=lang_code)
new_val = result.corrected
was_changed = result.changed
else:
lang = "german" if field_name in ("german", "example") else "english"
new_val, was_changed = _spell_fix_field(old_val, field=lang)
if was_changed and new_val != old_val:
changes.append({
"row_index": e.get("row_index", i),
"field": field_name,
"old": old_val,
"new": new_val,
})
e[field_name] = new_val
e["llm_corrected"] = True
all_corrected.append(e)
duration_ms = int((time.time() - t0) * 1000)
model_name = "smart-spell-checker" if _smart else "spell-checker"
return {
"entries_original": entries,
"entries_corrected": all_corrected,
"changes": changes,
"skipped_count": 0,
"model_used": model_name,
"duration_ms": duration_ms,
}
async def spell_review_entries_streaming(entries: List[Dict], batch_size: int = 50):
"""Async generator yielding SSE-compatible events for spell-checker review."""
total = len(entries)
yield {
"type": "meta",
"total_entries": total,
"to_review": total,
"skipped": 0,
"model": "spell-checker",
"batch_size": batch_size,
}
result = spell_review_entries_sync(entries)
changes = result["changes"]
yield {
"type": "batch",
"batch_index": 0,
"entries_reviewed": [e.get("row_index", i) for i, e in enumerate(entries)],
"changes": changes,
"duration_ms": result["duration_ms"],
"progress": {"current": total, "total": total},
}
yield {
"type": "complete",
"changes": changes,
"model_used": "spell-checker",
"duration_ms": result["duration_ms"],
"total_entries": total,
"reviewed": total,
"skipped": 0,
"corrections_found": len(changes),
"entries_corrected": result["entries_corrected"],
}

View File

@@ -0,0 +1,492 @@
"""
Grid Editor — column detection, cross-column splitting, marker merging.
Split from grid_editor_helpers.py for maintainability.
All functions are pure computation — no HTTP, DB, or session side effects.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Cross-column word splitting
# ---------------------------------------------------------------------------
_spell_cache: Optional[Any] = None
_spell_loaded = False
def _is_recognized_word(text: str) -> bool:
"""Check if *text* is a recognized German or English word.
Uses the spellchecker library (same as cv_syllable_detect.py).
Returns True for real words like "oder", "Kabel", "Zeitung".
Returns False for OCR merge artifacts like "sichzie", "dasZimmer".
"""
global _spell_cache, _spell_loaded
if not text or len(text) < 2:
return False
if not _spell_loaded:
_spell_loaded = True
try:
from spellchecker import SpellChecker
_spell_cache = SpellChecker(language="de")
except Exception:
pass
if _spell_cache is None:
return False
return text.lower() in _spell_cache
def _split_cross_column_words(
words: List[Dict],
columns: List[Dict],
) -> List[Dict]:
"""Split word boxes that span across column boundaries.
When OCR merges adjacent words from different columns (e.g. "sichzie"
spanning Col 1 and Col 2, or "dasZimmer" crossing the boundary),
split the word box at the column boundary so each piece is assigned
to the correct column.
Only splits when:
- The word has significant overlap (>15% of its width) on both sides
- AND the word is not a recognized real word (OCR merge artifact), OR
the word contains a case transition (lowercase->uppercase) near the
boundary indicating two merged words like "dasZimmer".
"""
if len(columns) < 2:
return words
# Column boundaries = midpoints between adjacent column edges
boundaries = []
for i in range(len(columns) - 1):
boundary = (columns[i]["x_max"] + columns[i + 1]["x_min"]) / 2
boundaries.append(boundary)
new_words: List[Dict] = []
split_count = 0
for w in words:
w_left = w["left"]
w_width = w["width"]
w_right = w_left + w_width
text = (w.get("text") or "").strip()
if not text or len(text) < 4 or w_width < 10:
new_words.append(w)
continue
# Find the first boundary this word straddles significantly
split_boundary = None
for b in boundaries:
if w_left < b < w_right:
left_part = b - w_left
right_part = w_right - b
# Both sides must have at least 15% of the word width
if left_part > w_width * 0.15 and right_part > w_width * 0.15:
split_boundary = b
break
if split_boundary is None:
new_words.append(w)
continue
# Compute approximate split position in the text.
left_width = split_boundary - w_left
split_ratio = left_width / w_width
approx_pos = len(text) * split_ratio
# Strategy 1: look for a case transition (lowercase->uppercase) near
# the approximate split point — e.g. "dasZimmer" splits at 'Z'.
split_char = None
search_lo = max(1, int(approx_pos) - 3)
search_hi = min(len(text), int(approx_pos) + 2)
for i in range(search_lo, search_hi):
if text[i - 1].islower() and text[i].isupper():
split_char = i
break
# Strategy 2: if no case transition, only split if the whole word
# is NOT a real word (i.e. it's an OCR merge artifact like "sichzie").
# Real words like "oder", "Kabel", "Zeitung" must not be split.
if split_char is None:
clean = re.sub(r"[,;:.!?]+$", "", text) # strip trailing punct
if _is_recognized_word(clean):
new_words.append(w)
continue
# Not a real word — use floor of proportional position
split_char = max(1, min(len(text) - 1, int(approx_pos)))
left_text = text[:split_char].rstrip()
right_text = text[split_char:].lstrip()
if len(left_text) < 2 or len(right_text) < 2:
new_words.append(w)
continue
right_width = w_width - round(left_width)
new_words.append({
**w,
"text": left_text,
"width": round(left_width),
})
new_words.append({
**w,
"text": right_text,
"left": round(split_boundary),
"width": right_width,
})
split_count += 1
logger.info(
"split cross-column word %r -> %r + %r at boundary %.0f",
text, left_text, right_text, split_boundary,
)
if split_count:
logger.info("split %d cross-column word(s)", split_count)
return new_words
def _cluster_columns_by_alignment(
words: List[Dict],
zone_w: int,
rows: List[Dict],
) -> List[Dict[str, Any]]:
"""Detect columns by clustering left-edge alignment across rows.
Hybrid approach:
1. Group words by row, find "group start" positions within each row
(words preceded by a large gap or first word in row)
2. Cluster group-start left-edges by X-proximity across rows
3. Filter by row coverage (how many rows have a group start here)
4. Merge nearby clusters
5. Build column boundaries
This filters out mid-phrase word positions (e.g. IPA transcriptions,
second words in multi-word entries) by only considering positions
where a new word group begins within a row.
"""
if not words or not rows:
return []
total_rows = len(rows)
if total_rows == 0:
return []
# --- Group words by row ---
row_words: Dict[int, List[Dict]] = {}
for w in words:
y_center = w["top"] + w["height"] / 2
best = min(rows, key=lambda r: abs(r["y_center"] - y_center))
row_words.setdefault(best["index"], []).append(w)
# --- Compute adaptive gap threshold for group-start detection ---
all_gaps: List[float] = []
for ri, rw_list in row_words.items():
sorted_rw = sorted(rw_list, key=lambda w: w["left"])
for i in range(len(sorted_rw) - 1):
right = sorted_rw[i]["left"] + sorted_rw[i]["width"]
gap = sorted_rw[i + 1]["left"] - right
if gap > 0:
all_gaps.append(gap)
if all_gaps:
sorted_gaps = sorted(all_gaps)
median_gap = sorted_gaps[len(sorted_gaps) // 2]
heights = [w["height"] for w in words if w.get("height", 0) > 0]
median_h = sorted(heights)[len(heights) // 2] if heights else 25
# For small word counts (boxes, sub-zones): PaddleOCR returns
# multi-word blocks, so ALL inter-word gaps are potential column
# boundaries. Use a low threshold based on word height — any gap
# wider than ~1x median word height is a column separator.
if len(words) <= 60:
gap_threshold = max(median_h * 1.0, 25)
logger.info(
"alignment columns (small zone): gap_threshold=%.0f "
"(median_h=%.0f, %d words, %d gaps: %s)",
gap_threshold, median_h, len(words), len(sorted_gaps),
[int(g) for g in sorted_gaps[:10]],
)
else:
# Standard approach for large zones (full pages)
gap_threshold = max(median_gap * 3, median_h * 1.5, 30)
# Cap at 25% of zone width
max_gap = zone_w * 0.25
if gap_threshold > max_gap > 30:
logger.info("alignment columns: capping gap_threshold %.0f -> %.0f (25%% of zone_w=%d)", gap_threshold, max_gap, zone_w)
gap_threshold = max_gap
else:
gap_threshold = 50
# --- Find group-start positions (left-edges that begin a new column) ---
start_positions: List[tuple] = [] # (left_edge, row_index)
for ri, rw_list in row_words.items():
sorted_rw = sorted(rw_list, key=lambda w: w["left"])
# First word in row is always a group start
start_positions.append((sorted_rw[0]["left"], ri))
for i in range(1, len(sorted_rw)):
right_prev = sorted_rw[i - 1]["left"] + sorted_rw[i - 1]["width"]
gap = sorted_rw[i]["left"] - right_prev
if gap >= gap_threshold:
start_positions.append((sorted_rw[i]["left"], ri))
start_positions.sort(key=lambda x: x[0])
logger.info(
"alignment columns: %d group-start positions from %d words "
"(gap_threshold=%.0f, %d rows)",
len(start_positions), len(words), gap_threshold, total_rows,
)
if not start_positions:
x_min = min(w["left"] for w in words)
x_max = max(w["left"] + w["width"] for w in words)
return [{"index": 0, "type": "column_text", "x_min": x_min, "x_max": x_max}]
# --- Cluster group-start positions by X-proximity ---
tolerance = max(10, int(zone_w * 0.01))
clusters: List[Dict[str, Any]] = []
cur_edges = [start_positions[0][0]]
cur_rows = {start_positions[0][1]}
for left, row_idx in start_positions[1:]:
if left - cur_edges[-1] <= tolerance:
cur_edges.append(left)
cur_rows.add(row_idx)
else:
clusters.append({
"mean_x": int(sum(cur_edges) / len(cur_edges)),
"min_edge": min(cur_edges),
"max_edge": max(cur_edges),
"count": len(cur_edges),
"distinct_rows": len(cur_rows),
"row_coverage": len(cur_rows) / total_rows,
})
cur_edges = [left]
cur_rows = {row_idx}
clusters.append({
"mean_x": int(sum(cur_edges) / len(cur_edges)),
"min_edge": min(cur_edges),
"max_edge": max(cur_edges),
"count": len(cur_edges),
"distinct_rows": len(cur_rows),
"row_coverage": len(cur_rows) / total_rows,
})
# --- Filter by row coverage ---
# These thresholds must be high enough to avoid false columns in flowing
# text (random inter-word gaps) while still detecting real columns in
# vocabulary worksheets (which typically have >80% row coverage).
MIN_COVERAGE_PRIMARY = 0.35
MIN_COVERAGE_SECONDARY = 0.12
MIN_WORDS_SECONDARY = 4
MIN_DISTINCT_ROWS = 3
# Content boundary for left-margin detection
content_x_min = min(w["left"] for w in words)
content_x_max = max(w["left"] + w["width"] for w in words)
content_span = content_x_max - content_x_min
primary = [
c for c in clusters
if c["row_coverage"] >= MIN_COVERAGE_PRIMARY
and c["distinct_rows"] >= MIN_DISTINCT_ROWS
]
primary_ids = {id(c) for c in primary}
secondary = [
c for c in clusters
if id(c) not in primary_ids
and c["row_coverage"] >= MIN_COVERAGE_SECONDARY
and c["count"] >= MIN_WORDS_SECONDARY
and c["distinct_rows"] >= MIN_DISTINCT_ROWS
]
# Tertiary: narrow left-margin columns (page refs, markers) that have
# too few rows for secondary but are clearly left-aligned and separated
# from the main content. These appear at the far left or far right and
# have a large gap to the nearest significant cluster.
used_ids = {id(c) for c in primary} | {id(c) for c in secondary}
sig_xs = [c["mean_x"] for c in primary + secondary]
# Tertiary: clusters that are clearly to the LEFT of the first
# significant column (or RIGHT of the last). If words consistently
# start at a position left of the established first column boundary,
# they MUST be a separate column — regardless of how few rows they
# cover. The only requirement is a clear spatial gap.
MIN_COVERAGE_TERTIARY = 0.02 # at least 1 row effectively
tertiary = []
for c in clusters:
if id(c) in used_ids:
continue
if c["distinct_rows"] < 1:
continue
if c["row_coverage"] < MIN_COVERAGE_TERTIARY:
continue
# Must be near left or right content margin (within 15%)
rel_pos = (c["mean_x"] - content_x_min) / content_span if content_span else 0.5
if not (rel_pos < 0.15 or rel_pos > 0.85):
continue
# Must have significant gap to nearest significant cluster
if sig_xs:
min_dist = min(abs(c["mean_x"] - sx) for sx in sig_xs)
if min_dist < max(30, content_span * 0.02):
continue
tertiary.append(c)
if tertiary:
for c in tertiary:
logger.info(
" tertiary (margin) cluster: x=%d (range %d-%d), %d words, %d rows (%.0f%%)",
c["mean_x"], c["min_edge"], c["max_edge"],
c["count"], c["distinct_rows"], c["row_coverage"] * 100,
)
significant = sorted(primary + secondary + tertiary, key=lambda c: c["mean_x"])
for c in significant:
logger.info(
" significant cluster: x=%d (range %d-%d), %d words, %d rows (%.0f%%)",
c["mean_x"], c["min_edge"], c["max_edge"],
c["count"], c["distinct_rows"], c["row_coverage"] * 100,
)
logger.info(
"alignment columns: %d clusters, %d primary, %d secondary -> %d significant",
len(clusters), len(primary), len(secondary), len(significant),
)
if not significant:
# Fallback: single column covering all content
x_min = min(w["left"] for w in words)
x_max = max(w["left"] + w["width"] for w in words)
return [{"index": 0, "type": "column_text", "x_min": x_min, "x_max": x_max}]
# --- Merge nearby clusters ---
merge_distance = max(25, int(zone_w * 0.03))
merged = [significant[0].copy()]
for s in significant[1:]:
if s["mean_x"] - merged[-1]["mean_x"] < merge_distance:
prev = merged[-1]
total = prev["count"] + s["count"]
prev["mean_x"] = (
prev["mean_x"] * prev["count"] + s["mean_x"] * s["count"]
) // total
prev["count"] = total
prev["min_edge"] = min(prev["min_edge"], s["min_edge"])
prev["max_edge"] = max(prev["max_edge"], s["max_edge"])
prev["distinct_rows"] = max(prev["distinct_rows"], s["distinct_rows"])
else:
merged.append(s.copy())
logger.info(
"alignment columns: %d after merge (distance=%d)",
len(merged), merge_distance,
)
# --- Build column boundaries ---
margin = max(5, int(zone_w * 0.005))
content_x_min = min(w["left"] for w in words)
content_x_max = max(w["left"] + w["width"] for w in words)
columns: List[Dict[str, Any]] = []
for i, cluster in enumerate(merged):
x_min = max(content_x_min, cluster["min_edge"] - margin)
if i + 1 < len(merged):
x_max = merged[i + 1]["min_edge"] - margin
else:
x_max = content_x_max
columns.append({
"index": i,
"type": f"column_{i + 1}" if len(merged) > 1 else "column_text",
"x_min": x_min,
"x_max": x_max,
})
return columns
_MARKER_CHARS = set("*-+#>")
def _merge_inline_marker_columns(
columns: List[Dict],
words: List[Dict],
) -> List[Dict]:
"""Merge narrow marker columns (bullets, numbering) into adjacent text.
Bullet points (*, -) and numbering (1., 2.) create narrow columns
at the left edge of a zone. These are inline markers that indent text,
not real separate columns. Merge them with their right neighbour.
Does NOT merge columns containing alphabetic words like "to", "in",
"der", "die", "das" — those are legitimate content columns.
"""
if len(columns) < 2:
return columns
merged: List[Dict] = []
skip: set = set()
for i, col in enumerate(columns):
if i in skip:
continue
# Find words in this column
col_words = [
w for w in words
if col["x_min"] <= w["left"] + w["width"] / 2 < col["x_max"]
]
col_width = col["x_max"] - col["x_min"]
# Narrow column with mostly short words -> MIGHT be inline markers
if col_words and col_width < 80:
avg_len = sum(len(w.get("text", "")) for w in col_words) / len(col_words)
if avg_len <= 2 and i + 1 < len(columns):
# Check if words are actual markers (symbols/numbers) vs
# real alphabetic words like "to", "in", "der", "die"
texts = [(w.get("text") or "").strip() for w in col_words]
alpha_count = sum(
1 for t in texts
if t and t[0].isalpha() and t not in _MARKER_CHARS
)
alpha_ratio = alpha_count / len(texts) if texts else 0
# If >=50% of words are alphabetic, this is a real column
if alpha_ratio >= 0.5:
logger.info(
" kept narrow column %d (w=%d, avg_len=%.1f, "
"alpha=%.0f%%) -- contains real words",
i, col_width, avg_len, alpha_ratio * 100,
)
else:
# Merge into next column
next_col = columns[i + 1].copy()
next_col["x_min"] = col["x_min"]
merged.append(next_col)
skip.add(i + 1)
logger.info(
" merged inline marker column %d (w=%d, avg_len=%.1f) "
"into column %d",
i, col_width, avg_len, i + 1,
)
continue
merged.append(col)
# Re-index
for i, col in enumerate(merged):
col["index"] = i
col["type"] = f"column_{i + 1}" if len(merged) > 1 else "column_text"
return merged

View File

@@ -0,0 +1,402 @@
"""
Grid Editor — word/zone filtering, border ghosts, decorative margins, footers.
Split from grid_editor_helpers.py for maintainability.
All functions are pure computation — no HTTP, DB, or session side effects.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def _filter_border_strip_words(words: List[Dict]) -> Tuple[List[Dict], int]:
"""Remove page-border decoration strip words BEFORE column detection.
Scans from each page edge inward to find the first significant x-gap
(>30 px). If the edge cluster contains <15 % of total words, those
words are removed as border-strip artifacts (alphabet letters,
illustration fragments).
Must run BEFORE ``_build_zone_grid`` so that column detection only
sees real content words and doesn't produce inflated row counts.
"""
if len(words) < 10:
return words, 0
sorted_words = sorted(words, key=lambda w: w.get("left", 0))
total = len(sorted_words)
# -- Left-edge scan (running max right-edge) --
left_count = 0
running_right = 0
for gi in range(total - 1):
running_right = max(
running_right,
sorted_words[gi].get("left", 0) + sorted_words[gi].get("width", 0),
)
if sorted_words[gi + 1].get("left", 0) - running_right > 30:
left_count = gi + 1
break
# -- Right-edge scan (running min left) --
right_count = 0
running_left = sorted_words[-1].get("left", 0)
for gi in range(total - 1, 0, -1):
running_left = min(running_left, sorted_words[gi].get("left", 0))
prev_right = (
sorted_words[gi - 1].get("left", 0)
+ sorted_words[gi - 1].get("width", 0)
)
if running_left - prev_right > 30:
right_count = total - gi
break
# Validate candidate strip: real border decorations are mostly short
# words (alphabet letters like "A", "Bb", stray marks). Multi-word
# content like "der Ranzen" or "die Schals" (continuation of German
# translations) must NOT be removed.
def _is_decorative_strip(candidates: List[Dict]) -> bool:
if not candidates:
return False
short = sum(1 for w in candidates if len((w.get("text") or "").strip()) <= 2)
return short / len(candidates) >= 0.45
strip_ids: set = set()
if left_count > 0 and left_count / total < 0.20:
candidates = sorted_words[:left_count]
if _is_decorative_strip(candidates):
strip_ids = {id(w) for w in candidates}
elif right_count > 0 and right_count / total < 0.20:
candidates = sorted_words[total - right_count:]
if _is_decorative_strip(candidates):
strip_ids = {id(w) for w in candidates}
if not strip_ids:
return words, 0
return [w for w in words if id(w) not in strip_ids], len(strip_ids)
# Characters that are typically OCR artefacts from box border lines.
# Intentionally excludes ! (red markers) and . , ; (real punctuation).
_GRID_GHOST_CHARS = set("|1lI[](){}/\\-\u2014\u2013_~=+")
def _filter_border_ghosts(
words: List[Dict],
boxes: List,
) -> tuple:
"""Remove words sitting on box borders that are OCR artefacts.
Returns (filtered_words, removed_count).
"""
if not boxes or not words:
return words, 0
# Build border bands from detected boxes
x_bands: List[tuple] = []
y_bands: List[tuple] = []
for b in boxes:
bt = (
b.border_thickness
if hasattr(b, "border_thickness")
else b.get("border_thickness", 3)
)
# Skip borderless boxes (images/graphics) -- no border line to produce ghosts
if bt == 0:
continue
bx = b.x if hasattr(b, "x") else b.get("x", 0)
by = b.y if hasattr(b, "y") else b.get("y", 0)
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
margin = max(bt * 2, 10) + 6
x_bands.append((bx - margin, bx + margin))
x_bands.append((bx + bw - margin, bx + bw + margin))
y_bands.append((by - margin, by + margin))
y_bands.append((by + bh - margin, by + bh + margin))
def _is_ghost(w: Dict) -> bool:
text = (w.get("text") or "").strip()
if not text:
return False
# Check if any word edge (not just center) touches a border band
w_left = w["left"]
w_right = w["left"] + w["width"]
w_top = w["top"]
w_bottom = w["top"] + w["height"]
on_border = (
any(lo <= w_left <= hi or lo <= w_right <= hi for lo, hi in x_bands)
or any(lo <= w_top <= hi or lo <= w_bottom <= hi for lo, hi in y_bands)
)
if not on_border:
return False
if len(text) == 1 and text in _GRID_GHOST_CHARS:
return True
return False
filtered = [w for w in words if not _is_ghost(w)]
return filtered, len(words) - len(filtered)
def _flatten_word_boxes(cells: List[Dict]) -> List[Dict]:
"""Extract all word_boxes from cells into a flat list of word dicts."""
words: List[Dict] = []
for cell in cells:
for wb in cell.get("word_boxes") or []:
if wb.get("text", "").strip():
words.append({
"text": wb["text"],
"left": wb["left"],
"top": wb["top"],
"width": wb["width"],
"height": wb["height"],
"conf": wb.get("conf", 0),
})
return words
def _words_in_zone(
words: List[Dict],
zone_y: int,
zone_h: int,
zone_x: int,
zone_w: int,
) -> List[Dict]:
"""Filter words whose Y-center falls within a zone's bounds."""
zone_y_end = zone_y + zone_h
zone_x_end = zone_x + zone_w
result = []
for w in words:
cy = w["top"] + w["height"] / 2
cx = w["left"] + w["width"] / 2
if zone_y <= cy <= zone_y_end and zone_x <= cx <= zone_x_end:
result.append(w)
return result
def _get_content_bounds(words: List[Dict]) -> tuple:
"""Get content bounds from word positions."""
if not words:
return 0, 0, 0, 0
x_min = min(w["left"] for w in words)
y_min = min(w["top"] for w in words)
x_max = max(w["left"] + w["width"] for w in words)
y_max = max(w["top"] + w["height"] for w in words)
return x_min, y_min, x_max - x_min, y_max - y_min
def _filter_decorative_margin(
words: List[Dict],
img_w: int,
log: Any,
session_id: str,
) -> Dict[str, Any]:
"""Remove words that belong to a decorative alphabet strip on a margin.
Some vocabulary worksheets have a vertical A-Z alphabet graphic along
the left or right edge. OCR reads each letter as an isolated single-
character word. These decorative elements are not content and confuse
column/row detection.
Detection criteria (phase 1 -- find the strip using single-char words):
- Words are in the outer 30% of the page (left or right)
- Nearly all words are single characters (letters or digits)
- At least 8 such words form a vertical strip (>=8 unique Y positions)
- Average horizontal spread of the strip is small (< 80px)
Phase 2 -- once a strip is confirmed, also remove any short word (<=3
chars) in the same narrow x-range. This catches multi-char OCR
artifacts like "Vv" that belong to the same decorative element.
Modifies *words* in place.
Returns:
Dict with 'found' (bool), 'side' (str), 'letters_detected' (int).
"""
no_strip: Dict[str, Any] = {"found": False, "side": "", "letters_detected": 0}
if not words or img_w <= 0:
return no_strip
margin_cutoff = img_w * 0.30
# Phase 1: find candidate strips using short words (1-2 chars).
# OCR often reads alphabet sidebar letters as pairs ("Aa", "Bb")
# rather than singles, so accept <=2-char words as strip candidates.
left_strip = [
w for w in words
if len((w.get("text") or "").strip()) <= 2
and w["left"] + w.get("width", 0) / 2 < margin_cutoff
]
right_strip = [
w for w in words
if len((w.get("text") or "").strip()) <= 2
and w["left"] + w.get("width", 0) / 2 > img_w - margin_cutoff
]
for strip, side in [(left_strip, "left"), (right_strip, "right")]:
if len(strip) < 6:
continue
# Check vertical distribution: should have many distinct Y positions
y_centers = sorted(set(
int(w["top"] + w.get("height", 0) / 2) // 20 * 20 # bucket
for w in strip
))
if len(y_centers) < 6:
continue
# Check horizontal compactness
x_positions = [w["left"] for w in strip]
x_min = min(x_positions)
x_max = max(x_positions)
x_spread = x_max - x_min
if x_spread > 80:
continue
# Phase 2: strip confirmed -- also collect short words in same x-range
# Expand x-range slightly to catch neighbors (e.g. "Vv" next to "U")
strip_x_lo = x_min - 20
strip_x_hi = x_max + 60 # word width + tolerance
all_strip_words = [
w for w in words
if len((w.get("text") or "").strip()) <= 3
and strip_x_lo <= w["left"] <= strip_x_hi
and (w["left"] + w.get("width", 0) / 2 < margin_cutoff
if side == "left"
else w["left"] + w.get("width", 0) / 2 > img_w - margin_cutoff)
]
strip_set = set(id(w) for w in all_strip_words)
before = len(words)
words[:] = [w for w in words if id(w) not in strip_set]
removed = before - len(words)
if removed:
log.info(
"build-grid session %s: removed %d decorative %s-margin words "
"(strip x=%d-%d)",
session_id, removed, side, strip_x_lo, strip_x_hi,
)
return {"found": True, "side": side, "letters_detected": len(strip)}
return no_strip
def _filter_footer_words(
words: List[Dict],
img_h: int,
log: Any,
session_id: str,
) -> Optional[Dict]:
"""Remove isolated words in the bottom 5% of the page (page numbers).
Modifies *words* in place and returns a page_number metadata dict
if a page number was extracted, or None.
"""
if not words or img_h <= 0:
return None
footer_y = img_h * 0.95
footer_words = [
w for w in words
if w["top"] + w.get("height", 0) / 2 > footer_y
]
if not footer_words:
return None
# Only remove if footer has very few words (<= 3) with short text
total_text = "".join((w.get("text") or "").strip() for w in footer_words)
if len(footer_words) <= 3 and len(total_text) <= 10:
# Extract page number metadata before removing
page_number_info = {
"text": total_text.strip(),
"y_pct": round(footer_words[0]["top"] / img_h * 100, 1),
}
# Try to parse as integer
digits = "".join(c for c in total_text if c.isdigit())
if digits:
page_number_info["number"] = int(digits)
footer_set = set(id(w) for w in footer_words)
words[:] = [w for w in words if id(w) not in footer_set]
log.info(
"build-grid session %s: extracted page number '%s' and removed %d footer words",
session_id, total_text, len(footer_words),
)
return page_number_info
return None
def _filter_header_junk(
words: List[Dict],
img_h: int,
log: Any,
session_id: str,
) -> None:
"""Remove OCR junk from header illustrations above the real content.
Textbook pages often have decorative header graphics (illustrations,
icons) that OCR reads as low-confidence junk characters. Real content
typically starts further down the page.
Algorithm:
1. Find the "content start" -- the first Y position where a dense
horizontal row of 3+ high-confidence words begins.
2. Above that line, remove words with conf < 75 and text <= 3 chars.
These are almost certainly OCR artifacts from illustrations.
Modifies *words* in place.
"""
if not words or img_h <= 0:
return
# --- Find content start: first horizontal row with >=3 high-conf words ---
# Sort words by Y
sorted_by_y = sorted(words, key=lambda w: w["top"])
content_start_y = 0
_ROW_TOLERANCE = img_h * 0.02 # words within 2% of page height = same row
_MIN_ROW_WORDS = 3
_MIN_CONF = 80
i = 0
while i < len(sorted_by_y):
row_y = sorted_by_y[i]["top"]
# Collect words in this row band
row_words = []
j = i
while j < len(sorted_by_y) and sorted_by_y[j]["top"] - row_y < _ROW_TOLERANCE:
row_words.append(sorted_by_y[j])
j += 1
# Count high-confidence words with real text (> 1 char)
high_conf = [
w for w in row_words
if w.get("conf", 0) >= _MIN_CONF
and len((w.get("text") or "").strip()) > 1
]
if len(high_conf) >= _MIN_ROW_WORDS:
content_start_y = row_y
break
i = j if j > i else i + 1
if content_start_y <= 0:
return # no clear content start found
# --- Remove low-conf short junk above content start ---
junk = [
w for w in words
if w["top"] + w.get("height", 0) < content_start_y
and w.get("conf", 0) < 75
and len((w.get("text") or "").strip()) <= 3
]
if not junk:
return
junk_set = set(id(w) for w in junk)
before = len(words)
words[:] = [w for w in words if id(w) not in junk_set]
removed = before - len(words)
if removed:
log.info(
"build-grid session %s: removed %d header junk words above y=%d "
"(content start)",
session_id, removed, content_start_y,
)

View File

@@ -0,0 +1,499 @@
"""
Grid Editor — header/heading detection and colspan (merged cell) detection.
Split from grid_editor_helpers.py. Pure computation, no HTTP/DB side effects.
Lizenz: Apache 2.0 | DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List, Optional
from cv_ocr_engines import _text_has_garbled_ipa
logger = logging.getLogger(__name__)
def _detect_heading_rows_by_color(zones_data: List[Dict], img_w: int, img_h: int) -> int:
"""Detect heading rows by color + height after color annotation.
A row is a heading if:
1. ALL word_boxes have color_name != 'black' (typically 'blue')
2. Mean word height > 1.2x median height of all words in the zone
Detected heading rows are merged into a single spanning cell.
Returns count of headings detected.
"""
heading_count = 0
for z in zones_data:
cells = z.get("cells", [])
rows = z.get("rows", [])
columns = z.get("columns", [])
if not cells or not rows or len(columns) < 2:
continue
# Compute median word height across the zone
all_heights = []
for cell in cells:
for wb in cell.get("word_boxes") or []:
h = wb.get("height", 0)
if h > 0:
all_heights.append(h)
if not all_heights:
continue
all_heights_sorted = sorted(all_heights)
median_h = all_heights_sorted[len(all_heights_sorted) // 2]
heading_row_indices = []
for row in rows:
if row.get("is_header"):
continue # already detected as header
ri = row["index"]
row_cells = [c for c in cells if c.get("row_index") == ri]
row_wbs = [
wb for cell in row_cells
for wb in cell.get("word_boxes") or []
]
if not row_wbs:
continue
# Condition 1: ALL words are non-black
all_colored = all(
wb.get("color_name", "black") != "black"
for wb in row_wbs
)
if not all_colored:
continue
# Condition 2: mean height > 1.2x median
mean_h = sum(wb.get("height", 0) for wb in row_wbs) / len(row_wbs)
if mean_h <= median_h * 1.2:
continue
heading_row_indices.append(ri)
# Merge heading cells into spanning cells
for hri in heading_row_indices:
header_cells = [c for c in cells if c.get("row_index") == hri]
if len(header_cells) <= 1:
# Single cell -- just mark it as heading
if header_cells:
header_cells[0]["col_type"] = "heading"
heading_count += 1
# Mark row as header
for row in rows:
if row["index"] == hri:
row["is_header"] = True
continue
# Collect all word_boxes and text from all columns
all_wb = []
all_text_parts = []
for hc in sorted(header_cells, key=lambda c: c["col_index"]):
all_wb.extend(hc.get("word_boxes", []))
if hc.get("text", "").strip():
all_text_parts.append(hc["text"].strip())
# Remove all cells for this row, replace with one spanning cell
z["cells"] = [c for c in z["cells"] if c.get("row_index") != hri]
if all_wb:
x_min = min(wb["left"] for wb in all_wb)
y_min = min(wb["top"] for wb in all_wb)
x_max = max(wb["left"] + wb["width"] for wb in all_wb)
y_max = max(wb["top"] + wb["height"] for wb in all_wb)
# Use the actual starting col_index from the first cell
first_col = min(hc["col_index"] for hc in header_cells)
zone_idx = z.get("zone_index", 0)
z["cells"].append({
"cell_id": f"Z{zone_idx}_R{hri:02d}_C{first_col}",
"zone_index": zone_idx,
"row_index": hri,
"col_index": first_col,
"col_type": "heading",
"text": " ".join(all_text_parts),
"confidence": 0.0,
"bbox_px": {"x": x_min, "y": y_min,
"w": x_max - x_min, "h": y_max - y_min},
"bbox_pct": {
"x": round(x_min / img_w * 100, 2) if img_w else 0,
"y": round(y_min / img_h * 100, 2) if img_h else 0,
"w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0,
"h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0,
},
"word_boxes": all_wb,
"ocr_engine": "words_first",
"is_bold": True,
})
# Mark row as header
for row in rows:
if row["index"] == hri:
row["is_header"] = True
heading_count += 1
return heading_count
def _detect_heading_rows_by_single_cell(
zones_data: List[Dict], img_w: int, img_h: int,
) -> int:
"""Detect heading rows that have only a single content cell.
Black headings like "Theme" have normal color and height, so they are
missed by ``_detect_heading_rows_by_color``. The distinguishing signal
is that they occupy only one column while normal vocabulary rows fill
at least 2-3 columns.
A row qualifies as a heading if:
1. It is not already marked as a header/heading.
2. It has exactly ONE cell whose col_type starts with ``column_``
(excluding column_1 / page_ref which only carries page numbers).
3. That single cell is NOT in the last column (continuation/example
lines like "2. Ver\u00e4nderung, Wechsel" often sit alone in column_4).
4. The text does not start with ``[`` (IPA continuation).
5. The zone has >=3 columns and >=5 rows (avoids false positives in
tiny zones).
6. The majority of rows in the zone have >=2 content cells (ensures
we are in a multi-column vocab layout).
"""
heading_count = 0
for z in zones_data:
cells = z.get("cells", [])
rows = z.get("rows", [])
columns = z.get("columns", [])
if len(columns) < 3 or len(rows) < 5:
continue
# Determine the last col_index (example/sentence column)
col_indices = sorted(set(c.get("col_index", 0) for c in cells))
if not col_indices:
continue
last_col = col_indices[-1]
# Count content cells per row (column_* but not column_1/page_ref).
# Exception: column_1 cells that contain a dictionary article word
# (die/der/das etc.) ARE content -- they appear in dictionary layouts
# where the leftmost column holds grammatical articles.
_ARTICLE_WORDS = {
"die", "der", "das", "dem", "den", "des", "ein", "eine",
"the", "a", "an",
}
row_content_counts: Dict[int, int] = {}
for cell in cells:
ct = cell.get("col_type", "")
if not ct.startswith("column_"):
continue
if ct == "column_1":
ctext = (cell.get("text") or "").strip().lower()
if ctext not in _ARTICLE_WORDS:
continue
ri = cell.get("row_index", -1)
row_content_counts[ri] = row_content_counts.get(ri, 0) + 1
# Majority of rows must have >=2 content cells
multi_col_rows = sum(1 for cnt in row_content_counts.values() if cnt >= 2)
if multi_col_rows < len(rows) * 0.4:
continue
# Exclude first and last non-header rows -- these are typically
# page numbers or footer text, not headings.
non_header_rows = [r for r in rows if not r.get("is_header")]
if len(non_header_rows) < 3:
continue
first_ri = non_header_rows[0]["index"]
last_ri = non_header_rows[-1]["index"]
heading_row_indices = []
for row in rows:
if row.get("is_header"):
continue
ri = row["index"]
if ri == first_ri or ri == last_ri:
continue
row_cells = [c for c in cells if c.get("row_index") == ri]
content_cells = [
c for c in row_cells
if c.get("col_type", "").startswith("column_")
and (c.get("col_type") != "column_1"
or (c.get("text") or "").strip().lower() in _ARTICLE_WORDS)
]
if len(content_cells) != 1:
continue
cell = content_cells[0]
# Not in the last column (continuation/example lines)
if cell.get("col_index") == last_col:
continue
text = (cell.get("text") or "").strip()
if not text or text.startswith("["):
continue
# Continuation lines start with "(" -- e.g. "(usw.)", "(TV-Serie)"
if text.startswith("("):
continue
# Single cell NOT in the first content column is likely a
# continuation/overflow line, not a heading. Real headings
# ("Theme 1", "Unit 3: ...") appear in the first or second
# content column.
first_content_col = col_indices[0] if col_indices else 0
if cell.get("col_index", 0) > first_content_col + 1:
continue
# Skip garbled IPA without brackets (e.g. "ska:f -- ska:vz")
# but NOT text with real IPA symbols (e.g. "Theme [\u03b8\u02c8i\u02d0m]")
_REAL_IPA_CHARS = set("\u02c8\u02cc\u0259\u026a\u025b\u0252\u028a\u028c\u00e6\u0251\u0254\u0283\u0292\u03b8\u00f0\u014b")
if _text_has_garbled_ipa(text) and not any(c in _REAL_IPA_CHARS for c in text):
continue
# Guard: dictionary section headings are short (1-4 alpha chars
# like "A", "Ab", "Zi", "Sch"). Longer text that starts
# lowercase is a regular vocabulary word (e.g. "zentral") that
# happens to appear alone in its row.
alpha_only = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', text)
if len(alpha_only) > 4 and text[0].islower():
continue
heading_row_indices.append(ri)
# Guard: if >25% of eligible rows would become headings, the
# heuristic is misfiring (e.g. sparse single-column layout where
# most rows naturally have only 1 content cell).
eligible_rows = len(non_header_rows) - 2 # minus first/last excluded
if eligible_rows > 0 and len(heading_row_indices) > eligible_rows * 0.25:
logger.debug(
"Skipping single-cell heading detection for zone %s: "
"%d/%d rows would be headings (>25%%)",
z.get("zone_index"), len(heading_row_indices), eligible_rows,
)
continue
for hri in heading_row_indices:
header_cells = [c for c in cells if c.get("row_index") == hri]
if not header_cells:
continue
# Collect all word_boxes and text
all_wb = []
all_text_parts = []
for hc in sorted(header_cells, key=lambda c: c["col_index"]):
all_wb.extend(hc.get("word_boxes", []))
if hc.get("text", "").strip():
all_text_parts.append(hc["text"].strip())
first_col_idx = min(hc["col_index"] for hc in header_cells)
# Remove old cells for this row, add spanning heading cell
z["cells"] = [c for c in z["cells"] if c.get("row_index") != hri]
if all_wb:
x_min = min(wb["left"] for wb in all_wb)
y_min = min(wb["top"] for wb in all_wb)
x_max = max(wb["left"] + wb["width"] for wb in all_wb)
y_max = max(wb["top"] + wb["height"] for wb in all_wb)
else:
# Fallback to first cell bbox
bp = header_cells[0].get("bbox_px", {})
x_min = bp.get("x", 0)
y_min = bp.get("y", 0)
x_max = x_min + bp.get("w", 0)
y_max = y_min + bp.get("h", 0)
zone_idx = z.get("zone_index", 0)
z["cells"].append({
"cell_id": f"Z{zone_idx}_R{hri:02d}_C{first_col_idx}",
"zone_index": zone_idx,
"row_index": hri,
"col_index": first_col_idx,
"col_type": "heading",
"text": " ".join(all_text_parts),
"confidence": 0.0,
"bbox_px": {"x": x_min, "y": y_min,
"w": x_max - x_min, "h": y_max - y_min},
"bbox_pct": {
"x": round(x_min / img_w * 100, 2) if img_w else 0,
"y": round(y_min / img_h * 100, 2) if img_h else 0,
"w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0,
"h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0,
},
"word_boxes": all_wb,
"ocr_engine": "words_first",
"is_bold": False,
})
for row in rows:
if row["index"] == hri:
row["is_header"] = True
heading_count += 1
return heading_count
def _detect_header_rows(
rows: List[Dict],
zone_words: List[Dict],
zone_y: int,
columns: Optional[List[Dict]] = None,
skip_first_row_header: bool = False,
) -> List[int]:
"""Detect header rows: first-row heuristic + spanning header detection.
A "spanning header" is a row whose words stretch across multiple column
boundaries (e.g. "Unit4: Bonnie Scotland" centred across 4 columns).
"""
if len(rows) < 2:
return []
headers = []
if not skip_first_row_header:
first_row = rows[0]
second_row = rows[1]
# Gap between first and second row > 0.5x average row height
avg_h = sum(r["y_max"] - r["y_min"] for r in rows) / len(rows)
gap = second_row["y_min"] - first_row["y_max"]
if gap > avg_h * 0.5:
headers.append(0)
# Also check if first row words are taller than average (bold/header text)
all_heights = [w["height"] for w in zone_words]
median_h = sorted(all_heights)[len(all_heights) // 2] if all_heights else 20
first_row_words = [
w for w in zone_words
if first_row["y_min"] <= w["top"] + w["height"] / 2 <= first_row["y_max"]
]
if first_row_words:
first_h = max(w["height"] for w in first_row_words)
if first_h > median_h * 1.3:
if 0 not in headers:
headers.append(0)
# Note: Spanning-header detection (rows spanning all columns) has been
# disabled because it produces too many false positives on vocabulary
# worksheets where IPA transcriptions or short entries naturally span
# multiple columns with few words. The first-row heuristic above is
# sufficient for detecting real headers.
return headers
def _detect_colspan_cells(
zone_words: List[Dict],
columns: List[Dict],
rows: List[Dict],
cells: List[Dict],
img_w: int,
img_h: int,
) -> List[Dict]:
"""Detect and merge cells that span multiple columns (colspan).
A word-block (PaddleOCR phrase) that extends significantly past a column
boundary into the next column indicates a merged cell. This replaces
the incorrectly split cells with a single cell spanning multiple columns.
Works for both full-page scans and box zones.
"""
if len(columns) < 2 or not zone_words or not rows:
return cells
from cv_words_first import _assign_word_to_row
# Column boundaries (midpoints between adjacent columns)
col_boundaries = []
for ci in range(len(columns) - 1):
col_boundaries.append((columns[ci]["x_max"] + columns[ci + 1]["x_min"]) / 2)
def _cols_covered(w_left: float, w_right: float) -> List[int]:
"""Return list of column indices that a word-block covers."""
covered = []
for col in columns:
col_mid = (col["x_min"] + col["x_max"]) / 2
# Word covers a column if it extends past the column's midpoint
if w_left < col_mid < w_right:
covered.append(col["index"])
# Also include column if word starts within it
elif col["x_min"] <= w_left < col["x_max"]:
covered.append(col["index"])
return sorted(set(covered))
# Group original word-blocks by row
row_word_blocks: Dict[int, List[Dict]] = {}
for w in zone_words:
ri = _assign_word_to_row(w, rows)
row_word_blocks.setdefault(ri, []).append(w)
# For each row, check if any word-block spans multiple columns
rows_to_merge: Dict[int, List[Dict]] = {} # row_index -> list of spanning word-blocks
for ri, wblocks in row_word_blocks.items():
spanning = []
for w in wblocks:
w_left = w["left"]
w_right = w_left + w["width"]
covered = _cols_covered(w_left, w_right)
if len(covered) >= 2:
spanning.append({"word": w, "cols": covered})
if spanning:
rows_to_merge[ri] = spanning
if not rows_to_merge:
return cells
# Merge cells for spanning rows
new_cells = []
for cell in cells:
ri = cell.get("row_index", -1)
if ri not in rows_to_merge:
new_cells.append(cell)
continue
# Check if this cell's column is part of a spanning block
ci = cell.get("col_index", -1)
is_part_of_span = False
for span in rows_to_merge[ri]:
if ci in span["cols"]:
is_part_of_span = True
# Only emit the merged cell for the FIRST column in the span
if ci == span["cols"][0]:
# Use the ORIGINAL word-block text (not the split cell texts
# which may have broken words like "euros a" + "nd cents")
orig_word = span["word"]
merged_text = orig_word.get("text", "").strip()
all_wb = [orig_word]
# Compute merged bbox
if all_wb:
x_min = min(wb["left"] for wb in all_wb)
y_min = min(wb["top"] for wb in all_wb)
x_max = max(wb["left"] + wb["width"] for wb in all_wb)
y_max = max(wb["top"] + wb["height"] for wb in all_wb)
else:
x_min = y_min = x_max = y_max = 0
new_cells.append({
"cell_id": cell["cell_id"],
"row_index": ri,
"col_index": span["cols"][0],
"col_type": "spanning_header",
"colspan": len(span["cols"]),
"text": merged_text,
"confidence": cell.get("confidence", 0),
"bbox_px": {"x": x_min, "y": y_min,
"w": x_max - x_min, "h": y_max - y_min},
"bbox_pct": {
"x": round(x_min / img_w * 100, 2) if img_w else 0,
"y": round(y_min / img_h * 100, 2) if img_h else 0,
"w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0,
"h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0,
},
"word_boxes": all_wb,
"ocr_engine": cell.get("ocr_engine", ""),
"is_bold": cell.get("is_bold", False),
})
logger.info(
"colspan detected: row %d, cols %s -> merged %d cells (%r)",
ri, span["cols"], len(span["cols"]), merged_text[:50],
)
break
if not is_part_of_span:
new_cells.append(cell)
return new_cells

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,389 @@
"""
Grid Editor — vertical divider detection, zone splitting/merging, zone grid building.
Split from grid_editor_helpers.py for maintainability.
All functions are pure computation — no HTTP, DB, or session side effects.
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Any, Dict, List, Optional
from cv_vocab_types import PageZone
from cv_words_first import _cluster_rows, _build_cells
from grid_editor_columns import (
_cluster_columns_by_alignment,
_merge_inline_marker_columns,
_split_cross_column_words,
)
from grid_editor_headers import (
_detect_header_rows,
_detect_colspan_cells,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Vertical divider detection and zone splitting
# ---------------------------------------------------------------------------
_PIPE_RE_VSPLIT = re.compile(r"^\|+$")
def _detect_vertical_dividers(
words: List[Dict],
zone_x: int,
zone_w: int,
zone_y: int,
zone_h: int,
) -> List[float]:
"""Detect vertical divider lines from pipe word_boxes at consistent x.
Returns list of divider x-positions (empty if no dividers found).
"""
if not words or zone_w <= 0 or zone_h <= 0:
return []
# Collect pipe word_boxes
pipes = [
w for w in words
if _PIPE_RE_VSPLIT.match((w.get("text") or "").strip())
]
if len(pipes) < 5:
return []
# Cluster pipe x-centers by proximity
tolerance = max(15, int(zone_w * 0.02))
pipe_xs = sorted(w["left"] + w["width"] / 2 for w in pipes)
clusters: List[List[float]] = [[pipe_xs[0]]]
for x in pipe_xs[1:]:
if x - clusters[-1][-1] <= tolerance:
clusters[-1].append(x)
else:
clusters.append([x])
dividers: List[float] = []
for cluster in clusters:
if len(cluster) < 5:
continue
mean_x = sum(cluster) / len(cluster)
# Must be between 15% and 85% of zone width
rel_pos = (mean_x - zone_x) / zone_w
if rel_pos < 0.15 or rel_pos > 0.85:
continue
# Check vertical coverage: pipes must span >= 50% of zone height
cluster_pipes = [
w for w in pipes
if abs(w["left"] + w["width"] / 2 - mean_x) <= tolerance
]
ys = [w["top"] for w in cluster_pipes] + [w["top"] + w["height"] for w in cluster_pipes]
y_span = max(ys) - min(ys) if ys else 0
if y_span < zone_h * 0.5:
continue
dividers.append(mean_x)
return sorted(dividers)
def _split_zone_at_vertical_dividers(
zone: "PageZone",
divider_xs: List[float],
vsplit_group_id: int,
) -> List["PageZone"]:
"""Split a PageZone at vertical divider positions into sub-zones."""
boundaries = [zone.x] + divider_xs + [zone.x + zone.width]
hints = []
for i in range(len(boundaries) - 1):
if i == 0:
hints.append("left_of_vsplit")
elif i == len(boundaries) - 2:
hints.append("right_of_vsplit")
else:
hints.append("middle_of_vsplit")
sub_zones = []
for i in range(len(boundaries) - 1):
x_start = int(boundaries[i])
x_end = int(boundaries[i + 1])
sub = PageZone(
index=0, # re-indexed later
zone_type=zone.zone_type,
y=zone.y,
height=zone.height,
x=x_start,
width=x_end - x_start,
box=zone.box,
image_overlays=zone.image_overlays,
layout_hint=hints[i],
vsplit_group=vsplit_group_id,
)
sub_zones.append(sub)
return sub_zones
def _merge_content_zones_across_boxes(
zones: List,
content_x: int,
content_w: int,
) -> List:
"""Merge content zones separated by box zones into single zones.
Box zones become image_overlays on the merged content zone.
Pattern: [content, box*, content] -> [merged_content with overlay]
Box zones NOT between two content zones stay as standalone zones.
"""
if len(zones) < 3:
return zones
# Group consecutive runs of [content, box+, content]
result: List = []
i = 0
while i < len(zones):
z = zones[i]
if z.zone_type != "content":
result.append(z)
i += 1
continue
# Start of a potential merge group: content zone
group_contents = [z]
group_boxes = []
j = i + 1
# Absorb [box, content] pairs -- only absorb a box if it's
# confirmed to be followed by another content zone.
while j < len(zones):
if (zones[j].zone_type == "box"
and j + 1 < len(zones)
and zones[j + 1].zone_type == "content"):
group_boxes.append(zones[j])
group_contents.append(zones[j + 1])
j += 2
else:
break
if len(group_contents) >= 2 and group_boxes:
# Merge: create one large content zone spanning all
y_min = min(c.y for c in group_contents)
y_max = max(c.y + c.height for c in group_contents)
overlays = []
for bz in group_boxes:
overlay = {
"y": bz.y,
"height": bz.height,
"x": bz.x,
"width": bz.width,
}
if bz.box:
overlay["box"] = {
"x": bz.box.x,
"y": bz.box.y,
"width": bz.box.width,
"height": bz.box.height,
"confidence": bz.box.confidence,
"border_thickness": bz.box.border_thickness,
}
overlays.append(overlay)
merged = PageZone(
index=0, # re-indexed below
zone_type="content",
y=y_min,
height=y_max - y_min,
x=content_x,
width=content_w,
image_overlays=overlays,
)
result.append(merged)
i = j
else:
# No merge possible -- emit just the content zone
result.append(z)
i += 1
# Re-index zones
for idx, z in enumerate(result):
z.index = idx
logger.info(
"zone-merge: %d zones -> %d zones after merging across boxes",
len(zones), len(result),
)
return result
def _build_zone_grid(
zone_words: List[Dict],
zone_x: int,
zone_y: int,
zone_w: int,
zone_h: int,
zone_index: int,
img_w: int,
img_h: int,
global_columns: Optional[List[Dict]] = None,
skip_first_row_header: bool = False,
) -> Dict[str, Any]:
"""Build columns, rows, cells for a single zone from its words.
Args:
global_columns: If provided, use these pre-computed column boundaries
instead of detecting columns per zone. Used for content zones so
that all content zones (above/between/below boxes) share the same
column structure. Box zones always detect columns independently.
"""
if not zone_words:
return {
"columns": [],
"rows": [],
"cells": [],
"header_rows": [],
}
# Cluster rows first (needed for column alignment analysis)
rows = _cluster_rows(zone_words)
# Diagnostic logging for small/medium zones (box zones typically have 40-60 words)
if len(zone_words) <= 60:
import statistics as _st
_heights = [w['height'] for w in zone_words if w.get('height', 0) > 0]
_med_h = _st.median(_heights) if _heights else 20
_y_tol = max(_med_h * 0.5, 5)
logger.info(
"zone %d row-clustering: %d words, median_h=%.0f, y_tol=%.1f -> %d rows",
zone_index, len(zone_words), _med_h, _y_tol, len(rows),
)
for w in sorted(zone_words, key=lambda ww: (ww['top'], ww['left'])):
logger.info(
" zone %d word: y=%d x=%d h=%d w=%d '%s'",
zone_index, w['top'], w['left'], w['height'], w['width'],
w.get('text', '')[:40],
)
for r in rows:
logger.info(
" zone %d row %d: y_min=%d y_max=%d y_center=%.0f",
zone_index, r['index'], r['y_min'], r['y_max'], r['y_center'],
)
# Use global columns if provided, otherwise detect per zone
columns = global_columns if global_columns else _cluster_columns_by_alignment(zone_words, zone_w, rows)
# Merge inline marker columns (bullets, numbering) into adjacent text
if not global_columns:
columns = _merge_inline_marker_columns(columns, zone_words)
if not columns or not rows:
return {
"columns": [],
"rows": [],
"cells": [],
"header_rows": [],
}
# Split word boxes that straddle column boundaries (e.g. "sichzie"
# spanning Col 1 + Col 2). Must happen after column detection and
# before cell assignment.
# Keep original words for colspan detection (split destroys span info).
original_zone_words = zone_words
if len(columns) >= 2:
zone_words = _split_cross_column_words(zone_words, columns)
# Build cells
cells = _build_cells(zone_words, columns, rows, img_w, img_h)
# --- Detect colspan (merged cells spanning multiple columns) ---
# Uses the ORIGINAL (pre-split) words to detect word-blocks that span
# multiple columns. _split_cross_column_words would have destroyed
# this information by cutting words at column boundaries.
if len(columns) >= 2:
cells = _detect_colspan_cells(original_zone_words, columns, rows, cells, img_w, img_h)
# Prefix cell IDs with zone index
for cell in cells:
cell["cell_id"] = f"Z{zone_index}_{cell['cell_id']}"
cell["zone_index"] = zone_index
# Detect header rows (pass columns for spanning header detection)
header_rows = _detect_header_rows(rows, zone_words, zone_y, columns,
skip_first_row_header=skip_first_row_header)
# Merge cells in spanning header rows into a single col-0 cell
if header_rows and len(columns) >= 2:
for hri in header_rows:
header_cells = [c for c in cells if c["row_index"] == hri]
if len(header_cells) <= 1:
continue
# Collect all word_boxes and text from all columns
all_wb = []
all_text_parts = []
for hc in sorted(header_cells, key=lambda c: c["col_index"]):
all_wb.extend(hc.get("word_boxes", []))
if hc.get("text", "").strip():
all_text_parts.append(hc["text"].strip())
# Remove all header cells, replace with one spanning cell
cells = [c for c in cells if c["row_index"] != hri]
if all_wb:
x_min = min(wb["left"] for wb in all_wb)
y_min = min(wb["top"] for wb in all_wb)
x_max = max(wb["left"] + wb["width"] for wb in all_wb)
y_max = max(wb["top"] + wb["height"] for wb in all_wb)
cells.append({
"cell_id": f"R{hri:02d}_C0",
"row_index": hri,
"col_index": 0,
"col_type": "spanning_header",
"text": " ".join(all_text_parts),
"confidence": 0.0,
"bbox_px": {"x": x_min, "y": y_min,
"w": x_max - x_min, "h": y_max - y_min},
"bbox_pct": {
"x": round(x_min / img_w * 100, 2) if img_w else 0,
"y": round(y_min / img_h * 100, 2) if img_h else 0,
"w": round((x_max - x_min) / img_w * 100, 2) if img_w else 0,
"h": round((y_max - y_min) / img_h * 100, 2) if img_h else 0,
},
"word_boxes": all_wb,
"ocr_engine": "words_first",
"is_bold": True,
})
# Convert columns to output format with percentages
out_columns = []
for col in columns:
x_min = col["x_min"]
x_max = col["x_max"]
out_columns.append({
"index": col["index"],
"label": col["type"],
"x_min_px": round(x_min),
"x_max_px": round(x_max),
"x_min_pct": round(x_min / img_w * 100, 2) if img_w else 0,
"x_max_pct": round(x_max / img_w * 100, 2) if img_w else 0,
"bold": False,
})
# Convert rows to output format with percentages
out_rows = []
for row in rows:
out_rows.append({
"index": row["index"],
"y_min_px": round(row["y_min"]),
"y_max_px": round(row["y_max"]),
"y_min_pct": round(row["y_min"] / img_h * 100, 2) if img_h else 0,
"y_max_pct": round(row["y_max"] / img_h * 100, 2) if img_h else 0,
"is_header": row["index"] in header_rows,
})
return {
"columns": out_columns,
"rows": out_rows,
"cells": cells,
"header_rows": header_rows,
"_raw_columns": columns, # internal: for propagation to other zones
}

View File

@@ -0,0 +1,197 @@
"""
Legal Corpus Chunking — Text splitting, semantic chunking, and HTML-to-text conversion.
Provides German-aware sentence splitting, paragraph splitting, semantic chunking
with overlap, and HTML-to-text conversion for legal document ingestion.
"""
import re
from typing import Dict, List, Optional, Tuple
# German abbreviations that don't end sentences
GERMAN_ABBREVIATIONS = {
'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max',
'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw',
'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'aufl',
'bd', 'betr', 'bzgl', 'dgl', 'ebd', 'hrsg', 'jg', 'kap', 'lt',
'rdnr', 'rn', 'std', 'str', 'tel', 'ua', 'uvm', 'va', 'zb',
'bsi', 'tr', 'owasp', 'iso', 'iec', 'din', 'en'
}
def split_into_sentences(text: str) -> List[str]:
"""Split text into sentences with German language support."""
if not text:
return []
text = re.sub(r'\s+', ' ', text).strip()
# Protect abbreviations
protected_text = text
for abbrev in GERMAN_ABBREVIATIONS:
pattern = re.compile(r'\b' + re.escape(abbrev) + r'\.', re.IGNORECASE)
protected_text = pattern.sub(abbrev.replace('.', '<DOT>') + '<ABBR>', protected_text)
# Protect decimal/ordinal numbers and requirement IDs (e.g., "O.Data_1")
protected_text = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected_text)
protected_text = re.sub(r'(\d+)\.(\s)', r'\1<ORD>\2', protected_text)
protected_text = re.sub(r'([A-Z])\.([A-Z])', r'\1<REQ>\2', protected_text) # O.Data_1
# Split on sentence endings
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜ0-9])|(?<=[.!?])$'
raw_sentences = re.split(sentence_pattern, protected_text)
# Restore protected characters
sentences = []
for s in raw_sentences:
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.').replace('<ORD>', '.').replace('<REQ>', '.')
s = s.strip()
if s:
sentences.append(s)
return sentences
def split_into_paragraphs(text: str) -> List[str]:
"""Split text into paragraphs."""
if not text:
return []
raw_paragraphs = re.split(r'\n\s*\n', text)
return [para.strip() for para in raw_paragraphs if para.strip()]
def chunk_text_semantic(
text: str,
chunk_size: int = 1000,
overlap: int = 200,
) -> List[Tuple[str, int]]:
"""
Semantic chunking that respects paragraph and sentence boundaries.
Matches NIBIS chunking strategy for consistency.
Returns list of (chunk_text, start_position) tuples.
"""
if not text:
return []
if len(text) <= chunk_size:
return [(text.strip(), 0)]
paragraphs = split_into_paragraphs(text)
overlap_sentences = max(1, overlap // 100) # Convert char overlap to sentence overlap
chunks = []
current_chunk_parts: List[str] = []
current_chunk_length = 0
chunk_start = 0
position = 0
for para in paragraphs:
if len(para) > chunk_size:
# Large paragraph: split into sentences
sentences = split_into_sentences(para)
for sentence in sentences:
sentence_len = len(sentence)
if sentence_len > chunk_size:
# Very long sentence: save current chunk first
if current_chunk_parts:
chunk_text = ' '.join(current_chunk_parts)
chunks.append((chunk_text, chunk_start))
overlap_buffer = current_chunk_parts[-overlap_sentences:] if overlap_sentences > 0 else []
current_chunk_parts = list(overlap_buffer)
current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts)
# Add long sentence as its own chunk
chunks.append((sentence, position))
current_chunk_parts = [sentence]
current_chunk_length = len(sentence) + 1
position += sentence_len + 1
continue
if current_chunk_length + sentence_len + 1 > chunk_size and current_chunk_parts:
# Current chunk is full, save it
chunk_text = ' '.join(current_chunk_parts)
chunks.append((chunk_text, chunk_start))
overlap_buffer = current_chunk_parts[-overlap_sentences:] if overlap_sentences > 0 else []
current_chunk_parts = list(overlap_buffer)
current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts)
chunk_start = position - current_chunk_length
current_chunk_parts.append(sentence)
current_chunk_length += sentence_len + 1
position += sentence_len + 1
else:
# Small paragraph: try to keep together
para_len = len(para)
if current_chunk_length + para_len + 2 > chunk_size and current_chunk_parts:
chunk_text = ' '.join(current_chunk_parts)
chunks.append((chunk_text, chunk_start))
last_para_sentences = split_into_sentences(current_chunk_parts[-1] if current_chunk_parts else "")
overlap_buffer = last_para_sentences[-overlap_sentences:] if overlap_sentences > 0 and last_para_sentences else []
current_chunk_parts = list(overlap_buffer)
current_chunk_length = sum(len(s) + 1 for s in current_chunk_parts)
chunk_start = position - current_chunk_length
if current_chunk_parts:
current_chunk_parts.append(para)
current_chunk_length += para_len + 2
else:
current_chunk_parts = [para]
current_chunk_length = para_len
chunk_start = position
position += para_len + 2
# Don't forget the last chunk
if current_chunk_parts:
chunk_text = ' '.join(current_chunk_parts)
chunks.append((chunk_text, chunk_start))
# Clean up whitespace
return [(re.sub(r'\s+', ' ', c).strip(), pos) for c, pos in chunks if c.strip()]
def extract_article_info(text: str) -> Optional[Dict]:
"""Extract article number and paragraph from text."""
# Pattern for "Artikel X" or "Art. X"
article_match = re.search(r'(?:Artikel|Art\.?)\s+(\d+)', text)
paragraph_match = re.search(r'(?:Absatz|Abs\.?)\s+(\d+)', text)
if article_match:
return {
"article": article_match.group(1),
"paragraph": paragraph_match.group(1) if paragraph_match else None,
}
return None
def html_to_text(html_content: str) -> str:
"""Convert HTML to clean text."""
# Remove script and style tags
html_content = re.sub(r'<script[^>]*>.*?</script>', '', html_content, flags=re.DOTALL)
html_content = re.sub(r'<style[^>]*>.*?</style>', '', html_content, flags=re.DOTALL)
# Remove comments
html_content = re.sub(r'<!--.*?-->', '', html_content, flags=re.DOTALL)
# Replace common HTML entities
html_content = html_content.replace('&nbsp;', ' ')
html_content = html_content.replace('&amp;', '&')
html_content = html_content.replace('&lt;', '<')
html_content = html_content.replace('&gt;', '>')
html_content = html_content.replace('&quot;', '"')
# Convert breaks and paragraphs to newlines for better chunking
html_content = re.sub(r'<br\s*/?>', '\n', html_content, flags=re.IGNORECASE)
html_content = re.sub(r'</p>', '\n\n', html_content, flags=re.IGNORECASE)
html_content = re.sub(r'</div>', '\n', html_content, flags=re.IGNORECASE)
html_content = re.sub(r'</h[1-6]>', '\n\n', html_content, flags=re.IGNORECASE)
# Remove remaining HTML tags
text = re.sub(r'<[^>]+>', ' ', html_content)
# Clean up whitespace (but preserve paragraph breaks)
text = re.sub(r'[ \t]+', ' ', text)
text = re.sub(r'\n[ \t]+', '\n', text)
text = re.sub(r'[ \t]+\n', '\n', text)
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,608 @@
"""
Legal Corpus Registry — Regulation metadata and definitions.
Pure data module: contains the Regulation dataclass and the REGULATIONS list
with all EU regulations, DACH national laws, and EDPB guidelines.
"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class Regulation:
"""Regulation metadata."""
code: str
name: str
full_name: str
regulation_type: str
source_url: str
description: str
celex: Optional[str] = None # CELEX number for EUR-Lex direct access
local_path: Optional[str] = None
language: str = "de"
requirement_count: int = 0
# All regulations from Compliance Hub (EU + DACH national laws + guidelines)
REGULATIONS: List[Regulation] = [
Regulation(
code="GDPR",
name="DSGVO",
full_name="Verordnung (EU) 2016/679 - Datenschutz-Grundverordnung",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2016/679/oj/deu",
description="Grundverordnung zum Schutz natuerlicher Personen bei der Verarbeitung personenbezogener Daten.",
celex="32016R0679",
requirement_count=99,
),
Regulation(
code="EPRIVACY",
name="ePrivacy-Richtlinie",
full_name="Richtlinie 2002/58/EG",
regulation_type="eu_directive",
source_url="https://eur-lex.europa.eu/eli/dir/2002/58/oj/deu",
description="Datenschutz in der elektronischen Kommunikation, Cookies und Tracking.",
celex="32002L0058",
requirement_count=25,
),
Regulation(
code="TDDDG",
name="TDDDG",
full_name="Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/ttdsg/TDDDG.pdf",
description="Deutsche Umsetzung der ePrivacy-Richtlinie (30 Paragraphen).",
requirement_count=30,
),
Regulation(
code="SCC",
name="Standardvertragsklauseln",
full_name="Durchfuehrungsbeschluss (EU) 2021/914",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/dec_impl/2021/914/oj/deu",
description="Standardvertragsklauseln fuer Drittlandtransfers.",
celex="32021D0914",
requirement_count=18,
),
Regulation(
code="DPF",
name="EU-US Data Privacy Framework",
full_name="Durchfuehrungsbeschluss (EU) 2023/1795",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/dec_impl/2023/1795/oj",
description="Angemessenheitsbeschluss fuer USA-Transfers.",
celex="32023D1795",
requirement_count=12,
),
Regulation(
code="AIACT",
name="EU AI Act",
full_name="Verordnung (EU) 2024/1689 - KI-Verordnung",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2024/1689/oj/deu",
description="EU-Verordnung zur Regulierung von KI-Systemen nach Risikostufen.",
celex="32024R1689",
requirement_count=85,
),
Regulation(
code="CRA",
name="Cyber Resilience Act",
full_name="Verordnung (EU) 2024/2847",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2024/2847/oj/deu",
description="Cybersicherheitsanforderungen, SBOM-Pflicht.",
celex="32024R2847",
requirement_count=45,
),
Regulation(
code="NIS2",
name="NIS2-Richtlinie",
full_name="Richtlinie (EU) 2022/2555",
regulation_type="eu_directive",
source_url="https://eur-lex.europa.eu/eli/dir/2022/2555/oj/deu",
description="Cybersicherheit fuer wesentliche Einrichtungen.",
celex="32022L2555",
requirement_count=46,
),
Regulation(
code="EUCSA",
name="EU Cybersecurity Act",
full_name="Verordnung (EU) 2019/881",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2019/881/oj/deu",
description="ENISA und Cybersicherheitszertifizierung.",
celex="32019R0881",
requirement_count=35,
),
Regulation(
code="DATAACT",
name="Data Act",
full_name="Verordnung (EU) 2023/2854",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2023/2854/oj/deu",
description="Fairer Datenzugang, IoT-Daten, Cloud-Wechsel.",
celex="32023R2854",
requirement_count=42,
),
Regulation(
code="DGA",
name="Data Governance Act",
full_name="Verordnung (EU) 2022/868",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2022/868/oj/deu",
description="Weiterverwendung oeffentlicher Daten.",
celex="32022R0868",
requirement_count=35,
),
Regulation(
code="DSA",
name="Digital Services Act",
full_name="Verordnung (EU) 2022/2065",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2022/2065/oj/deu",
description="Digitale Dienste, Transparenzpflichten.",
celex="32022R2065",
requirement_count=93,
),
Regulation(
code="EAA",
name="European Accessibility Act",
full_name="Richtlinie (EU) 2019/882",
regulation_type="eu_directive",
source_url="https://eur-lex.europa.eu/eli/dir/2019/882/oj/deu",
description="Barrierefreiheit digitaler Produkte.",
celex="32019L0882",
requirement_count=25,
),
Regulation(
code="DSM",
name="DSM-Urheberrechtsrichtlinie",
full_name="Richtlinie (EU) 2019/790",
regulation_type="eu_directive",
source_url="https://eur-lex.europa.eu/eli/dir/2019/790/oj/deu",
description="Urheberrecht, Text- und Data-Mining.",
celex="32019L0790",
requirement_count=22,
),
Regulation(
code="PLD",
name="Produkthaftungsrichtlinie",
full_name="Richtlinie (EU) 2024/2853",
regulation_type="eu_directive",
source_url="https://eur-lex.europa.eu/eli/dir/2024/2853/oj/deu",
description="Produkthaftung inkl. Software und KI.",
celex="32024L2853",
requirement_count=18,
),
Regulation(
code="GPSR",
name="General Product Safety",
full_name="Verordnung (EU) 2023/988",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2023/988/oj/deu",
description="Allgemeine Produktsicherheit.",
celex="32023R0988",
requirement_count=30,
),
Regulation(
code="BSI-TR-03161-1",
name="BSI-TR-03161 Teil 1",
full_name="BSI Technische Richtlinie - Allgemeine Anforderungen",
regulation_type="bsi_standard",
source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-1.pdf?__blob=publicationFile&v=6",
description="Allgemeine Sicherheitsanforderungen (45 Pruefaspekte).",
requirement_count=45,
),
Regulation(
code="BSI-TR-03161-2",
name="BSI-TR-03161 Teil 2",
full_name="BSI Technische Richtlinie - Web-Anwendungen",
regulation_type="bsi_standard",
source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-2.pdf?__blob=publicationFile&v=5",
description="Web-Sicherheit (40 Pruefaspekte).",
requirement_count=40,
),
Regulation(
code="BSI-TR-03161-3",
name="BSI-TR-03161 Teil 3",
full_name="BSI Technische Richtlinie - Hintergrundsysteme",
regulation_type="bsi_standard",
source_url="https://www.bsi.bund.de/SharedDocs/Downloads/DE/BSI/Publikationen/TechnischeRichtlinien/TR03161/BSI-TR-03161-3.pdf?__blob=publicationFile&v=5",
description="Backend-Sicherheit (35 Pruefaspekte).",
requirement_count=35,
),
# Additional regulations for financial sector and health
Regulation(
code="DORA",
name="DORA",
full_name="Verordnung (EU) 2022/2554 - Digital Operational Resilience Act",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2022/2554/oj/deu",
description="Digitale operationale Resilienz fuer den Finanzsektor. IKT-Risikomanagement, Vorfallmeldung, Resilienz-Tests.",
celex="32022R2554",
requirement_count=64,
),
Regulation(
code="PSD2",
name="PSD2",
full_name="Richtlinie (EU) 2015/2366 - Zahlungsdiensterichtlinie",
regulation_type="eu_directive",
source_url="https://eur-lex.europa.eu/eli/dir/2015/2366/oj/deu",
description="Zahlungsdienste im Binnenmarkt. Starke Kundenauthentifizierung, Open Banking APIs.",
celex="32015L2366",
requirement_count=117,
),
Regulation(
code="AMLR",
name="AML-Verordnung",
full_name="Verordnung (EU) 2024/1624 - Geldwaeschebekaempfung",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2024/1624/oj/deu",
description="Verhinderung der Nutzung des Finanzsystems zur Geldwaesche und Terrorismusfinanzierung.",
celex="32024R1624",
requirement_count=89,
),
Regulation(
code="EHDS",
name="EHDS",
full_name="Verordnung (EU) 2025/327 - Europaeischer Gesundheitsdatenraum",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2025/327/oj/deu",
description="Europaeischer Raum fuer Gesundheitsdaten. Primaer- und Sekundaernutzung von Gesundheitsdaten.",
celex="32025R0327",
requirement_count=95,
),
Regulation(
code="MiCA",
name="MiCA",
full_name="Verordnung (EU) 2023/1114 - Markets in Crypto-Assets",
regulation_type="eu_regulation",
source_url="https://eur-lex.europa.eu/eli/reg/2023/1114/oj/deu",
description="Regulierung von Kryptowerten, Stablecoins und Crypto-Asset-Dienstleistern.",
celex="32023R1114",
requirement_count=149,
),
# =====================================================================
# DACH National Laws — Deutschland (P1)
# =====================================================================
Regulation(
code="DE_DDG",
name="Digitale-Dienste-Gesetz",
full_name="Digitale-Dienste-Gesetz (DDG)",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/ddg/",
description="Deutsches Umsetzungsgesetz zum DSA. Regelt Impressumspflicht (§5), Informationspflichten fuer digitale Dienste und Cookies.",
requirement_count=30,
),
Regulation(
code="DE_BGB_AGB",
name="BGB AGB-Recht",
full_name="BGB §§305-310, 312-312k — AGB und Fernabsatz",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/bgb/",
description="Deutsches AGB-Recht (§§305-310 BGB) und Fernabsatzrecht (§§312-312k BGB). Klauselverbote, Inhaltskontrolle, Widerrufsrecht, Button-Loesung.",
local_path="DE_BGB_AGB.txt",
requirement_count=40,
),
Regulation(
code="DE_EGBGB",
name="EGBGB Art. 246-248",
full_name="Einfuehrungsgesetz zum BGB — Informationspflichten",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/bgbeg/",
description="Informationspflichten bei Verbrauchervertraegen (Art. 246), Fernabsatz (Art. 246a), E-Commerce (Art. 246c).",
local_path="DE_EGBGB.txt",
requirement_count=20,
),
Regulation(
code="DE_UWG",
name="UWG Deutschland",
full_name="Gesetz gegen den unlauteren Wettbewerb (UWG)",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/uwg_2004/",
description="Unlauterer Wettbewerb: irrefuehrende Werbung, Spam-Verbot, Preisangaben, Online-Marketing-Regeln.",
requirement_count=25,
),
Regulation(
code="DE_HGB_RET",
name="HGB Aufbewahrung",
full_name="HGB §§238-261, 257 — Handelsbuecher und Aufbewahrungsfristen",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/hgb/",
description="Buchfuehrungspflicht, Aufbewahrungsfristen 6/10 Jahre, Anforderungen an elektronische Aufbewahrung.",
local_path="DE_HGB_RET.txt",
requirement_count=15,
),
Regulation(
code="DE_AO_RET",
name="AO Aufbewahrung",
full_name="Abgabenordnung §§140-148 — Steuerliche Aufbewahrungspflichten",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/ao_1977/",
description="Steuerliche Buchfuehrungs- und Aufbewahrungspflichten. 6/10 Jahre Fristen, Datenzugriff durch Finanzbehoerden.",
local_path="DE_AO_RET.txt",
requirement_count=12,
),
Regulation(
code="DE_TKG",
name="TKG 2021",
full_name="Telekommunikationsgesetz 2021",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/tkg_2021/",
description="Telekommunikationsregulierung: Kundenschutz, Datenschutz, Vertragslaufzeiten, Netzinfrastruktur.",
requirement_count=45,
),
# =====================================================================
# DACH National Laws — Oesterreich (P1)
# =====================================================================
Regulation(
code="AT_ECG",
name="E-Commerce-Gesetz AT",
full_name="E-Commerce-Gesetz (ECG) Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20001703",
description="Oesterreichisches E-Commerce-Gesetz: Impressum/Offenlegungspflicht (§5), Informationspflichten, Haftung von Diensteanbietern.",
language="de",
requirement_count=30,
),
Regulation(
code="AT_TKG",
name="TKG 2021 AT",
full_name="Telekommunikationsgesetz 2021 Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20011678",
description="Oesterreichisches TKG: Cookie-Bestimmungen (§165), Kommunikationsgeheimnis, Endgeraetezugriff.",
language="de",
requirement_count=40,
),
Regulation(
code="AT_KSCHG",
name="KSchG Oesterreich",
full_name="Konsumentenschutzgesetz (KSchG) Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10002462",
description="Konsumentenschutz: AGB-Kontrolle (§6 Klauselverbote, §9 Verbandsklage), Ruecktrittsrecht, Informationspflichten.",
language="de",
requirement_count=35,
),
Regulation(
code="AT_FAGG",
name="FAGG Oesterreich",
full_name="Fern- und Auswaertsgeschaefte-Gesetz (FAGG) Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=20008847",
description="Fernabsatzrecht: Informationspflichten, Widerrufsrecht 14 Tage, Button-Loesung, Ausnahmen.",
language="de",
requirement_count=20,
),
Regulation(
code="AT_UGB_RET",
name="UGB Aufbewahrung AT",
full_name="UGB §§189-216, 212 — Rechnungslegung und Aufbewahrung Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10001702",
description="Oesterreichische Rechnungslegungspflicht und Aufbewahrungsfristen (7 Jahre). Buchfuehrung, Jahresabschluss.",
local_path="AT_UGB_RET.txt",
language="de",
requirement_count=15,
),
Regulation(
code="AT_BAO_RET",
name="BAO §132 AT",
full_name="Bundesabgabenordnung §132 — Aufbewahrung Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10003940",
description="Steuerliche Aufbewahrungspflicht 7 Jahre fuer Buecher, Aufzeichnungen und Belege. Grundstuecke 22 Jahre.",
language="de",
requirement_count=5,
),
Regulation(
code="AT_MEDIENG",
name="MedienG §§24-25 AT",
full_name="Mediengesetz §§24-25 Oesterreich — Impressum und Offenlegung",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10000719",
description="Impressum/Offenlegungspflicht fuer periodische Medien und Websites in Oesterreich.",
language="de",
requirement_count=10,
),
# =====================================================================
# DACH National Laws — Schweiz (P1)
# =====================================================================
Regulation(
code="CH_DSV",
name="DSV Schweiz",
full_name="Datenschutzverordnung (DSV) Schweiz — SR 235.11",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/2022/568/de",
description="Ausfuehrungsverordnung zum revDSG: Meldepflichten, DSFA-Verfahren, Auslandtransfers, technische Massnahmen.",
language="de",
requirement_count=30,
),
Regulation(
code="CH_OR_AGB",
name="OR AGB/Aufbewahrung CH",
full_name="Obligationenrecht — AGB-Kontrolle und Aufbewahrung Schweiz (SR 220)",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/27/317_321_377/de",
description="Art. 8 OR (AGB-Inhaltskontrolle), Art. 19/20 (Vertragsfreiheit), Art. 957-958f (Buchfuehrung, 10 Jahre Aufbewahrung).",
local_path="CH_OR_AGB.txt",
language="de",
requirement_count=20,
),
Regulation(
code="CH_UWG",
name="UWG Schweiz",
full_name="Bundesgesetz gegen den unlauteren Wettbewerb Schweiz (SR 241)",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/1988/223_223_223/de",
description="Lauterkeitsrecht: Impressumspflicht, irrefuehrende Werbung, aggressive Verkaufsmethoden, AGB-Transparenz.",
language="de",
requirement_count=20,
),
Regulation(
code="CH_FMG",
name="FMG Schweiz",
full_name="Fernmeldegesetz Schweiz (SR 784.10)",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/1997/2187_2187_2187/de",
description="Telekommunikationsregulierung: Fernmeldegeheimnis, Cookies/Tracking (Art. 45c), Spam-Verbot, Datenschutz.",
language="de",
requirement_count=25,
),
# =====================================================================
# Deutschland P2
# =====================================================================
Regulation(
code="DE_PANGV",
name="PAngV",
full_name="Preisangabenverordnung (PAngV 2022)",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/pangv_2022/",
description="Preisangaben: Gesamtpreis, Grundpreis, Streichpreise (§11), Online-Preisauszeichnung.",
requirement_count=15,
),
Regulation(
code="DE_DLINFOV",
name="DL-InfoV",
full_name="Dienstleistungs-Informationspflichten-Verordnung",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/dlinfov/",
description="Informationspflichten fuer Dienstleister: Identitaet, Kontakt, Berufshaftpflicht, AGB-Zugang.",
requirement_count=10,
),
Regulation(
code="DE_BETRVG",
name="BetrVG §87",
full_name="Betriebsverfassungsgesetz §87 Abs.1 Nr.6",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/betrvg/",
description="Mitbestimmung bei technischer Ueberwachung: Betriebsrat-Beteiligung bei IT-Systemen, die Arbeitnehmerverhalten ueberwachen koennen.",
requirement_count=5,
),
# =====================================================================
# Oesterreich P2
# =====================================================================
Regulation(
code="AT_ABGB_AGB",
name="ABGB AGB-Recht AT",
full_name="ABGB §§861-879, 864a — AGB-Kontrolle Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10001622",
description="Geltungskontrolle (§864a), Sittenwidrigkeitskontrolle (§879 Abs.3), allgemeine Vertragsregeln.",
local_path="AT_ABGB_AGB.txt",
language="de",
requirement_count=10,
),
Regulation(
code="AT_UWG",
name="UWG Oesterreich",
full_name="Bundesgesetz gegen den unlauteren Wettbewerb Oesterreich",
regulation_type="at_law",
source_url="https://www.ris.bka.gv.at/GeltendeFassung.wxe?Abfrage=Bundesnormen&Gesetzesnummer=10002665",
description="Lauterkeitsrecht AT: irrefuehrende Geschaeftspraktiken, aggressive Praktiken, Preisauszeichnung.",
language="de",
requirement_count=15,
),
# =====================================================================
# Schweiz P2
# =====================================================================
Regulation(
code="CH_GEBUV",
name="GeBuV Schweiz",
full_name="Geschaeftsbuecher-Verordnung Schweiz (SR 221.431)",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/2002/468_468_468/de",
description="Ausfuehrungsvorschriften zur Buchfuehrung: elektronische Aufbewahrung, Integritaet, Datentraeger.",
language="de",
requirement_count=10,
),
Regulation(
code="CH_ZERTES",
name="ZertES Schweiz",
full_name="Bundesgesetz ueber die elektronische Signatur (SR 943.03)",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/2016/752/de",
description="Elektronische Signatur und Zertifizierung: Qualifizierte Signaturen, Zertifizierungsdiensteanbieter.",
language="de",
requirement_count=10,
),
# =====================================================================
# Deutschland P3
# =====================================================================
Regulation(
code="DE_GESCHGEHG",
name="GeschGehG",
full_name="Gesetz zum Schutz von Geschaeftsgeheimnissen",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/geschgehg/",
description="Schutz von Geschaeftsgeheimnissen: Definition, angemessene Geheimhaltungsmassnahmen, Reverse Engineering.",
requirement_count=10,
),
Regulation(
code="DE_BSIG",
name="BSI-Gesetz",
full_name="Gesetz ueber das Bundesamt fuer Sicherheit in der Informationstechnik (BSIG)",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/bsig_2009/",
description="BSI-Aufgaben, KRITIS-Meldepflichten, IT-Sicherheitsstandards, Zertifizierung.",
requirement_count=20,
),
Regulation(
code="DE_USTG_RET",
name="UStG §14b",
full_name="Umsatzsteuergesetz §14b — Aufbewahrung von Rechnungen",
regulation_type="de_law",
source_url="https://www.gesetze-im-internet.de/ustg_1980/",
description="Aufbewahrungspflicht fuer Rechnungen: 10 Jahre, Grundstuecke 20 Jahre, elektronische Aufbewahrung.",
local_path="DE_USTG_RET.txt",
requirement_count=5,
),
# =====================================================================
# Schweiz P3
# =====================================================================
Regulation(
code="CH_ZGB_PERS",
name="ZGB Persoenlichkeitsschutz CH",
full_name="Zivilgesetzbuch Art. 28-28l — Persoenlichkeitsschutz Schweiz (SR 210)",
regulation_type="ch_law",
source_url="https://www.fedlex.admin.ch/eli/cc/24/233_245_233/de",
description="Persoenlichkeitsschutz: Recht am eigenen Bild, Schutz der Privatsphaere, Gegendarstellungsrecht.",
language="de",
requirement_count=8,
),
# =====================================================================
# 3 fehlgeschlagene Quellen mit alternativen URLs nachholen
# =====================================================================
Regulation(
code="LU_DPA_LAW",
name="Datenschutzgesetz Luxemburg",
full_name="Loi du 1er aout 2018 — Datenschutzgesetz Luxemburg",
regulation_type="national_law",
source_url="https://legilux.public.lu/eli/etat/leg/loi/2018/08/01/a686/jo",
description="Luxemburgisches Datenschutzgesetz: Organisation der CNPD, nationale DSGVO-Ergaenzung.",
language="fr",
requirement_count=40,
),
Regulation(
code="DK_DATABESKYTTELSESLOVEN",
name="Databeskyttelsesloven DK",
full_name="Databeskyttelsesloven — Datenschutzgesetz Daenemark",
regulation_type="national_law",
source_url="https://www.retsinformation.dk/eli/lta/2018/502",
description="Daenisches Datenschutzgesetz als ergaenzende Bestimmungen zur DSGVO. Reguliert durch Datatilsynet.",
language="da",
requirement_count=30,
),
Regulation(
code="EDPB_GUIDELINES_1_2022",
name="EDPB GL Bussgelder",
full_name="EDPB Leitlinien 04/2022 zur Berechnung von Bussgeldern nach der DSGVO",
regulation_type="eu_guideline",
source_url="https://www.edpb.europa.eu/system/files/2023-05/edpb_guidelines_042022_calculationofadministrativefines_en.pdf",
description="EDPB-Leitlinien zur Berechnung von Verwaltungsbussgeldern unter der DSGVO.",
language="en",
requirement_count=15,
),
]

View File

@@ -0,0 +1,485 @@
"""
Worksheet Editor AI — AI image generation and AI worksheet modification.
"""
import io
import json
import base64
import logging
import re
import time
import random
from typing import List, Dict
import httpx
from worksheet_editor_models import (
AIImageRequest,
AIImageResponse,
AIImageStyle,
AIModifyRequest,
AIModifyResponse,
OLLAMA_URL,
STYLE_PROMPTS,
)
logger = logging.getLogger(__name__)
# =============================================
# AI IMAGE GENERATION
# =============================================
async def generate_ai_image_logic(request: AIImageRequest) -> AIImageResponse:
"""
Generate an AI image using Ollama with a text-to-image model.
Falls back to a placeholder if Ollama is not available.
"""
from fastapi import HTTPException
try:
# Build enhanced prompt with style
style_modifier = STYLE_PROMPTS.get(request.style, "")
enhanced_prompt = f"{request.prompt}, {style_modifier}"
logger.info(f"Generating AI image: {enhanced_prompt[:100]}...")
# Check if Ollama is available
async with httpx.AsyncClient(timeout=10.0) as check_client:
try:
health_response = await check_client.get(f"{OLLAMA_URL}/api/tags")
if health_response.status_code != 200:
raise HTTPException(status_code=503, detail="Ollama service not available")
except httpx.ConnectError:
logger.warning("Ollama not reachable, returning placeholder")
return _generate_placeholder_image(request, enhanced_prompt)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
tags_response = await client.get(f"{OLLAMA_URL}/api/tags")
available_models = [m.get("name", "") for m in tags_response.json().get("models", [])]
sd_model = None
for model in available_models:
if "stable" in model.lower() or "sd" in model.lower() or "diffusion" in model.lower():
sd_model = model
break
if not sd_model:
logger.warning("No Stable Diffusion model found in Ollama")
return _generate_placeholder_image(request, enhanced_prompt)
logger.info(f"SD model found: {sd_model}, but image generation API not implemented")
return _generate_placeholder_image(request, enhanced_prompt)
except Exception as e:
logger.error(f"Image generation failed: {e}")
return _generate_placeholder_image(request, enhanced_prompt)
except HTTPException:
raise
except Exception as e:
logger.error(f"AI image generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _generate_placeholder_image(request: AIImageRequest, prompt: str) -> AIImageResponse:
"""
Generate a placeholder image when AI generation is not available.
Creates a simple SVG-based placeholder with the prompt text.
"""
from PIL import Image, ImageDraw, ImageFont
width, height = request.width, request.height
style_colors = {
AIImageStyle.REALISTIC: ("#2563eb", "#dbeafe"),
AIImageStyle.CARTOON: ("#f97316", "#ffedd5"),
AIImageStyle.SKETCH: ("#6b7280", "#f3f4f6"),
AIImageStyle.CLIPART: ("#8b5cf6", "#ede9fe"),
AIImageStyle.EDUCATIONAL: ("#059669", "#d1fae5"),
}
fg_color, bg_color = style_colors.get(request.style, ("#6366f1", "#e0e7ff"))
img = Image.new('RGB', (width, height), bg_color)
draw = ImageDraw.Draw(img)
draw.rectangle([5, 5, width-6, height-6], outline=fg_color, width=3)
cx, cy = width // 2, height // 2 - 30
draw.ellipse([cx-40, cy-40, cx+40, cy+40], outline=fg_color, width=3)
draw.line([cx-20, cy-10, cx+20, cy-10], fill=fg_color, width=3)
draw.line([cx, cy-10, cx, cy+20], fill=fg_color, width=3)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
except Exception:
font = ImageFont.load_default()
max_chars = 40
lines = []
words = prompt[:200].split()
current_line = ""
for word in words:
if len(current_line) + len(word) + 1 <= max_chars:
current_line += (" " + word if current_line else word)
else:
if current_line:
lines.append(current_line)
current_line = word
if current_line:
lines.append(current_line)
text_y = cy + 60
for line in lines[:4]:
bbox = draw.textbbox((0, 0), line, font=font)
text_width = bbox[2] - bbox[0]
draw.text((cx - text_width // 2, text_y), line, fill=fg_color, font=font)
text_y += 20
badge_text = "KI-Bild (Platzhalter)"
try:
badge_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
except Exception:
badge_font = font
draw.rectangle([10, height-30, 150, height-10], fill=fg_color)
draw.text((15, height-27), badge_text, fill="white", font=badge_font)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)
image_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
return AIImageResponse(
image_base64=image_base64,
prompt_used=prompt,
error="AI image generation not available. Using placeholder."
)
# =============================================
# AI WORKSHEET MODIFICATION
# =============================================
async def modify_worksheet_with_ai_logic(request: AIModifyRequest) -> AIModifyResponse:
"""
Modify a worksheet using AI based on natural language prompt.
"""
try:
logger.info(f"AI modify request: {request.prompt[:100]}...")
try:
canvas_data = json.loads(request.canvas_json)
except json.JSONDecodeError:
return AIModifyResponse(
message="Fehler beim Parsen des Canvas",
error="Invalid canvas JSON"
)
system_prompt = """Du bist ein Assistent fuer die Bearbeitung von Arbeitsblaettern.
Du erhaeltst den aktuellen Zustand eines Canvas im JSON-Format und eine Anweisung des Nutzers.
Deine Aufgabe ist es, die gewuenschten Aenderungen am Canvas vorzunehmen.
Der Canvas verwendet Fabric.js. Hier sind die wichtigsten Objekttypen:
- i-text: Interaktiver Text mit fontFamily, fontSize, fill, left, top
- rect: Rechteck mit left, top, width, height, fill, stroke, strokeWidth
- circle: Kreis mit left, top, radius, fill, stroke, strokeWidth
- line: Linie mit x1, y1, x2, y2, stroke, strokeWidth
Das Canvas ist 794x1123 Pixel (A4 bei 96 DPI).
Antworte NUR mit einem JSON-Objekt in diesem Format:
{
"action": "modify" oder "add" oder "delete" oder "info",
"objects": [...], // Neue/modifizierte Objekte (bei modify/add)
"message": "Kurze Beschreibung der Aenderung"
}
Wenn du Objekte hinzufuegst, generiere eindeutige IDs im Format "obj_<timestamp>_<random>".
"""
user_prompt = f"""Aktueller Canvas-Zustand:
```json
{json.dumps(canvas_data, indent=2)[:5000]}
```
Nutzer-Anweisung: {request.prompt}
Fuehre die Aenderung durch und antworte mit dem JSON-Objekt."""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{OLLAMA_URL}/api/generate",
json={
"model": request.model,
"prompt": user_prompt,
"system": system_prompt,
"stream": False,
"options": {
"temperature": 0.3,
"num_predict": 4096
}
}
)
if response.status_code != 200:
logger.warning(f"Ollama error: {response.status_code}, trying local fallback")
return _handle_simple_modification(request.prompt, canvas_data)
ai_response = response.json().get("response", "")
except httpx.ConnectError:
logger.warning("Ollama not reachable")
return _handle_simple_modification(request.prompt, canvas_data)
except httpx.TimeoutException:
logger.warning("Ollama timeout, trying local fallback")
return _handle_simple_modification(request.prompt, canvas_data)
try:
json_start = ai_response.find('{')
json_end = ai_response.rfind('}') + 1
if json_start == -1 or json_end <= json_start:
logger.warning(f"No JSON found in AI response: {ai_response[:200]}")
return AIModifyResponse(
message="KI konnte die Anfrage nicht verarbeiten",
error="No JSON in response"
)
ai_json = json.loads(ai_response[json_start:json_end])
action = ai_json.get("action", "info")
message = ai_json.get("message", "Aenderungen angewendet")
new_objects = ai_json.get("objects", [])
if action == "info":
return AIModifyResponse(message=message)
if action == "add" and new_objects:
existing_objects = canvas_data.get("objects", [])
existing_objects.extend(new_objects)
canvas_data["objects"] = existing_objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
if action == "modify" and new_objects:
existing_objects = canvas_data.get("objects", [])
new_ids = {obj.get("id") for obj in new_objects if obj.get("id")}
kept_objects = [obj for obj in existing_objects if obj.get("id") not in new_ids]
kept_objects.extend(new_objects)
canvas_data["objects"] = kept_objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
if action == "delete":
delete_ids = ai_json.get("delete_ids", [])
if delete_ids:
existing_objects = canvas_data.get("objects", [])
canvas_data["objects"] = [obj for obj in existing_objects if obj.get("id") not in delete_ids]
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
return AIModifyResponse(message=message)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse AI JSON: {e}")
return AIModifyResponse(
message="Fehler beim Verarbeiten der KI-Antwort",
error=str(e)
)
except Exception as e:
logger.error(f"AI modify error: {e}")
return AIModifyResponse(
message="Ein unerwarteter Fehler ist aufgetreten",
error=str(e)
)
def _handle_simple_modification(prompt: str, canvas_data: dict) -> AIModifyResponse:
"""
Handle simple modifications locally when Ollama is not available.
Supports basic commands like adding headings, lines, etc.
"""
prompt_lower = prompt.lower()
objects = canvas_data.get("objects", [])
def generate_id():
return f"obj_{int(time.time()*1000)}_{random.randint(1000, 9999)}"
# Add heading
if "ueberschrift" in prompt_lower or "titel" in prompt_lower or "heading" in prompt_lower:
text_match = re.search(r'"([^"]+)"', prompt)
text = text_match.group(1) if text_match else "Ueberschrift"
new_text = {
"type": "i-text", "id": generate_id(), "text": text,
"left": 397, "top": 50, "originX": "center",
"fontFamily": "Arial", "fontSize": 28, "fontWeight": "bold", "fill": "#000000"
}
objects.append(new_text)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"Ueberschrift '{text}' hinzugefuegt"
)
# Add lines for writing
if "linie" in prompt_lower or "line" in prompt_lower or "schreib" in prompt_lower:
num_match = re.search(r'(\d+)', prompt)
num_lines = int(num_match.group(1)) if num_match else 5
num_lines = min(num_lines, 20)
start_y = 150
line_spacing = 40
for i in range(num_lines):
new_line = {
"type": "line", "id": generate_id(),
"x1": 60, "y1": start_y + i * line_spacing,
"x2": 734, "y2": start_y + i * line_spacing,
"stroke": "#cccccc", "strokeWidth": 1
}
objects.append(new_line)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{num_lines} Schreiblinien hinzugefuegt"
)
# Make text bigger
if "groesser" in prompt_lower or "bigger" in prompt_lower or "larger" in prompt_lower:
modified = 0
for obj in objects:
if obj.get("type") in ["i-text", "text", "textbox"]:
current_size = obj.get("fontSize", 16)
obj["fontSize"] = int(current_size * 1.25)
modified += 1
canvas_data["objects"] = objects
if modified > 0:
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{modified} Texte vergroessert"
)
# Center elements
if "zentrier" in prompt_lower or "center" in prompt_lower or "mitte" in prompt_lower:
center_x = 397
for obj in objects:
if not obj.get("isGrid"):
obj["left"] = center_x
obj["originX"] = "center"
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message="Elemente zentriert"
)
# Add numbering
if "nummer" in prompt_lower or "nummerier" in prompt_lower or "1-10" in prompt_lower:
range_match = re.search(r'(\d+)\s*[-bis]+\s*(\d+)', prompt)
if range_match:
start, end = int(range_match.group(1)), int(range_match.group(2))
else:
start, end = 1, 10
y = 100
for i in range(start, min(end + 1, start + 20)):
new_text = {
"type": "i-text", "id": generate_id(), "text": f"{i}.",
"left": 40, "top": y, "fontFamily": "Arial", "fontSize": 14, "fill": "#000000"
}
objects.append(new_text)
y += 35
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"Nummerierung {start}-{end} hinzugefuegt"
)
# Add rectangle/box
if "rechteck" in prompt_lower or "box" in prompt_lower or "kasten" in prompt_lower:
new_rect = {
"type": "rect", "id": generate_id(),
"left": 100, "top": 200, "width": 200, "height": 100,
"fill": "transparent", "stroke": "#000000", "strokeWidth": 2
}
objects.append(new_rect)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message="Rechteck hinzugefuegt"
)
# Add grid/raster
if "raster" in prompt_lower or "grid" in prompt_lower or "tabelle" in prompt_lower:
dim_match = re.search(r'(\d+)\s*[x/\u00d7\*mal by]\s*(\d+)', prompt_lower)
if dim_match:
cols = int(dim_match.group(1))
rows = int(dim_match.group(2))
else:
nums = re.findall(r'(\d+)', prompt)
if len(nums) >= 2:
cols, rows = int(nums[0]), int(nums[1])
else:
cols, rows = 3, 4
cols = min(max(1, cols), 10)
rows = min(max(1, rows), 15)
canvas_width = 794
canvas_height = 1123
margin = 60
available_width = canvas_width - 2 * margin
available_height = canvas_height - 2 * margin - 80
cell_width = available_width / cols
cell_height = min(available_height / rows, 80)
start_x = margin
start_y = 120
grid_objects = []
for r in range(rows + 1):
y = start_y + r * cell_height
grid_objects.append({
"type": "line", "id": generate_id(),
"x1": start_x, "y1": y,
"x2": start_x + cols * cell_width, "y2": y,
"stroke": "#666666", "strokeWidth": 1, "isGrid": True
})
for c in range(cols + 1):
x = start_x + c * cell_width
grid_objects.append({
"type": "line", "id": generate_id(),
"x1": x, "y1": start_y,
"x2": x, "y2": start_y + rows * cell_height,
"stroke": "#666666", "strokeWidth": 1, "isGrid": True
})
objects.extend(grid_objects)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{cols}x{rows} Raster hinzugefuegt ({cols} Spalten, {rows} Zeilen)"
)
# Default: Ollama needed
return AIModifyResponse(
message="Diese Aenderung erfordert den KI-Service. Bitte stellen Sie sicher, dass Ollama laeuft.",
error="Complex modification requires Ollama"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,133 @@
"""
Worksheet Editor Models — Enums, Pydantic models, and configuration.
"""
import os
import logging
from typing import Optional, List, Dict
from enum import Enum
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# =============================================
# CONFIGURATION
# =============================================
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
SD_MODEL = os.getenv("SD_MODEL", "stable-diffusion") # or specific SD model
WORKSHEET_STORAGE_DIR = os.getenv("WORKSHEET_STORAGE_DIR",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "worksheet-storage"))
# Ensure storage directory exists
os.makedirs(WORKSHEET_STORAGE_DIR, exist_ok=True)
# =============================================
# ENUMS & MODELS
# =============================================
class AIImageStyle(str, Enum):
REALISTIC = "realistic"
CARTOON = "cartoon"
SKETCH = "sketch"
CLIPART = "clipart"
EDUCATIONAL = "educational"
class WorksheetStatus(str, Enum):
DRAFT = "draft"
PUBLISHED = "published"
ARCHIVED = "archived"
# Style prompt modifiers
STYLE_PROMPTS = {
AIImageStyle.REALISTIC: "photorealistic, high detail, professional photography",
AIImageStyle.CARTOON: "cartoon style, colorful, child-friendly, simple shapes",
AIImageStyle.SKETCH: "pencil sketch, hand-drawn, black and white, artistic",
AIImageStyle.CLIPART: "clipart style, flat design, simple, vector-like",
AIImageStyle.EDUCATIONAL: "educational illustration, clear, informative, textbook style"
}
# =============================================
# REQUEST/RESPONSE MODELS
# =============================================
class AIImageRequest(BaseModel):
prompt: str = Field(..., min_length=3, max_length=500)
style: AIImageStyle = AIImageStyle.EDUCATIONAL
width: int = Field(512, ge=256, le=1024)
height: int = Field(512, ge=256, le=1024)
class AIImageResponse(BaseModel):
image_base64: str
prompt_used: str
error: Optional[str] = None
class PageData(BaseModel):
id: str
index: int
canvasJSON: str
class PageFormat(BaseModel):
width: float = 210
height: float = 297
orientation: str = "portrait"
margins: Dict[str, float] = {"top": 15, "right": 15, "bottom": 15, "left": 15}
class WorksheetSaveRequest(BaseModel):
id: Optional[str] = None
title: str
description: Optional[str] = None
pages: List[PageData]
pageFormat: Optional[PageFormat] = None
class WorksheetResponse(BaseModel):
id: str
title: str
description: Optional[str]
pages: List[PageData]
pageFormat: PageFormat
createdAt: str
updatedAt: str
class AIModifyRequest(BaseModel):
prompt: str = Field(..., min_length=3, max_length=1000)
canvas_json: str
model: str = "qwen2.5vl:32b"
class AIModifyResponse(BaseModel):
modified_canvas_json: Optional[str] = None
message: str
error: Optional[str] = None
class ReconstructRequest(BaseModel):
session_id: str
page_number: int = 1
include_images: bool = True
regenerate_graphics: bool = False
class ReconstructResponse(BaseModel):
canvas_json: str
page_width: int
page_height: int
elements_count: int
vocabulary_matched: int
message: str
error: Optional[str] = None
# =============================================
# IN-MEMORY STORAGE (Development)
# =============================================
worksheets_db: Dict[str, Dict] = {}
# PDF Generation availability
try:
from reportlab.lib import colors # noqa: F401
from reportlab.lib.pagesizes import A4 # noqa: F401
from reportlab.lib.units import mm # noqa: F401
from reportlab.pdfgen import canvas # noqa: F401
from reportlab.lib.styles import getSampleStyleSheet # noqa: F401
REPORTLAB_AVAILABLE = True
except ImportError:
REPORTLAB_AVAILABLE = False

View File

@@ -0,0 +1,255 @@
"""
Worksheet Editor Reconstruct — Document reconstruction from vocab sessions.
"""
import io
import uuid
import base64
import logging
from typing import List, Dict
import numpy as np
from worksheet_editor_models import (
ReconstructRequest,
ReconstructResponse,
)
logger = logging.getLogger(__name__)
async def reconstruct_document_logic(request: ReconstructRequest) -> ReconstructResponse:
"""
Reconstruct a document from a vocab session into Fabric.js canvas format.
This function:
1. Loads the original PDF from the vocab session
2. Runs OCR with position tracking
3. Creates Fabric.js canvas JSON with positioned elements
4. Maps extracted vocabulary to their positions
Returns ReconstructResponse ready to send to the client.
"""
from fastapi import HTTPException
from vocab_worksheet_api import _sessions, convert_pdf_page_to_image
# Check if session exists
if request.session_id not in _sessions:
raise HTTPException(status_code=404, detail=f"Session {request.session_id} not found")
session = _sessions[request.session_id]
if not session.get("pdf_data"):
raise HTTPException(status_code=400, detail="Session has no PDF data")
pdf_data = session["pdf_data"]
page_count = session.get("pdf_page_count", 1)
if request.page_number < 1 or request.page_number > page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} not found. PDF has {page_count} pages."
)
vocabulary = session.get("vocabulary", [])
page_vocab = [v for v in vocabulary if v.get("source_page") == request.page_number]
logger.info(f"Reconstructing page {request.page_number} from session {request.session_id}")
logger.info(f"Found {len(page_vocab)} vocabulary items for this page")
image_bytes = await convert_pdf_page_to_image(pdf_data, request.page_number)
if not image_bytes:
raise HTTPException(status_code=500, detail="Failed to convert PDF page to image")
from PIL import Image
img = Image.open(io.BytesIO(image_bytes))
img_width, img_height = img.size
from hybrid_vocab_extractor import run_paddle_ocr
ocr_regions, raw_text = run_paddle_ocr(image_bytes)
logger.info(f"OCR found {len(ocr_regions)} text regions")
A4_WIDTH = 794
A4_HEIGHT = 1123
scale_x = A4_WIDTH / img_width
scale_y = A4_HEIGHT / img_height
fabric_objects = []
# 1. Add white background
fabric_objects.append({
"type": "rect", "left": 0, "top": 0,
"width": A4_WIDTH, "height": A4_HEIGHT,
"fill": "#ffffff", "selectable": False,
"evented": False, "isBackground": True
})
# 2. Group OCR regions by Y-coordinate to detect rows
sorted_regions = sorted(ocr_regions, key=lambda r: (r.y1, r.x1))
# 3. Detect headers (larger text at top)
headers = []
for region in sorted_regions:
height = region.y2 - region.y1
if region.y1 < img_height * 0.15 and height > 30:
headers.append(region)
# 4. Create text objects for each region
vocab_matched = 0
for region in sorted_regions:
left = int(region.x1 * scale_x)
top = int(region.y1 * scale_y)
is_header = region in headers
region_height = region.y2 - region.y1
base_font_size = max(10, min(32, int(region_height * scale_y * 0.8)))
if is_header:
base_font_size = max(base_font_size, 24)
is_vocab = False
vocab_match = None
for v in page_vocab:
if v.get("english", "").lower() in region.text.lower() or \
v.get("german", "").lower() in region.text.lower():
is_vocab = True
vocab_match = v
vocab_matched += 1
break
text_obj = {
"type": "i-text",
"id": f"text_{uuid.uuid4().hex[:8]}",
"left": left, "top": top,
"text": region.text,
"fontFamily": "Arial",
"fontSize": base_font_size,
"fontWeight": "bold" if is_header else "normal",
"fill": "#000000",
"originX": "left", "originY": "top",
}
if is_vocab and vocab_match:
text_obj["isVocabulary"] = True
text_obj["vocabularyId"] = vocab_match.get("id")
text_obj["english"] = vocab_match.get("english")
text_obj["german"] = vocab_match.get("german")
fabric_objects.append(text_obj)
# 5. If include_images, detect and extract image regions
if request.include_images:
image_regions = await _detect_image_regions(image_bytes, ocr_regions, img_width, img_height)
for i, img_region in enumerate(image_regions):
img_x1 = int(img_region["x1"])
img_y1 = int(img_region["y1"])
img_x2 = int(img_region["x2"])
img_y2 = int(img_region["y2"])
cropped = img.crop((img_x1, img_y1, img_x2, img_y2))
buffer = io.BytesIO()
cropped.save(buffer, format='PNG')
buffer.seek(0)
img_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
fabric_objects.append({
"type": "image",
"id": f"img_{uuid.uuid4().hex[:8]}",
"left": int(img_x1 * scale_x),
"top": int(img_y1 * scale_y),
"width": int((img_x2 - img_x1) * scale_x),
"height": int((img_y2 - img_y1) * scale_y),
"src": img_base64,
"scaleX": 1, "scaleY": 1,
})
import json
canvas_data = {
"version": "6.0.0",
"objects": fabric_objects,
"background": "#ffffff"
}
return ReconstructResponse(
canvas_json=json.dumps(canvas_data),
page_width=A4_WIDTH,
page_height=A4_HEIGHT,
elements_count=len(fabric_objects),
vocabulary_matched=vocab_matched,
message=f"Reconstructed page {request.page_number} with {len(fabric_objects)} elements, "
f"{vocab_matched} vocabulary items matched"
)
async def _detect_image_regions(
image_bytes: bytes,
ocr_regions: list,
img_width: int,
img_height: int
) -> List[Dict]:
"""
Detect image/graphic regions in the document.
Uses a simple approach:
1. Find large gaps between text regions (potential image areas)
2. Use edge detection to find bounded regions
3. Filter out text areas
"""
from PIL import Image
import cv2
try:
img = Image.open(io.BytesIO(image_bytes))
img_array = np.array(img.convert('L'))
text_mask = np.ones_like(img_array, dtype=bool)
for region in ocr_regions:
x1 = max(0, region.x1 - 5)
y1 = max(0, region.y1 - 5)
x2 = min(img_width, region.x2 + 5)
y2 = min(img_height, region.y2 + 5)
text_mask[y1:y2, x1:x2] = False
image_regions = []
edges = cv2.Canny(img_array, 50, 150)
edges[~text_mask] = 0
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 50 and h > 50:
if w < img_width * 0.9 and h < img_height * 0.9:
region_content = img_array[y:y+h, x:x+w]
variance = np.var(region_content)
if variance > 500:
image_regions.append({
"x1": x, "y1": y,
"x2": x + w, "y2": y + h
})
filtered_regions = []
for region in sorted(image_regions, key=lambda r: (r["x2"]-r["x1"])*(r["y2"]-r["y1"]), reverse=True):
overlaps = False
for existing in filtered_regions:
if not (region["x2"] < existing["x1"] or region["x1"] > existing["x2"] or
region["y2"] < existing["y1"] or region["y1"] > existing["y2"]):
overlaps = True
break
if not overlaps:
filtered_regions.append(region)
logger.info(f"Detected {len(filtered_regions)} image regions")
return filtered_regions[:10]
except Exception as e:
logger.warning(f"Image region detection failed: {e}")
return []