From d093a4d3883aaa6d8b87c23827c4f6a327f4df2f Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Sat, 25 Apr 2026 23:19:11 +0200 Subject: [PATCH] Restructure: Move final 12 root files into packages (klausur-service) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ocr/spell/ (3): smart_spell, core, text upload/ (3): api, chunked, mobile crawler/ (3): github, github_core, github_parsers + unified_grid → grid/, tesseract_extractor → ocr/engines/, htr_api → ocr/pipeline/ 12 shims added. Only main.py, config.py, storage + RAG files remain at root. Co-Authored-By: Claude Opus 4.6 (1M context) --- klausur-service/backend/crawler/__init__.py | 6 + klausur-service/backend/crawler/github.py | 35 ++ .../backend/crawler/github_core.py | 411 +++++++++++++++++ .../backend/crawler/github_parsers.py | 303 +++++++++++++ klausur-service/backend/github_crawler.py | 39 +- .../backend/github_crawler_core.py | 415 +---------------- .../backend/github_crawler_parsers.py | 307 +------------ klausur-service/backend/grid/unified.py | 425 +++++++++++++++++ .../backend/handwriting_htr_api.py | 280 +----------- .../ocr/engines/tesseract_extractor.py | 346 ++++++++++++++ .../backend/ocr/pipeline/htr_api.py | 276 +++++++++++ klausur-service/backend/ocr/spell/__init__.py | 7 + klausur-service/backend/ocr/spell/core.py | 298 ++++++++++++ .../backend/ocr/spell/smart_spell.py | 25 + klausur-service/backend/ocr/spell/text.py | 289 ++++++++++++ klausur-service/backend/smart_spell.py | 29 +- klausur-service/backend/smart_spell_core.py | 302 +----------- klausur-service/backend/smart_spell_text.py | 293 +----------- .../backend/tesseract_vocab_extractor.py | 350 +------------- klausur-service/backend/unified_grid.py | 429 +----------------- klausur-service/backend/upload/__init__.py | 6 + klausur-service/backend/upload/api.py | 29 ++ klausur-service/backend/upload/chunked.py | 320 +++++++++++++ klausur-service/backend/upload/mobile.py | 292 ++++++++++++ klausur-service/backend/upload_api.py | 33 +- klausur-service/backend/upload_api_chunked.py | 324 +------------ klausur-service/backend/upload_api_mobile.py | 296 +----------- 27 files changed, 3116 insertions(+), 3049 deletions(-) create mode 100644 klausur-service/backend/crawler/__init__.py create mode 100644 klausur-service/backend/crawler/github.py create mode 100644 klausur-service/backend/crawler/github_core.py create mode 100644 klausur-service/backend/crawler/github_parsers.py create mode 100644 klausur-service/backend/grid/unified.py create mode 100644 klausur-service/backend/ocr/engines/tesseract_extractor.py create mode 100644 klausur-service/backend/ocr/pipeline/htr_api.py create mode 100644 klausur-service/backend/ocr/spell/__init__.py create mode 100644 klausur-service/backend/ocr/spell/core.py create mode 100644 klausur-service/backend/ocr/spell/smart_spell.py create mode 100644 klausur-service/backend/ocr/spell/text.py create mode 100644 klausur-service/backend/upload/__init__.py create mode 100644 klausur-service/backend/upload/api.py create mode 100644 klausur-service/backend/upload/chunked.py create mode 100644 klausur-service/backend/upload/mobile.py diff --git a/klausur-service/backend/crawler/__init__.py b/klausur-service/backend/crawler/__init__.py new file mode 100644 index 0000000..3dfe7a9 --- /dev/null +++ b/klausur-service/backend/crawler/__init__.py @@ -0,0 +1,6 @@ +""" +Crawler package — GitHub repository crawler for legal templates. + +Moved from backend/ flat modules (github_crawler*.py). +Backward-compatible shim files remain at the old locations. +""" diff --git a/klausur-service/backend/crawler/github.py b/klausur-service/backend/crawler/github.py new file mode 100644 index 0000000..b8b35de --- /dev/null +++ b/klausur-service/backend/crawler/github.py @@ -0,0 +1,35 @@ +""" +GitHub Repository Crawler — Barrel Re-export + +Split into: +- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser +- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source + +All public names are re-exported here for backward compatibility. +""" + +# Parsers +from .github_parsers import ( # noqa: F401 + ExtractedDocument, + MarkdownParser, + HTMLParser, + JSONParser, +) + +# Crawler and downloader +from .github_core import ( # noqa: F401 + GITHUB_API_URL, + GITLAB_API_URL, + GITHUB_TOKEN, + MAX_FILE_SIZE, + REQUEST_TIMEOUT, + RATE_LIMIT_DELAY, + GitHubCrawler, + RepositoryDownloader, + crawl_source, + main, +) + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/klausur-service/backend/crawler/github_core.py b/klausur-service/backend/crawler/github_core.py new file mode 100644 index 0000000..6b7b3fd --- /dev/null +++ b/klausur-service/backend/crawler/github_core.py @@ -0,0 +1,411 @@ +""" +GitHub Crawler - Core Crawler and Downloader + +GitHubCrawler for API-based repository crawling and RepositoryDownloader +for ZIP-based local extraction. + +Extracted from github_crawler.py to keep files under 500 LOC. +""" + +import asyncio +import logging +import os +import shutil +import tempfile +import zipfile +from fnmatch import fnmatch +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +import httpx + +from template_sources import SourceConfig +from .github_parsers import ( + ExtractedDocument, + MarkdownParser, + HTMLParser, + JSONParser, +) + +logger = logging.getLogger(__name__) + +# Configuration +GITHUB_API_URL = "https://api.github.com" +GITLAB_API_URL = "https://gitlab.com/api/v4" +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") +MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size +REQUEST_TIMEOUT = 60.0 +RATE_LIMIT_DELAY = 1.0 + + +class GitHubCrawler: + """Crawl GitHub repositories for legal templates.""" + + def __init__(self, token: Optional[str] = None): + self.token = token or GITHUB_TOKEN + self.headers = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": "LegalTemplatesCrawler/1.0", + } + if self.token: + self.headers["Authorization"] = f"token {self.token}" + + self.http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + self.http_client = httpx.AsyncClient( + timeout=REQUEST_TIMEOUT, + headers=self.headers, + follow_redirects=True, + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.http_client: + await self.http_client.aclose() + + def _parse_repo_url(self, url: str) -> Tuple[str, str, str]: + """Parse repository URL into owner, repo, and host.""" + parsed = urlparse(url) + path_parts = parsed.path.strip('/').split('/') + + if len(path_parts) < 2: + raise ValueError(f"Invalid repository URL: {url}") + + owner = path_parts[0] + repo = path_parts[1].replace('.git', '') + + if 'gitlab' in parsed.netloc: + host = 'gitlab' + else: + host = 'github' + + return owner, repo, host + + async def get_default_branch(self, owner: str, repo: str) -> str: + """Get the default branch of a repository.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + url = f"{GITHUB_API_URL}/repos/{owner}/{repo}" + response = await self.http_client.get(url) + response.raise_for_status() + data = response.json() + return data.get("default_branch", "main") + + async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str: + """Get the latest commit SHA for a branch.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}" + response = await self.http_client.get(url) + response.raise_for_status() + data = response.json() + return data.get("sha", "") + + async def list_files( + self, + owner: str, + repo: str, + path: str = "", + branch: str = "main", + patterns: List[str] = None, + exclude_patterns: List[str] = None, + ) -> List[Dict[str, Any]]: + """List files in a repository matching the given patterns.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + patterns = patterns or ["*.md", "*.txt", "*.html"] + exclude_patterns = exclude_patterns or [] + + url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1" + response = await self.http_client.get(url) + response.raise_for_status() + data = response.json() + + files = [] + for item in data.get("tree", []): + if item["type"] != "blob": + continue + + file_path = item["path"] + + excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns) + if excluded: + continue + + matched = any(fnmatch(file_path, pattern) for pattern in patterns) + if not matched: + continue + + if item.get("size", 0) > MAX_FILE_SIZE: + logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)") + continue + + files.append({ + "path": file_path, + "sha": item["sha"], + "size": item.get("size", 0), + "url": item.get("url", ""), + }) + + return files + + async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str: + """Get the content of a file from a repository.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" + response = await self.http_client.get(url) + response.raise_for_status() + return response.text + + async def crawl_repository( + self, + source: SourceConfig, + ) -> AsyncGenerator[ExtractedDocument, None]: + """Crawl a repository and yield extracted documents.""" + if not source.repo_url: + logger.warning(f"No repo URL for source: {source.name}") + return + + try: + owner, repo, host = self._parse_repo_url(source.repo_url) + except ValueError as e: + logger.error(f"Failed to parse repo URL for {source.name}: {e}") + return + + if host == "gitlab": + logger.info(f"GitLab repos not yet supported: {source.name}") + return + + logger.info(f"Crawling repository: {owner}/{repo}") + + try: + branch = await self.get_default_branch(owner, repo) + commit_sha = await self.get_latest_commit(owner, repo, branch) + + await asyncio.sleep(RATE_LIMIT_DELAY) + + files = await self.list_files( + owner, repo, + branch=branch, + patterns=source.file_patterns, + exclude_patterns=source.exclude_patterns, + ) + + logger.info(f"Found {len(files)} matching files in {source.name}") + + for file_info in files: + await asyncio.sleep(RATE_LIMIT_DELAY) + + try: + content = await self.get_file_content( + owner, repo, file_info["path"], branch + ) + + file_path = file_info["path"] + source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}" + + if file_path.endswith('.md'): + doc = MarkdownParser.parse(content, file_path) + doc.source_url = source_url + doc.source_commit = commit_sha + yield doc + + elif file_path.endswith('.html') or file_path.endswith('.htm'): + doc = HTMLParser.parse(content, file_path) + doc.source_url = source_url + doc.source_commit = commit_sha + yield doc + + elif file_path.endswith('.json'): + docs = JSONParser.parse(content, file_path) + for doc in docs: + doc.source_url = source_url + doc.source_commit = commit_sha + yield doc + + elif file_path.endswith('.txt'): + yield ExtractedDocument( + text=content, + title=Path(file_path).stem, + file_path=file_path, + file_type="text", + source_url=source_url, + source_commit=commit_sha, + language=MarkdownParser._detect_language(content), + placeholders=MarkdownParser._find_placeholders(content), + ) + + except httpx.HTTPError as e: + logger.warning(f"Failed to fetch {file_path}: {e}") + continue + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + continue + + except httpx.HTTPError as e: + logger.error(f"HTTP error crawling {source.name}: {e}") + except Exception as e: + logger.error(f"Error crawling {source.name}: {e}") + + +class RepositoryDownloader: + """Download and extract repository archives.""" + + def __init__(self): + self.http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + self.http_client = httpx.AsyncClient( + timeout=120.0, + follow_redirects=True, + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.http_client: + await self.http_client.aclose() + + async def download_zip(self, repo_url: str, branch: str = "main") -> Path: + """Download repository as ZIP and extract to temp directory.""" + if not self.http_client: + raise RuntimeError("Downloader not initialized. Use 'async with' context.") + + parsed = urlparse(repo_url) + path_parts = parsed.path.strip('/').split('/') + owner = path_parts[0] + repo = path_parts[1].replace('.git', '') + + zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" + + logger.info(f"Downloading ZIP from {zip_url}") + + response = await self.http_client.get(zip_url) + response.raise_for_status() + + temp_dir = Path(tempfile.mkdtemp()) + zip_path = temp_dir / f"{repo}.zip" + + with open(zip_path, 'wb') as f: + f.write(response.content) + + extract_dir = temp_dir / repo + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + extracted_dirs = list(temp_dir.glob(f"{repo}-*")) + if extracted_dirs: + return extracted_dirs[0] + + return extract_dir + + async def crawl_local_directory( + self, + directory: Path, + source: SourceConfig, + base_url: str, + ) -> AsyncGenerator[ExtractedDocument, None]: + """Crawl a local directory for documents.""" + patterns = source.file_patterns or ["*.md", "*.txt", "*.html"] + exclude_patterns = source.exclude_patterns or [] + + for pattern in patterns: + for file_path in directory.rglob(pattern.replace("**/", "")): + if not file_path.is_file(): + continue + + rel_path = str(file_path.relative_to(directory)) + + excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns) + if excluded: + continue + + if file_path.stat().st_size > MAX_FILE_SIZE: + continue + + try: + content = file_path.read_text(encoding='utf-8') + except UnicodeDecodeError: + try: + content = file_path.read_text(encoding='latin-1') + except Exception: + continue + + source_url = f"{base_url}/{rel_path}" + + if file_path.suffix == '.md': + doc = MarkdownParser.parse(content, rel_path) + doc.source_url = source_url + yield doc + + elif file_path.suffix in ['.html', '.htm']: + doc = HTMLParser.parse(content, rel_path) + doc.source_url = source_url + yield doc + + elif file_path.suffix == '.json': + docs = JSONParser.parse(content, rel_path) + for doc in docs: + doc.source_url = source_url + yield doc + + elif file_path.suffix == '.txt': + yield ExtractedDocument( + text=content, + title=file_path.stem, + file_path=rel_path, + file_type="text", + source_url=source_url, + language=MarkdownParser._detect_language(content), + placeholders=MarkdownParser._find_placeholders(content), + ) + + def cleanup(self, directory: Path): + """Clean up temporary directory.""" + if directory.exists(): + shutil.rmtree(directory, ignore_errors=True) + + +async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]: + """Crawl a source configuration and return all extracted documents.""" + documents = [] + + if source.repo_url: + async with GitHubCrawler() as crawler: + async for doc in crawler.crawl_repository(source): + documents.append(doc) + + return documents + + +# CLI for testing +async def main(): + """Test crawler with a sample source.""" + from template_sources import TEMPLATE_SOURCES + + source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy") + + async with GitHubCrawler() as crawler: + count = 0 + async for doc in crawler.crawl_repository(source): + count += 1 + print(f"\n{'='*60}") + print(f"Title: {doc.title}") + print(f"Path: {doc.file_path}") + print(f"Type: {doc.file_type}") + print(f"Language: {doc.language}") + print(f"URL: {doc.source_url}") + print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}") + print(f"Text preview: {doc.text[:200]}...") + + print(f"\n\nTotal documents: {count}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/klausur-service/backend/crawler/github_parsers.py b/klausur-service/backend/crawler/github_parsers.py new file mode 100644 index 0000000..416c9eb --- /dev/null +++ b/klausur-service/backend/crawler/github_parsers.py @@ -0,0 +1,303 @@ +""" +GitHub Crawler - Document Parsers + +Markdown, HTML, and JSON parsers for extracting structured content +from legal template documents. + +Extracted from github_crawler.py to keep files under 500 LOC. +""" + +import hashlib +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + + +@dataclass +class ExtractedDocument: + """A document extracted from a repository.""" + text: str + title: str + file_path: str + file_type: str # "markdown", "html", "json", "text" + source_url: str + source_commit: Optional[str] = None + source_hash: str = "" # SHA256 of original content + sections: List[Dict[str, Any]] = field(default_factory=list) + placeholders: List[str] = field(default_factory=list) + language: str = "en" + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.source_hash and self.text: + self.source_hash = hashlib.sha256(self.text.encode()).hexdigest() + + +class MarkdownParser: + """Parse Markdown files into structured content.""" + + # Common placeholder patterns + PLACEHOLDER_PATTERNS = [ + r'\[([A-Z_]+)\]', # [COMPANY_NAME] + r'\{([a-z_]+)\}', # {company_name} + r'\{\{([a-z_]+)\}\}', # {{company_name}} + r'__([A-Z_]+)__', # __COMPANY_NAME__ + r'<([A-Z_]+)>', # + ] + + @classmethod + def parse(cls, content: str, filename: str = "") -> ExtractedDocument: + """Parse markdown content into an ExtractedDocument.""" + title = cls._extract_title(content, filename) + sections = cls._extract_sections(content) + placeholders = cls._find_placeholders(content) + language = cls._detect_language(content) + clean_text = cls._clean_for_indexing(content) + + return ExtractedDocument( + text=clean_text, + title=title, + file_path=filename, + file_type="markdown", + source_url="", + sections=sections, + placeholders=placeholders, + language=language, + ) + + @classmethod + def _extract_title(cls, content: str, filename: str) -> str: + """Extract title from markdown heading or filename.""" + h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE) + if h1_match: + return h1_match.group(1).strip() + + frontmatter_match = re.search( + r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---', + content, re.DOTALL + ) + if frontmatter_match: + return frontmatter_match.group(1).strip() + + if filename: + name = Path(filename).stem + return name.replace('-', ' ').replace('_', ' ').title() + + return "Untitled" + + @classmethod + def _extract_sections(cls, content: str) -> List[Dict[str, Any]]: + """Extract sections from markdown content.""" + sections = [] + current_section = {"heading": "", "level": 0, "content": "", "start": 0} + + for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE): + if current_section["heading"] or current_section["content"].strip(): + current_section["content"] = current_section["content"].strip() + sections.append(current_section.copy()) + + level = len(match.group(1)) + heading = match.group(2).strip() + current_section = { + "heading": heading, + "level": level, + "content": "", + "start": match.end(), + } + + if current_section["heading"] or current_section["content"].strip(): + current_section["content"] = content[current_section["start"]:].strip() + sections.append(current_section) + + return sections + + @classmethod + def _find_placeholders(cls, content: str) -> List[str]: + """Find placeholder patterns in content.""" + placeholders = set() + for pattern in cls.PLACEHOLDER_PATTERNS: + for match in re.finditer(pattern, content): + placeholder = match.group(0) + placeholders.add(placeholder) + return sorted(list(placeholders)) + + @classmethod + def _detect_language(cls, content: str) -> str: + """Detect language from content.""" + german_indicators = [ + 'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung', + 'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung', + 'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind', + ] + + lower_content = content.lower() + german_count = sum(1 for word in german_indicators if word.lower() in lower_content) + + if german_count >= 3: + return "de" + return "en" + + @classmethod + def _clean_for_indexing(cls, content: str) -> str: + """Clean markdown content for text indexing.""" + content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL) + content = re.sub(r'', '', content, flags=re.DOTALL) + content = re.sub(r'<[^>]+>', '', content) + content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) + content = re.sub(r'\*(.+?)\*', r'\1', content) + content = re.sub(r'`(.+?)`', r'\1', content) + content = re.sub(r'~~(.+?)~~', r'\1', content) + content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content) + content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content) + content = re.sub(r'\n{3,}', '\n\n', content) + content = re.sub(r' +', ' ', content) + + return content.strip() + + +class HTMLParser: + """Parse HTML files into structured content.""" + + @classmethod + def parse(cls, content: str, filename: str = "") -> ExtractedDocument: + """Parse HTML content into an ExtractedDocument.""" + title_match = re.search(r'(.+?)', content, re.IGNORECASE) + title = title_match.group(1) if title_match else Path(filename).stem + + text = cls._html_to_text(content) + placeholders = MarkdownParser._find_placeholders(text) + + lang_match = re.search(r']*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE) + language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text) + + return ExtractedDocument( + text=text, + title=title, + file_path=filename, + file_type="html", + source_url="", + placeholders=placeholders, + language=language, + ) + + @classmethod + def _html_to_text(cls, html: str) -> str: + """Convert HTML to clean text.""" + html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + html = re.sub(r'', '', html, flags=re.DOTALL) + + html = html.replace(' ', ' ') + html = html.replace('&', '&') + html = html.replace('<', '<') + html = html.replace('>', '>') + html = html.replace('"', '"') + html = html.replace(''', "'") + + html = re.sub(r'', '\n', html, flags=re.IGNORECASE) + html = re.sub(r'

', '\n\n', html, flags=re.IGNORECASE) + html = re.sub(r'', '\n', html, flags=re.IGNORECASE) + html = re.sub(r'', '\n\n', html, flags=re.IGNORECASE) + html = re.sub(r'', '\n', html, flags=re.IGNORECASE) + + html = re.sub(r'<[^>]+>', '', html) + + html = re.sub(r'[ \t]+', ' ', html) + html = re.sub(r'\n[ \t]+', '\n', html) + html = re.sub(r'[ \t]+\n', '\n', html) + html = re.sub(r'\n{3,}', '\n\n', html) + + return html.strip() + + +class JSONParser: + """Parse JSON files containing legal template data.""" + + @classmethod + def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]: + """Parse JSON content into ExtractedDocuments.""" + try: + data = json.loads(content) + except json.JSONDecodeError as e: + import logging + logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}") + return [] + + documents = [] + + if isinstance(data, dict): + documents.extend(cls._parse_dict(data, filename)) + elif isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, dict): + docs = cls._parse_dict(item, f"{filename}[{i}]") + documents.extend(docs) + + return documents + + @classmethod + def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]: + """Parse a dictionary into documents.""" + documents = [] + + text_keys = ['text', 'content', 'body', 'description', 'value'] + title_keys = ['title', 'name', 'heading', 'label', 'key'] + + text = "" + for key in text_keys: + if key in data and isinstance(data[key], str): + text = data[key] + break + + if not text: + for key, value in data.items(): + if isinstance(value, dict): + nested_docs = cls._parse_dict(value, f"{filename}.{key}") + documents.extend(nested_docs) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]") + documents.extend(nested_docs) + elif isinstance(item, str) and len(item) > 50: + documents.append(ExtractedDocument( + text=item, + title=f"{key} {i+1}", + file_path=filename, + file_type="json", + source_url="", + language=MarkdownParser._detect_language(item), + )) + return documents + + title = "" + for key in title_keys: + if key in data and isinstance(data[key], str): + title = data[key] + break + + if not title: + title = Path(filename).stem + + metadata = {} + for key, value in data.items(): + if key not in text_keys + title_keys and not isinstance(value, (dict, list)): + metadata[key] = value + + placeholders = MarkdownParser._find_placeholders(text) + language = data.get('lang', data.get('language', MarkdownParser._detect_language(text))) + + documents.append(ExtractedDocument( + text=text, + title=title, + file_path=filename, + file_type="json", + source_url="", + placeholders=placeholders, + language=language, + metadata=metadata, + )) + + return documents diff --git a/klausur-service/backend/github_crawler.py b/klausur-service/backend/github_crawler.py index bba1705..359f340 100644 --- a/klausur-service/backend/github_crawler.py +++ b/klausur-service/backend/github_crawler.py @@ -1,35 +1,4 @@ -""" -GitHub Repository Crawler — Barrel Re-export - -Split into: -- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser -- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source - -All public names are re-exported here for backward compatibility. -""" - -# Parsers -from github_crawler_parsers import ( # noqa: F401 - ExtractedDocument, - MarkdownParser, - HTMLParser, - JSONParser, -) - -# Crawler and downloader -from github_crawler_core import ( # noqa: F401 - GITHUB_API_URL, - GITLAB_API_URL, - GITHUB_TOKEN, - MAX_FILE_SIZE, - REQUEST_TIMEOUT, - RATE_LIMIT_DELAY, - GitHubCrawler, - RepositoryDownloader, - crawl_source, - main, -) - -if __name__ == "__main__": - import asyncio - asyncio.run(main()) +# Backward-compat shim -- module moved to crawler/github.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("crawler.github") diff --git a/klausur-service/backend/github_crawler_core.py b/klausur-service/backend/github_crawler_core.py index 4152d30..fb231b2 100644 --- a/klausur-service/backend/github_crawler_core.py +++ b/klausur-service/backend/github_crawler_core.py @@ -1,411 +1,4 @@ -""" -GitHub Crawler - Core Crawler and Downloader - -GitHubCrawler for API-based repository crawling and RepositoryDownloader -for ZIP-based local extraction. - -Extracted from github_crawler.py to keep files under 500 LOC. -""" - -import asyncio -import logging -import os -import shutil -import tempfile -import zipfile -from fnmatch import fnmatch -from pathlib import Path -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple -from urllib.parse import urlparse - -import httpx - -from template_sources import SourceConfig -from github_crawler_parsers import ( - ExtractedDocument, - MarkdownParser, - HTMLParser, - JSONParser, -) - -logger = logging.getLogger(__name__) - -# Configuration -GITHUB_API_URL = "https://api.github.com" -GITLAB_API_URL = "https://gitlab.com/api/v4" -GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") -MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size -REQUEST_TIMEOUT = 60.0 -RATE_LIMIT_DELAY = 1.0 - - -class GitHubCrawler: - """Crawl GitHub repositories for legal templates.""" - - def __init__(self, token: Optional[str] = None): - self.token = token or GITHUB_TOKEN - self.headers = { - "Accept": "application/vnd.github.v3+json", - "User-Agent": "LegalTemplatesCrawler/1.0", - } - if self.token: - self.headers["Authorization"] = f"token {self.token}" - - self.http_client: Optional[httpx.AsyncClient] = None - - async def __aenter__(self): - self.http_client = httpx.AsyncClient( - timeout=REQUEST_TIMEOUT, - headers=self.headers, - follow_redirects=True, - ) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.http_client: - await self.http_client.aclose() - - def _parse_repo_url(self, url: str) -> Tuple[str, str, str]: - """Parse repository URL into owner, repo, and host.""" - parsed = urlparse(url) - path_parts = parsed.path.strip('/').split('/') - - if len(path_parts) < 2: - raise ValueError(f"Invalid repository URL: {url}") - - owner = path_parts[0] - repo = path_parts[1].replace('.git', '') - - if 'gitlab' in parsed.netloc: - host = 'gitlab' - else: - host = 'github' - - return owner, repo, host - - async def get_default_branch(self, owner: str, repo: str) -> str: - """Get the default branch of a repository.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - url = f"{GITHUB_API_URL}/repos/{owner}/{repo}" - response = await self.http_client.get(url) - response.raise_for_status() - data = response.json() - return data.get("default_branch", "main") - - async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str: - """Get the latest commit SHA for a branch.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}" - response = await self.http_client.get(url) - response.raise_for_status() - data = response.json() - return data.get("sha", "") - - async def list_files( - self, - owner: str, - repo: str, - path: str = "", - branch: str = "main", - patterns: List[str] = None, - exclude_patterns: List[str] = None, - ) -> List[Dict[str, Any]]: - """List files in a repository matching the given patterns.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - patterns = patterns or ["*.md", "*.txt", "*.html"] - exclude_patterns = exclude_patterns or [] - - url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1" - response = await self.http_client.get(url) - response.raise_for_status() - data = response.json() - - files = [] - for item in data.get("tree", []): - if item["type"] != "blob": - continue - - file_path = item["path"] - - excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns) - if excluded: - continue - - matched = any(fnmatch(file_path, pattern) for pattern in patterns) - if not matched: - continue - - if item.get("size", 0) > MAX_FILE_SIZE: - logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)") - continue - - files.append({ - "path": file_path, - "sha": item["sha"], - "size": item.get("size", 0), - "url": item.get("url", ""), - }) - - return files - - async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str: - """Get the content of a file from a repository.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" - response = await self.http_client.get(url) - response.raise_for_status() - return response.text - - async def crawl_repository( - self, - source: SourceConfig, - ) -> AsyncGenerator[ExtractedDocument, None]: - """Crawl a repository and yield extracted documents.""" - if not source.repo_url: - logger.warning(f"No repo URL for source: {source.name}") - return - - try: - owner, repo, host = self._parse_repo_url(source.repo_url) - except ValueError as e: - logger.error(f"Failed to parse repo URL for {source.name}: {e}") - return - - if host == "gitlab": - logger.info(f"GitLab repos not yet supported: {source.name}") - return - - logger.info(f"Crawling repository: {owner}/{repo}") - - try: - branch = await self.get_default_branch(owner, repo) - commit_sha = await self.get_latest_commit(owner, repo, branch) - - await asyncio.sleep(RATE_LIMIT_DELAY) - - files = await self.list_files( - owner, repo, - branch=branch, - patterns=source.file_patterns, - exclude_patterns=source.exclude_patterns, - ) - - logger.info(f"Found {len(files)} matching files in {source.name}") - - for file_info in files: - await asyncio.sleep(RATE_LIMIT_DELAY) - - try: - content = await self.get_file_content( - owner, repo, file_info["path"], branch - ) - - file_path = file_info["path"] - source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}" - - if file_path.endswith('.md'): - doc = MarkdownParser.parse(content, file_path) - doc.source_url = source_url - doc.source_commit = commit_sha - yield doc - - elif file_path.endswith('.html') or file_path.endswith('.htm'): - doc = HTMLParser.parse(content, file_path) - doc.source_url = source_url - doc.source_commit = commit_sha - yield doc - - elif file_path.endswith('.json'): - docs = JSONParser.parse(content, file_path) - for doc in docs: - doc.source_url = source_url - doc.source_commit = commit_sha - yield doc - - elif file_path.endswith('.txt'): - yield ExtractedDocument( - text=content, - title=Path(file_path).stem, - file_path=file_path, - file_type="text", - source_url=source_url, - source_commit=commit_sha, - language=MarkdownParser._detect_language(content), - placeholders=MarkdownParser._find_placeholders(content), - ) - - except httpx.HTTPError as e: - logger.warning(f"Failed to fetch {file_path}: {e}") - continue - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") - continue - - except httpx.HTTPError as e: - logger.error(f"HTTP error crawling {source.name}: {e}") - except Exception as e: - logger.error(f"Error crawling {source.name}: {e}") - - -class RepositoryDownloader: - """Download and extract repository archives.""" - - def __init__(self): - self.http_client: Optional[httpx.AsyncClient] = None - - async def __aenter__(self): - self.http_client = httpx.AsyncClient( - timeout=120.0, - follow_redirects=True, - ) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.http_client: - await self.http_client.aclose() - - async def download_zip(self, repo_url: str, branch: str = "main") -> Path: - """Download repository as ZIP and extract to temp directory.""" - if not self.http_client: - raise RuntimeError("Downloader not initialized. Use 'async with' context.") - - parsed = urlparse(repo_url) - path_parts = parsed.path.strip('/').split('/') - owner = path_parts[0] - repo = path_parts[1].replace('.git', '') - - zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" - - logger.info(f"Downloading ZIP from {zip_url}") - - response = await self.http_client.get(zip_url) - response.raise_for_status() - - temp_dir = Path(tempfile.mkdtemp()) - zip_path = temp_dir / f"{repo}.zip" - - with open(zip_path, 'wb') as f: - f.write(response.content) - - extract_dir = temp_dir / repo - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - extracted_dirs = list(temp_dir.glob(f"{repo}-*")) - if extracted_dirs: - return extracted_dirs[0] - - return extract_dir - - async def crawl_local_directory( - self, - directory: Path, - source: SourceConfig, - base_url: str, - ) -> AsyncGenerator[ExtractedDocument, None]: - """Crawl a local directory for documents.""" - patterns = source.file_patterns or ["*.md", "*.txt", "*.html"] - exclude_patterns = source.exclude_patterns or [] - - for pattern in patterns: - for file_path in directory.rglob(pattern.replace("**/", "")): - if not file_path.is_file(): - continue - - rel_path = str(file_path.relative_to(directory)) - - excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns) - if excluded: - continue - - if file_path.stat().st_size > MAX_FILE_SIZE: - continue - - try: - content = file_path.read_text(encoding='utf-8') - except UnicodeDecodeError: - try: - content = file_path.read_text(encoding='latin-1') - except Exception: - continue - - source_url = f"{base_url}/{rel_path}" - - if file_path.suffix == '.md': - doc = MarkdownParser.parse(content, rel_path) - doc.source_url = source_url - yield doc - - elif file_path.suffix in ['.html', '.htm']: - doc = HTMLParser.parse(content, rel_path) - doc.source_url = source_url - yield doc - - elif file_path.suffix == '.json': - docs = JSONParser.parse(content, rel_path) - for doc in docs: - doc.source_url = source_url - yield doc - - elif file_path.suffix == '.txt': - yield ExtractedDocument( - text=content, - title=file_path.stem, - file_path=rel_path, - file_type="text", - source_url=source_url, - language=MarkdownParser._detect_language(content), - placeholders=MarkdownParser._find_placeholders(content), - ) - - def cleanup(self, directory: Path): - """Clean up temporary directory.""" - if directory.exists(): - shutil.rmtree(directory, ignore_errors=True) - - -async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]: - """Crawl a source configuration and return all extracted documents.""" - documents = [] - - if source.repo_url: - async with GitHubCrawler() as crawler: - async for doc in crawler.crawl_repository(source): - documents.append(doc) - - return documents - - -# CLI for testing -async def main(): - """Test crawler with a sample source.""" - from template_sources import TEMPLATE_SOURCES - - source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy") - - async with GitHubCrawler() as crawler: - count = 0 - async for doc in crawler.crawl_repository(source): - count += 1 - print(f"\n{'='*60}") - print(f"Title: {doc.title}") - print(f"Path: {doc.file_path}") - print(f"Type: {doc.file_type}") - print(f"Language: {doc.language}") - print(f"URL: {doc.source_url}") - print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}") - print(f"Text preview: {doc.text[:200]}...") - - print(f"\n\nTotal documents: {count}") - - -if __name__ == "__main__": - asyncio.run(main()) +# Backward-compat shim -- module moved to crawler/github_core.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("crawler.github_core") diff --git a/klausur-service/backend/github_crawler_parsers.py b/klausur-service/backend/github_crawler_parsers.py index 416c9eb..972b99f 100644 --- a/klausur-service/backend/github_crawler_parsers.py +++ b/klausur-service/backend/github_crawler_parsers.py @@ -1,303 +1,4 @@ -""" -GitHub Crawler - Document Parsers - -Markdown, HTML, and JSON parsers for extracting structured content -from legal template documents. - -Extracted from github_crawler.py to keep files under 500 LOC. -""" - -import hashlib -import json -import re -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional - - -@dataclass -class ExtractedDocument: - """A document extracted from a repository.""" - text: str - title: str - file_path: str - file_type: str # "markdown", "html", "json", "text" - source_url: str - source_commit: Optional[str] = None - source_hash: str = "" # SHA256 of original content - sections: List[Dict[str, Any]] = field(default_factory=list) - placeholders: List[str] = field(default_factory=list) - language: str = "en" - metadata: Dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - if not self.source_hash and self.text: - self.source_hash = hashlib.sha256(self.text.encode()).hexdigest() - - -class MarkdownParser: - """Parse Markdown files into structured content.""" - - # Common placeholder patterns - PLACEHOLDER_PATTERNS = [ - r'\[([A-Z_]+)\]', # [COMPANY_NAME] - r'\{([a-z_]+)\}', # {company_name} - r'\{\{([a-z_]+)\}\}', # {{company_name}} - r'__([A-Z_]+)__', # __COMPANY_NAME__ - r'<([A-Z_]+)>', # - ] - - @classmethod - def parse(cls, content: str, filename: str = "") -> ExtractedDocument: - """Parse markdown content into an ExtractedDocument.""" - title = cls._extract_title(content, filename) - sections = cls._extract_sections(content) - placeholders = cls._find_placeholders(content) - language = cls._detect_language(content) - clean_text = cls._clean_for_indexing(content) - - return ExtractedDocument( - text=clean_text, - title=title, - file_path=filename, - file_type="markdown", - source_url="", - sections=sections, - placeholders=placeholders, - language=language, - ) - - @classmethod - def _extract_title(cls, content: str, filename: str) -> str: - """Extract title from markdown heading or filename.""" - h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE) - if h1_match: - return h1_match.group(1).strip() - - frontmatter_match = re.search( - r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---', - content, re.DOTALL - ) - if frontmatter_match: - return frontmatter_match.group(1).strip() - - if filename: - name = Path(filename).stem - return name.replace('-', ' ').replace('_', ' ').title() - - return "Untitled" - - @classmethod - def _extract_sections(cls, content: str) -> List[Dict[str, Any]]: - """Extract sections from markdown content.""" - sections = [] - current_section = {"heading": "", "level": 0, "content": "", "start": 0} - - for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE): - if current_section["heading"] or current_section["content"].strip(): - current_section["content"] = current_section["content"].strip() - sections.append(current_section.copy()) - - level = len(match.group(1)) - heading = match.group(2).strip() - current_section = { - "heading": heading, - "level": level, - "content": "", - "start": match.end(), - } - - if current_section["heading"] or current_section["content"].strip(): - current_section["content"] = content[current_section["start"]:].strip() - sections.append(current_section) - - return sections - - @classmethod - def _find_placeholders(cls, content: str) -> List[str]: - """Find placeholder patterns in content.""" - placeholders = set() - for pattern in cls.PLACEHOLDER_PATTERNS: - for match in re.finditer(pattern, content): - placeholder = match.group(0) - placeholders.add(placeholder) - return sorted(list(placeholders)) - - @classmethod - def _detect_language(cls, content: str) -> str: - """Detect language from content.""" - german_indicators = [ - 'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung', - 'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung', - 'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind', - ] - - lower_content = content.lower() - german_count = sum(1 for word in german_indicators if word.lower() in lower_content) - - if german_count >= 3: - return "de" - return "en" - - @classmethod - def _clean_for_indexing(cls, content: str) -> str: - """Clean markdown content for text indexing.""" - content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL) - content = re.sub(r'', '', content, flags=re.DOTALL) - content = re.sub(r'<[^>]+>', '', content) - content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) - content = re.sub(r'\*(.+?)\*', r'\1', content) - content = re.sub(r'`(.+?)`', r'\1', content) - content = re.sub(r'~~(.+?)~~', r'\1', content) - content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content) - content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content) - content = re.sub(r'\n{3,}', '\n\n', content) - content = re.sub(r' +', ' ', content) - - return content.strip() - - -class HTMLParser: - """Parse HTML files into structured content.""" - - @classmethod - def parse(cls, content: str, filename: str = "") -> ExtractedDocument: - """Parse HTML content into an ExtractedDocument.""" - title_match = re.search(r'(.+?)', content, re.IGNORECASE) - title = title_match.group(1) if title_match else Path(filename).stem - - text = cls._html_to_text(content) - placeholders = MarkdownParser._find_placeholders(text) - - lang_match = re.search(r']*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE) - language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text) - - return ExtractedDocument( - text=text, - title=title, - file_path=filename, - file_type="html", - source_url="", - placeholders=placeholders, - language=language, - ) - - @classmethod - def _html_to_text(cls, html: str) -> str: - """Convert HTML to clean text.""" - html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - html = re.sub(r'', '', html, flags=re.DOTALL) - - html = html.replace(' ', ' ') - html = html.replace('&', '&') - html = html.replace('<', '<') - html = html.replace('>', '>') - html = html.replace('"', '"') - html = html.replace(''', "'") - - html = re.sub(r'', '\n', html, flags=re.IGNORECASE) - html = re.sub(r'

', '\n\n', html, flags=re.IGNORECASE) - html = re.sub(r'', '\n', html, flags=re.IGNORECASE) - html = re.sub(r'', '\n\n', html, flags=re.IGNORECASE) - html = re.sub(r'', '\n', html, flags=re.IGNORECASE) - - html = re.sub(r'<[^>]+>', '', html) - - html = re.sub(r'[ \t]+', ' ', html) - html = re.sub(r'\n[ \t]+', '\n', html) - html = re.sub(r'[ \t]+\n', '\n', html) - html = re.sub(r'\n{3,}', '\n\n', html) - - return html.strip() - - -class JSONParser: - """Parse JSON files containing legal template data.""" - - @classmethod - def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]: - """Parse JSON content into ExtractedDocuments.""" - try: - data = json.loads(content) - except json.JSONDecodeError as e: - import logging - logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}") - return [] - - documents = [] - - if isinstance(data, dict): - documents.extend(cls._parse_dict(data, filename)) - elif isinstance(data, list): - for i, item in enumerate(data): - if isinstance(item, dict): - docs = cls._parse_dict(item, f"{filename}[{i}]") - documents.extend(docs) - - return documents - - @classmethod - def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]: - """Parse a dictionary into documents.""" - documents = [] - - text_keys = ['text', 'content', 'body', 'description', 'value'] - title_keys = ['title', 'name', 'heading', 'label', 'key'] - - text = "" - for key in text_keys: - if key in data and isinstance(data[key], str): - text = data[key] - break - - if not text: - for key, value in data.items(): - if isinstance(value, dict): - nested_docs = cls._parse_dict(value, f"{filename}.{key}") - documents.extend(nested_docs) - elif isinstance(value, list): - for i, item in enumerate(value): - if isinstance(item, dict): - nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]") - documents.extend(nested_docs) - elif isinstance(item, str) and len(item) > 50: - documents.append(ExtractedDocument( - text=item, - title=f"{key} {i+1}", - file_path=filename, - file_type="json", - source_url="", - language=MarkdownParser._detect_language(item), - )) - return documents - - title = "" - for key in title_keys: - if key in data and isinstance(data[key], str): - title = data[key] - break - - if not title: - title = Path(filename).stem - - metadata = {} - for key, value in data.items(): - if key not in text_keys + title_keys and not isinstance(value, (dict, list)): - metadata[key] = value - - placeholders = MarkdownParser._find_placeholders(text) - language = data.get('lang', data.get('language', MarkdownParser._detect_language(text))) - - documents.append(ExtractedDocument( - text=text, - title=title, - file_path=filename, - file_type="json", - source_url="", - placeholders=placeholders, - language=language, - metadata=metadata, - )) - - return documents +# Backward-compat shim -- module moved to crawler/github_parsers.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("crawler.github_parsers") diff --git a/klausur-service/backend/grid/unified.py b/klausur-service/backend/grid/unified.py new file mode 100644 index 0000000..48e9fa4 --- /dev/null +++ b/klausur-service/backend/grid/unified.py @@ -0,0 +1,425 @@ +""" +Unified Grid Builder — merges multi-zone grid into a single Excel-like grid. + +Takes content zone + box zones and produces one unified zone where: +- All content rows use the dominant row height +- Full-width boxes are integrated directly (box rows replace standard rows) +- Partial-width boxes: extra rows inserted if box has more lines than standard +- Box-origin cells carry metadata (bg_color, border) for visual distinction + +The result is a single-zone StructuredGrid that can be: +- Rendered in an Excel-like editor +- Exported to Excel/CSV +- Edited with unified row/column numbering +""" + +import logging +import math +import statistics +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _compute_dominant_row_height(content_zone: Dict) -> float: + """Median of content row-to-row spacings, excluding box-gap jumps.""" + rows = content_zone.get("rows", []) + if len(rows) < 2: + return 47.0 + + spacings = [] + for i in range(len(rows) - 1): + y1 = rows[i].get("y_min_px", rows[i].get("y_min", 0)) + y2 = rows[i + 1].get("y_min_px", rows[i + 1].get("y_min", 0)) + d = y2 - y1 + if 0 < d < 100: # exclude box-gap jumps + spacings.append(d) + + if not spacings: + return 47.0 + spacings.sort() + return spacings[len(spacings) // 2] + + +def _classify_boxes( + box_zones: List[Dict], + content_width: float, +) -> List[Dict]: + """Classify each box as full_width or partial_width.""" + result = [] + for bz in box_zones: + bb = bz.get("bbox_px", {}) + bw = bb.get("w", 0) + bx = bb.get("x", 0) + + if bw >= content_width * 0.85: + classification = "full_width" + side = "center" + else: + classification = "partial_width" + # Determine which side of the page the box is on + page_center = content_width / 2 + box_center = bx + bw / 2 + side = "right" if box_center > page_center else "left" + + # Count total text lines in box (including \n within cells) + total_lines = sum( + (c.get("text", "").count("\n") + 1) + for c in bz.get("cells", []) + ) + + result.append({ + "zone": bz, + "classification": classification, + "side": side, + "y_start": bb.get("y", 0), + "y_end": bb.get("y", 0) + bb.get("h", 0), + "total_lines": total_lines, + "bg_hex": bz.get("box_bg_hex", ""), + "bg_color": bz.get("box_bg_color", ""), + }) + return result + + +def build_unified_grid( + zones: List[Dict], + image_width: int, + image_height: int, + layout_metrics: Dict, +) -> Dict[str, Any]: + """Build a single-zone unified grid from multi-zone grid data. + + Returns a StructuredGrid with one zone containing all rows and cells. + """ + content_zone = None + box_zones = [] + for z in zones: + if z.get("zone_type") == "content": + content_zone = z + elif z.get("zone_type") == "box": + box_zones.append(z) + + if not content_zone: + logger.warning("build_unified_grid: no content zone found") + return {"zones": zones} # fallback: return as-is + + box_zones.sort(key=lambda b: b.get("bbox_px", {}).get("y", 0)) + + dominant_h = _compute_dominant_row_height(content_zone) + content_bbox = content_zone.get("bbox_px", {}) + content_width = content_bbox.get("w", image_width) + content_x = content_bbox.get("x", 0) + content_cols = content_zone.get("columns", []) + num_cols = len(content_cols) + + box_infos = _classify_boxes(box_zones, content_width) + + logger.info( + "build_unified_grid: dominant_h=%.1f, %d content rows, %d boxes (%s)", + dominant_h, len(content_zone.get("rows", [])), len(box_infos), + [b["classification"] for b in box_infos], + ) + + # --- Build unified row list + cell list --- + unified_rows: List[Dict] = [] + unified_cells: List[Dict] = [] + unified_row_idx = 0 + + # Content rows and cells indexed by row_index + content_rows = content_zone.get("rows", []) + content_cells = content_zone.get("cells", []) + content_cells_by_row: Dict[int, List[Dict]] = {} + for c in content_cells: + content_cells_by_row.setdefault(c.get("row_index", -1), []).append(c) + + # Track which content rows we've processed + content_row_ptr = 0 + + for bi, box_info in enumerate(box_infos): + bz = box_info["zone"] + by_start = box_info["y_start"] + by_end = box_info["y_end"] + + # --- Add content rows ABOVE this box --- + while content_row_ptr < len(content_rows): + cr = content_rows[content_row_ptr] + cry = cr.get("y_min_px", cr.get("y_min", 0)) + if cry >= by_start: + break + # Add this content row + _add_content_row( + unified_rows, unified_cells, unified_row_idx, + cr, content_cells_by_row, dominant_h, image_height, + ) + unified_row_idx += 1 + content_row_ptr += 1 + + # --- Add box rows --- + if box_info["classification"] == "full_width": + # Full-width box: integrate box rows directly + _add_full_width_box( + unified_rows, unified_cells, unified_row_idx, + bz, box_info, dominant_h, num_cols, image_height, + ) + unified_row_idx += len(bz.get("rows", [])) + # Skip content rows that overlap with this box + while content_row_ptr < len(content_rows): + cr = content_rows[content_row_ptr] + cry = cr.get("y_min_px", cr.get("y_min", 0)) + if cry > by_end: + break + content_row_ptr += 1 + + else: + # Partial-width box: merge with adjacent content rows + unified_row_idx = _add_partial_width_box( + unified_rows, unified_cells, unified_row_idx, + bz, box_info, content_rows, content_cells_by_row, + content_row_ptr, dominant_h, num_cols, image_height, + content_x, content_width, + ) + # Advance content pointer past box region + while content_row_ptr < len(content_rows): + cr = content_rows[content_row_ptr] + cry = cr.get("y_min_px", cr.get("y_min", 0)) + if cry > by_end: + break + content_row_ptr += 1 + + # --- Add remaining content rows BELOW all boxes --- + while content_row_ptr < len(content_rows): + cr = content_rows[content_row_ptr] + _add_content_row( + unified_rows, unified_cells, unified_row_idx, + cr, content_cells_by_row, dominant_h, image_height, + ) + unified_row_idx += 1 + content_row_ptr += 1 + + # --- Build unified zone --- + unified_zone = { + "zone_index": 0, + "zone_type": "unified", + "bbox_px": content_bbox, + "bbox_pct": content_zone.get("bbox_pct", {}), + "border": None, + "word_count": sum(len(c.get("word_boxes", [])) for c in unified_cells), + "columns": content_cols, + "rows": unified_rows, + "cells": unified_cells, + "header_rows": [], + } + + logger.info( + "build_unified_grid: %d unified rows, %d cells (from %d content + %d box zones)", + len(unified_rows), len(unified_cells), + len(content_rows), len(box_zones), + ) + + return { + "zones": [unified_zone], + "image_width": image_width, + "image_height": image_height, + "layout_metrics": layout_metrics, + "summary": { + "total_zones": 1, + "total_columns": num_cols, + "total_rows": len(unified_rows), + "total_cells": len(unified_cells), + }, + "is_unified": True, + "dominant_row_h": dominant_h, + } + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_row(idx: int, y: float, h: float, img_h: int, is_header: bool = False) -> Dict: + return { + "index": idx, + "row_index": idx, + "y_min_px": round(y), + "y_max_px": round(y + h), + "y_min_pct": round(y / img_h * 100, 2) if img_h else 0, + "y_max_pct": round((y + h) / img_h * 100, 2) if img_h else 0, + "is_header": is_header, + } + + +def _remap_cell(cell: Dict, new_row: int, new_col: int = None, + source_type: str = "content", box_region: Dict = None) -> Dict: + """Create a new cell dict with remapped indices.""" + c = dict(cell) + c["row_index"] = new_row + if new_col is not None: + c["col_index"] = new_col + c["cell_id"] = f"U_R{new_row:02d}_C{c.get('col_index', 0)}" + c["source_zone_type"] = source_type + if box_region: + c["box_region"] = box_region + return c + + +def _add_content_row( + unified_rows, unified_cells, row_idx, + content_row, cells_by_row, dominant_h, img_h, +): + """Add a single content row to the unified grid.""" + y = content_row.get("y_min_px", content_row.get("y_min", 0)) + is_hdr = content_row.get("is_header", False) + unified_rows.append(_make_row(row_idx, y, dominant_h, img_h, is_hdr)) + + for cell in cells_by_row.get(content_row.get("index", -1), []): + unified_cells.append(_remap_cell(cell, row_idx, source_type="content")) + + +def _add_full_width_box( + unified_rows, unified_cells, start_row_idx, + box_zone, box_info, dominant_h, num_cols, img_h, +): + """Add a full-width box's rows to the unified grid.""" + box_rows = box_zone.get("rows", []) + box_cells = box_zone.get("cells", []) + box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True} + + # Distribute box height evenly among its rows + box_h = box_info["y_end"] - box_info["y_start"] + row_h = box_h / len(box_rows) if box_rows else dominant_h + + for i, br in enumerate(box_rows): + y = box_info["y_start"] + i * row_h + new_idx = start_row_idx + i + is_hdr = br.get("is_header", False) + unified_rows.append(_make_row(new_idx, y, row_h, img_h, is_hdr)) + + for cell in box_cells: + if cell.get("row_index") == br.get("index", i): + unified_cells.append( + _remap_cell(cell, new_idx, source_type="box", box_region=box_region) + ) + + +def _add_partial_width_box( + unified_rows, unified_cells, start_row_idx, + box_zone, box_info, content_rows, content_cells_by_row, + content_row_ptr, dominant_h, num_cols, img_h, + content_x, content_width, +) -> int: + """Add a partial-width box merged with content rows. + + Returns the next unified_row_idx after processing. + """ + by_start = box_info["y_start"] + by_end = box_info["y_end"] + box_h = by_end - by_start + box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True} + + # Content rows in the box's Y range + overlap_content_rows = [] + ptr = content_row_ptr + while ptr < len(content_rows): + cr = content_rows[ptr] + cry = cr.get("y_min_px", cr.get("y_min", 0)) + if cry > by_end: + break + if cry >= by_start: + overlap_content_rows.append(cr) + ptr += 1 + + # How many standard rows fit in the box height + standard_rows = max(1, math.floor(box_h / dominant_h)) + # How many text lines the box actually has + box_text_lines = box_info["total_lines"] + # Extra rows needed + extra_rows = max(0, box_text_lines - standard_rows) + total_rows_for_region = standard_rows + extra_rows + + logger.info( + "partial box: standard=%d, box_lines=%d, extra=%d, content_overlap=%d", + standard_rows, box_text_lines, extra_rows, len(overlap_content_rows), + ) + + # Determine which columns the box occupies + box_bb = box_zone.get("bbox_px", {}) + box_x = box_bb.get("x", 0) + box_w = box_bb.get("w", 0) + + # Map box to content columns: find which content columns overlap + box_col_start = 0 + box_col_end = num_cols + content_cols_list = [] + for z_col_idx in range(num_cols): + # Find the column definition by checking all column entries + # Simple heuristic: if box starts past halfway, it's the right columns + pass + + # Simpler approach: box on right side → last N columns + # box on left side → first N columns + if box_info["side"] == "right": + # Box starts at x=box_x. Find first content column that overlaps + box_col_start = num_cols # default: beyond all columns + for z in (box_zone.get("columns") or [{"index": 0}]): + pass + # Use content column positions to determine overlap + content_cols_data = [ + {"idx": c.get("index", i), "x_min": c.get("x_min_px", 0), "x_max": c.get("x_max_px", 0)} + for i, c in enumerate(content_rows[0:0] or []) # placeholder + ] + # Simple: split columns at midpoint + box_col_start = num_cols // 2 # right half + box_col_end = num_cols + else: + box_col_start = 0 + box_col_end = num_cols // 2 + + # Build rows for this region + box_cells = box_zone.get("cells", []) + box_rows = box_zone.get("rows", []) + row_idx = start_row_idx + + # Expand box cell texts with \n into individual lines for row mapping + box_lines: List[Tuple[str, Dict]] = [] # (text_line, parent_cell) + for bc in sorted(box_cells, key=lambda c: c.get("row_index", 0)): + text = bc.get("text", "") + for line in text.split("\n"): + box_lines.append((line.strip(), bc)) + + for i in range(total_rows_for_region): + y = by_start + i * dominant_h + unified_rows.append(_make_row(row_idx, y, dominant_h, img_h)) + + # Content cells for this row (from overlapping content rows) + if i < len(overlap_content_rows): + cr = overlap_content_rows[i] + for cell in content_cells_by_row.get(cr.get("index", -1), []): + # Only include cells from columns NOT covered by the box + ci = cell.get("col_index", 0) + if ci < box_col_start or ci >= box_col_end: + unified_cells.append(_remap_cell(cell, row_idx, source_type="content")) + + # Box cell for this row + if i < len(box_lines): + line_text, parent_cell = box_lines[i] + box_cell = { + "cell_id": f"U_R{row_idx:02d}_C{box_col_start}", + "row_index": row_idx, + "col_index": box_col_start, + "col_type": "spanning_header" if (box_col_end - box_col_start) > 1 else parent_cell.get("col_type", "column_1"), + "colspan": box_col_end - box_col_start, + "text": line_text, + "confidence": parent_cell.get("confidence", 0), + "bbox_px": parent_cell.get("bbox_px", {}), + "bbox_pct": parent_cell.get("bbox_pct", {}), + "word_boxes": [], + "ocr_engine": parent_cell.get("ocr_engine", ""), + "is_bold": parent_cell.get("is_bold", False), + "source_zone_type": "box", + "box_region": box_region, + } + unified_cells.append(box_cell) + + row_idx += 1 + + return row_idx diff --git a/klausur-service/backend/handwriting_htr_api.py b/klausur-service/backend/handwriting_htr_api.py index 2976069..4e692ce 100644 --- a/klausur-service/backend/handwriting_htr_api.py +++ b/klausur-service/backend/handwriting_htr_api.py @@ -1,276 +1,4 @@ -""" -Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen. - -Endpoints: -- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text -- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen - -Modell-Strategie: - 1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM) - 2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama) - -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini. -""" - -import io -import os -import logging -import time -import base64 -from typing import Optional - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException, Query, UploadFile, File -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/htr", tags=["HTR"]) - -OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") -OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") -HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large") - - -# --------------------------------------------------------------------------- -# Pydantic Models -# --------------------------------------------------------------------------- - -class HTRSessionRequest(BaseModel): - session_id: str - model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large" - use_clean: bool = True # Prefer clean_png (after handwriting removal) - - -# --------------------------------------------------------------------------- -# Preprocessing -# --------------------------------------------------------------------------- - -def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray: - """ - CLAHE contrast enhancement + upscale to improve HTR accuracy. - Returns grayscale enhanced image. - """ - gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) - clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) - enhanced = clahe.apply(gray) - - # Upscale if image is too small - h, w = enhanced.shape - if min(h, w) < 800: - scale = 800 / min(h, w) - enhanced = cv2.resize( - enhanced, None, fx=scale, fy=scale, - interpolation=cv2.INTER_CUBIC - ) - - return enhanced - - -def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes: - """Convert BGR ndarray to PNG bytes.""" - success, buf = cv2.imencode(".png", img_bgr) - if not success: - raise RuntimeError("Failed to encode image to PNG") - return buf.tobytes() - - -def _preprocess_image_bytes(image_bytes: bytes) -> bytes: - """Load image, apply HTR preprocessing, return PNG bytes.""" - arr = np.frombuffer(image_bytes, dtype=np.uint8) - img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise ValueError("Could not decode image") - - enhanced = _preprocess_for_htr(img_bgr) - # Convert grayscale back to BGR for encoding - enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR) - return _bgr_to_png_bytes(enhanced_bgr) - - -# --------------------------------------------------------------------------- -# Backend: Ollama qwen2.5vl -# --------------------------------------------------------------------------- - -async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]: - """ - Send image to Ollama qwen2.5vl:32b for HTR. - Returns extracted text or None on error. - """ - import httpx - - lang_hint = { - "de": "Deutsch", - "en": "Englisch", - "de+en": "Deutsch und Englisch", - }.get(language, "Deutsch") - - prompt = ( - f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. " - "Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. " - "Antworte NUR mit dem erkannten Text, ohne Erklaerungen." - ) - - img_b64 = base64.b64encode(image_bytes).decode("utf-8") - - payload = { - "model": OLLAMA_HTR_MODEL, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=120.0) as client: - resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload) - resp.raise_for_status() - data = resp.json() - return data.get("response", "").strip() - except Exception as e: - logger.warning(f"Ollama qwen2.5vl HTR failed: {e}") - return None - - -# --------------------------------------------------------------------------- -# Backend: TrOCR-large fallback -# --------------------------------------------------------------------------- - -async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]: - """ - Use microsoft/trocr-large-handwritten via trocr_service.py. - Returns extracted text or None on error. - """ - try: - from services.trocr_service import run_trocr_ocr, _check_trocr_available - if not _check_trocr_available(): - logger.warning("TrOCR not available for HTR fallback") - return None - - text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large") - return text.strip() if text else None - except Exception as e: - logger.warning(f"TrOCR-large HTR failed: {e}") - return None - - -# --------------------------------------------------------------------------- -# Core recognition logic -# --------------------------------------------------------------------------- - -async def _do_recognize( - image_bytes: bytes, - model: str = "auto", - preprocess: bool = True, - language: str = "de", -) -> dict: - """ - Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large. - Returns dict with text, model_used, processing_time_ms. - """ - t0 = time.monotonic() - - if preprocess: - try: - image_bytes = _preprocess_image_bytes(image_bytes) - except Exception as e: - logger.warning(f"HTR preprocessing failed, using raw image: {e}") - - text: Optional[str] = None - model_used: str = "none" - - use_qwen = model in ("auto", "qwen2.5vl") - use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None) - - if use_qwen: - text = await _recognize_with_qwen_vl(image_bytes, language) - if text is not None: - model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})" - - if text is None and (use_trocr or model == "trocr-large"): - text = await _recognize_with_trocr_large(image_bytes) - if text is not None: - model_used = "trocr-large-handwritten" - - if text is None: - text = "" - model_used = "none (all backends failed)" - - elapsed_ms = int((time.monotonic() - t0) * 1000) - - return { - "text": text, - "model_used": model_used, - "processing_time_ms": elapsed_ms, - "language": language, - "preprocessed": preprocess, - } - - -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- - -@router.post("/recognize") -async def recognize_handwriting( - file: UploadFile = File(...), - model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"), - preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"), - language: str = Query("de", description="de | en | de+en"), -): - """ - Upload an image and get back the handwritten text as plain text. - - Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten. - """ - if model not in ("auto", "qwen2.5vl", "trocr-large"): - raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large") - if language not in ("de", "en", "de+en"): - raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en") - - image_bytes = await file.read() - if not image_bytes: - raise HTTPException(status_code=400, detail="Empty file") - - return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language) - - -@router.post("/recognize-session") -async def recognize_from_session(req: HTRSessionRequest): - """ - Use an OCR-Pipeline session as image source for HTR. - - Set use_clean=true to prefer the clean image (after handwriting removal step). - This is useful when you want to do HTR on isolated handwriting regions. - """ - from ocr_pipeline_session_store import get_session_db, get_session_image - - session = await get_session_db(req.session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found") - - # Choose source image - image_bytes: Optional[bytes] = None - source_used: str = "" - - if req.use_clean: - image_bytes = await get_session_image(req.session_id, "clean") - if image_bytes: - source_used = "clean" - - if not image_bytes: - image_bytes = await get_session_image(req.session_id, "deskewed") - if image_bytes: - source_used = "deskewed" - - if not image_bytes: - image_bytes = await get_session_image(req.session_id, "original") - source_used = "original" - - if not image_bytes: - raise HTTPException(status_code=404, detail="No image available in session") - - result = await _do_recognize(image_bytes, model=req.model) - result["session_id"] = req.session_id - result["source_image"] = source_used - return result +# Backward-compat shim -- module moved to ocr/pipeline/htr_api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.htr_api") diff --git a/klausur-service/backend/ocr/engines/tesseract_extractor.py b/klausur-service/backend/ocr/engines/tesseract_extractor.py new file mode 100644 index 0000000..23ac32e --- /dev/null +++ b/klausur-service/backend/ocr/engines/tesseract_extractor.py @@ -0,0 +1,346 @@ +""" +Tesseract-based OCR extraction with word-level bounding boxes. + +Uses Tesseract for spatial information (WHERE text is) while +the Vision LLM handles semantic understanding (WHAT the text means). + +Tesseract runs natively on ARM64 via Debian's apt package. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +""" + +import io +import logging +from typing import List, Dict, Any, Optional +from difflib import SequenceMatcher + +logger = logging.getLogger(__name__) + +try: + import pytesseract + from PIL import Image + TESSERACT_AVAILABLE = True +except ImportError: + TESSERACT_AVAILABLE = False + logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable") + + +async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict: + """Run Tesseract OCR and return word-level bounding boxes. + + Args: + image_bytes: PNG/JPEG image as bytes. + lang: Tesseract language string (e.g. "eng+deu"). + + Returns: + Dict with 'words' list and 'image_width'/'image_height'. + """ + if not TESSERACT_AVAILABLE: + return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"} + + image = Image.open(io.BytesIO(image_bytes)) + data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT) + + words = [] + for i in range(len(data['text'])): + text = data['text'][i].strip() + conf = int(data['conf'][i]) + if not text or conf < 20: + continue + words.append({ + "text": text, + "left": data['left'][i], + "top": data['top'][i], + "width": data['width'][i], + "height": data['height'][i], + "conf": conf, + "block_num": data['block_num'][i], + "par_num": data['par_num'][i], + "line_num": data['line_num'][i], + "word_num": data['word_num'][i], + }) + + return { + "words": words, + "image_width": image.width, + "image_height": image.height, + } + + +def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]: + """Group words by their Y position into lines. + + Args: + words: List of word dicts from extract_bounding_boxes. + y_tolerance_px: Max pixel distance to consider words on the same line. + + Returns: + List of lines, each line is a list of words sorted by X position. + """ + if not words: + return [] + + # Sort by Y then X + sorted_words = sorted(words, key=lambda w: (w['top'], w['left'])) + + lines: List[List[dict]] = [] + current_line: List[dict] = [sorted_words[0]] + current_y = sorted_words[0]['top'] + + for word in sorted_words[1:]: + if abs(word['top'] - current_y) <= y_tolerance_px: + current_line.append(word) + else: + current_line.sort(key=lambda w: w['left']) + lines.append(current_line) + current_line = [word] + current_y = word['top'] + + if current_line: + current_line.sort(key=lambda w: w['left']) + lines.append(current_line) + + return lines + + +def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]: + """Detect column boundaries from word positions. + + Typical vocab table: Left=English, Middle=German, Right=Example sentences. + + Returns: + Dict with column boundaries and type assignments. + """ + if not lines or image_width == 0: + return {"columns": [], "column_types": []} + + # Collect all word X positions + all_x_positions = [] + for line in lines: + for word in line: + all_x_positions.append(word['left']) + + if not all_x_positions: + return {"columns": [], "column_types": []} + + # Find X-position clusters (column starts) + all_x_positions.sort() + + # Simple gap-based column detection + min_gap = image_width * 0.08 # 8% of page width = column gap + clusters = [] + current_cluster = [all_x_positions[0]] + + for x in all_x_positions[1:]: + if x - current_cluster[-1] > min_gap: + clusters.append(current_cluster) + current_cluster = [x] + else: + current_cluster.append(x) + + if current_cluster: + clusters.append(current_cluster) + + # Each cluster represents a column start + columns = [] + for cluster in clusters: + col_start = min(cluster) + columns.append({ + "x_start": col_start, + "x_start_pct": col_start / image_width * 100, + "word_count": len(cluster), + }) + + # Assign column types based on position (left→right: EN, DE, Example) + type_map = ["english", "german", "example"] + column_types = [] + for i, col in enumerate(columns): + if i < len(type_map): + column_types.append(type_map[i]) + else: + column_types.append("unknown") + + return { + "columns": columns, + "column_types": column_types, + } + + +def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict], + column_types: List[str], image_width: int, + image_height: int) -> List[dict]: + """Convert grouped words into vocabulary entries using column positions. + + Args: + lines: Grouped word lines from group_words_into_lines. + columns: Column boundaries from detect_columns. + column_types: Column type assignments. + image_width: Image width in pixels. + image_height: Image height in pixels. + + Returns: + List of vocabulary entry dicts with english/german/example fields. + """ + if not columns or not lines: + return [] + + # Build column boundaries for word assignment + col_boundaries = [] + for i, col in enumerate(columns): + start = col['x_start'] + if i + 1 < len(columns): + end = columns[i + 1]['x_start'] + else: + end = image_width + col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown")) + + entries = [] + for line in lines: + entry = {"english": "", "german": "", "example": ""} + line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []} + line_bbox: Dict[str, Optional[dict]] = {} + + for word in line: + word_center_x = word['left'] + word['width'] / 2 + assigned_type = "unknown" + for start, end, col_type in col_boundaries: + if start <= word_center_x < end: + assigned_type = col_type + break + + if assigned_type in line_words_by_col: + line_words_by_col[assigned_type].append(word['text']) + # Track bounding box for the column + if assigned_type not in line_bbox or line_bbox[assigned_type] is None: + line_bbox[assigned_type] = { + "left": word['left'], + "top": word['top'], + "right": word['left'] + word['width'], + "bottom": word['top'] + word['height'], + } + else: + bb = line_bbox[assigned_type] + bb['left'] = min(bb['left'], word['left']) + bb['top'] = min(bb['top'], word['top']) + bb['right'] = max(bb['right'], word['left'] + word['width']) + bb['bottom'] = max(bb['bottom'], word['top'] + word['height']) + + for col_type in ["english", "german", "example"]: + if line_words_by_col[col_type]: + entry[col_type] = " ".join(line_words_by_col[col_type]) + if line_bbox.get(col_type): + bb = line_bbox[col_type] + entry[f"{col_type}_bbox"] = { + "x_pct": bb['left'] / image_width * 100, + "y_pct": bb['top'] / image_height * 100, + "w_pct": (bb['right'] - bb['left']) / image_width * 100, + "h_pct": (bb['bottom'] - bb['top']) / image_height * 100, + } + + # Only add if at least one column has content + if entry["english"] or entry["german"]: + entries.append(entry) + + return entries + + +def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict], + image_w: int, image_h: int, + threshold: float = 0.6) -> List[dict]: + """Match Tesseract bounding boxes to LLM vocabulary entries. + + For each LLM vocab entry, find the best-matching Tesseract word + and attach its bounding box coordinates. + + Args: + tess_words: Word list from Tesseract with pixel coordinates. + llm_vocab: Vocabulary list from Vision LLM. + image_w: Image width in pixels. + image_h: Image height in pixels. + threshold: Minimum similarity ratio for a match. + + Returns: + llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added. + """ + if not tess_words or not llm_vocab or image_w == 0 or image_h == 0: + return llm_vocab + + for entry in llm_vocab: + english = entry.get("english", "").lower().strip() + german = entry.get("german", "").lower().strip() + + if not english and not german: + continue + + # Try to match English word first, then German + for field in ["english", "german"]: + search_text = entry.get(field, "").lower().strip() + if not search_text: + continue + + best_word = None + best_ratio = 0.0 + + for word in tess_words: + ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio() + if ratio > best_ratio: + best_ratio = ratio + best_word = word + + if best_word and best_ratio >= threshold: + entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100 + entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100 + entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100 + entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100 + entry["bbox_match_field"] = field + entry["bbox_match_ratio"] = round(best_ratio, 3) + break # Found a match, no need to try the other field + + return llm_vocab + + +async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict: + """Full Tesseract pipeline: extract words, group lines, detect columns, build vocab. + + Args: + image_bytes: PNG/JPEG image as bytes. + lang: Tesseract language string. + + Returns: + Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'. + """ + # Step 1: Extract bounding boxes + bbox_data = await extract_bounding_boxes(image_bytes, lang=lang) + + if bbox_data.get("error"): + return bbox_data + + words = bbox_data["words"] + image_w = bbox_data["image_width"] + image_h = bbox_data["image_height"] + + # Step 2: Group into lines + lines = group_words_into_lines(words) + + # Step 3: Detect columns + col_info = detect_columns(lines, image_w) + + # Step 4: Build vocabulary entries + vocab = words_to_vocab_entries( + lines, + col_info["columns"], + col_info["column_types"], + image_w, + image_h, + ) + + return { + "vocabulary": vocab, + "words": words, + "lines_count": len(lines), + "columns": col_info["columns"], + "column_types": col_info["column_types"], + "image_width": image_w, + "image_height": image_h, + "word_count": len(words), + } diff --git a/klausur-service/backend/ocr/pipeline/htr_api.py b/klausur-service/backend/ocr/pipeline/htr_api.py new file mode 100644 index 0000000..2976069 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/htr_api.py @@ -0,0 +1,276 @@ +""" +Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen. + +Endpoints: +- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text +- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen + +Modell-Strategie: + 1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM) + 2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama) + +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini. +""" + +import io +import os +import logging +import time +import base64 +from typing import Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Query, UploadFile, File +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/htr", tags=["HTR"]) + +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") +OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") +HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large") + + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +class HTRSessionRequest(BaseModel): + session_id: str + model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large" + use_clean: bool = True # Prefer clean_png (after handwriting removal) + + +# --------------------------------------------------------------------------- +# Preprocessing +# --------------------------------------------------------------------------- + +def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray: + """ + CLAHE contrast enhancement + upscale to improve HTR accuracy. + Returns grayscale enhanced image. + """ + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(gray) + + # Upscale if image is too small + h, w = enhanced.shape + if min(h, w) < 800: + scale = 800 / min(h, w) + enhanced = cv2.resize( + enhanced, None, fx=scale, fy=scale, + interpolation=cv2.INTER_CUBIC + ) + + return enhanced + + +def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes: + """Convert BGR ndarray to PNG bytes.""" + success, buf = cv2.imencode(".png", img_bgr) + if not success: + raise RuntimeError("Failed to encode image to PNG") + return buf.tobytes() + + +def _preprocess_image_bytes(image_bytes: bytes) -> bytes: + """Load image, apply HTR preprocessing, return PNG bytes.""" + arr = np.frombuffer(image_bytes, dtype=np.uint8) + img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise ValueError("Could not decode image") + + enhanced = _preprocess_for_htr(img_bgr) + # Convert grayscale back to BGR for encoding + enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR) + return _bgr_to_png_bytes(enhanced_bgr) + + +# --------------------------------------------------------------------------- +# Backend: Ollama qwen2.5vl +# --------------------------------------------------------------------------- + +async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]: + """ + Send image to Ollama qwen2.5vl:32b for HTR. + Returns extracted text or None on error. + """ + import httpx + + lang_hint = { + "de": "Deutsch", + "en": "Englisch", + "de+en": "Deutsch und Englisch", + }.get(language, "Deutsch") + + prompt = ( + f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. " + "Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. " + "Antworte NUR mit dem erkannten Text, ohne Erklaerungen." + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + + payload = { + "model": OLLAMA_HTR_MODEL, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload) + resp.raise_for_status() + data = resp.json() + return data.get("response", "").strip() + except Exception as e: + logger.warning(f"Ollama qwen2.5vl HTR failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# Backend: TrOCR-large fallback +# --------------------------------------------------------------------------- + +async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]: + """ + Use microsoft/trocr-large-handwritten via trocr_service.py. + Returns extracted text or None on error. + """ + try: + from services.trocr_service import run_trocr_ocr, _check_trocr_available + if not _check_trocr_available(): + logger.warning("TrOCR not available for HTR fallback") + return None + + text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large") + return text.strip() if text else None + except Exception as e: + logger.warning(f"TrOCR-large HTR failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# Core recognition logic +# --------------------------------------------------------------------------- + +async def _do_recognize( + image_bytes: bytes, + model: str = "auto", + preprocess: bool = True, + language: str = "de", +) -> dict: + """ + Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large. + Returns dict with text, model_used, processing_time_ms. + """ + t0 = time.monotonic() + + if preprocess: + try: + image_bytes = _preprocess_image_bytes(image_bytes) + except Exception as e: + logger.warning(f"HTR preprocessing failed, using raw image: {e}") + + text: Optional[str] = None + model_used: str = "none" + + use_qwen = model in ("auto", "qwen2.5vl") + use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None) + + if use_qwen: + text = await _recognize_with_qwen_vl(image_bytes, language) + if text is not None: + model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})" + + if text is None and (use_trocr or model == "trocr-large"): + text = await _recognize_with_trocr_large(image_bytes) + if text is not None: + model_used = "trocr-large-handwritten" + + if text is None: + text = "" + model_used = "none (all backends failed)" + + elapsed_ms = int((time.monotonic() - t0) * 1000) + + return { + "text": text, + "model_used": model_used, + "processing_time_ms": elapsed_ms, + "language": language, + "preprocessed": preprocess, + } + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@router.post("/recognize") +async def recognize_handwriting( + file: UploadFile = File(...), + model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"), + preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"), + language: str = Query("de", description="de | en | de+en"), +): + """ + Upload an image and get back the handwritten text as plain text. + + Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten. + """ + if model not in ("auto", "qwen2.5vl", "trocr-large"): + raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large") + if language not in ("de", "en", "de+en"): + raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en") + + image_bytes = await file.read() + if not image_bytes: + raise HTTPException(status_code=400, detail="Empty file") + + return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language) + + +@router.post("/recognize-session") +async def recognize_from_session(req: HTRSessionRequest): + """ + Use an OCR-Pipeline session as image source for HTR. + + Set use_clean=true to prefer the clean image (after handwriting removal step). + This is useful when you want to do HTR on isolated handwriting regions. + """ + from ocr_pipeline_session_store import get_session_db, get_session_image + + session = await get_session_db(req.session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found") + + # Choose source image + image_bytes: Optional[bytes] = None + source_used: str = "" + + if req.use_clean: + image_bytes = await get_session_image(req.session_id, "clean") + if image_bytes: + source_used = "clean" + + if not image_bytes: + image_bytes = await get_session_image(req.session_id, "deskewed") + if image_bytes: + source_used = "deskewed" + + if not image_bytes: + image_bytes = await get_session_image(req.session_id, "original") + source_used = "original" + + if not image_bytes: + raise HTTPException(status_code=404, detail="No image available in session") + + result = await _do_recognize(image_bytes, model=req.model) + result["session_id"] = req.session_id + result["source_image"] = source_used + return result diff --git a/klausur-service/backend/ocr/spell/__init__.py b/klausur-service/backend/ocr/spell/__init__.py new file mode 100644 index 0000000..cb10785 --- /dev/null +++ b/klausur-service/backend/ocr/spell/__init__.py @@ -0,0 +1,7 @@ +""" +OCR spell-checking sub-package — language-aware OCR correction. + +Moved from backend/ flat modules (smart_spell*.py). +Backward-compatible shim files remain at the old locations. +""" +from .smart_spell import * # noqa: F401,F403 diff --git a/klausur-service/backend/ocr/spell/core.py b/klausur-service/backend/ocr/spell/core.py new file mode 100644 index 0000000..9f2fa7d --- /dev/null +++ b/klausur-service/backend/ocr/spell/core.py @@ -0,0 +1,298 @@ +""" +SmartSpellChecker Core — init, data types, language detection, word correction. + +Extracted from smart_spell.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +""" + +import logging +import re +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, Set, Tuple + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + +try: + from spellchecker import SpellChecker as _SpellChecker + _en_spell = _SpellChecker(language='en', distance=1) + _de_spell = _SpellChecker(language='de', distance=1) + _AVAILABLE = True +except ImportError: + _AVAILABLE = False + logger.warning("pyspellchecker not installed — SmartSpellChecker disabled") + +Lang = Literal["en", "de", "both", "unknown"] + +# --------------------------------------------------------------------------- +# Bigram context for a/I disambiguation +# --------------------------------------------------------------------------- + +# Words that commonly follow "I" (subject pronoun -> verb/modal) +_I_FOLLOWERS: frozenset = frozenset({ + "am", "was", "have", "had", "do", "did", "will", "would", "can", + "could", "should", "shall", "may", "might", "must", + "think", "know", "see", "want", "need", "like", "love", "hate", + "go", "went", "come", "came", "say", "said", "get", "got", + "make", "made", "take", "took", "give", "gave", "tell", "told", + "feel", "felt", "find", "found", "believe", "hope", "wish", + "remember", "forget", "understand", "mean", "meant", + "don't", "didn't", "can't", "won't", "couldn't", "wouldn't", + "shouldn't", "haven't", "hadn't", "isn't", "wasn't", + "really", "just", "also", "always", "never", "often", "sometimes", +}) + +# Words that commonly follow "a" (article -> noun/adjective) +_A_FOLLOWERS: frozenset = frozenset({ + "lot", "few", "little", "bit", "good", "bad", "great", "new", "old", + "long", "short", "big", "small", "large", "huge", "tiny", + "nice", "beautiful", "wonderful", "terrible", "horrible", + "man", "woman", "boy", "girl", "child", "dog", "cat", "bird", + "book", "car", "house", "room", "school", "teacher", "student", + "day", "week", "month", "year", "time", "place", "way", + "friend", "family", "person", "problem", "question", "story", + "very", "really", "quite", "rather", "pretty", "single", +}) + +# Digit->letter substitutions (OCR confusion) +_DIGIT_SUBS: Dict[str, List[str]] = { + '0': ['o', 'O'], + '1': ['l', 'I'], + '5': ['s', 'S'], + '6': ['g', 'G'], + '8': ['b', 'B'], + '|': ['I', 'l'], + '/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl") +} +_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys()) + +# Umlaut confusion: OCR drops dots (u->u, a->a, o->o) +_UMLAUT_MAP = { + 'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc', + 'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc', +} + +# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words +_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)") + + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + +@dataclass +class CorrectionResult: + original: str + corrected: str + lang_detected: Lang + changed: bool + changes: List[str] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Core class — language detection and word-level correction +# --------------------------------------------------------------------------- + +class _SmartSpellCoreBase: + """Base class with language detection and single-word correction. + + Not intended for direct use — SmartSpellChecker inherits from this. + """ + + def __init__(self): + if not _AVAILABLE: + raise RuntimeError("pyspellchecker not installed") + self.en = _en_spell + self.de = _de_spell + + # --- Language detection --- + + def detect_word_lang(self, word: str) -> Lang: + """Detect language of a single word using dual-dict heuristic.""" + w = word.lower().strip(".,;:!?\"'()") + if not w: + return "unknown" + in_en = bool(self.en.known([w])) + in_de = bool(self.de.known([w])) + if in_en and in_de: + return "both" + if in_en: + return "en" + if in_de: + return "de" + return "unknown" + + def detect_text_lang(self, text: str) -> Lang: + """Detect dominant language of a text string (sentence/phrase).""" + words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text) + if not words: + return "unknown" + + en_count = 0 + de_count = 0 + for w in words: + lang = self.detect_word_lang(w) + if lang == "en": + en_count += 1 + elif lang == "de": + de_count += 1 + # "both" doesn't count for either + + if en_count > de_count: + return "en" + if de_count > en_count: + return "de" + if en_count == de_count and en_count > 0: + return "both" + return "unknown" + + # --- Single-word correction --- + + def _known(self, word: str) -> bool: + """True if word is known in EN or DE dictionary, or is a known abbreviation.""" + w = word.lower() + if bool(self.en.known([w])) or bool(self.de.known([w])): + return True + # Also accept known abbreviations (sth, sb, adj, etc.) + try: + from cv_ocr_engines import _KNOWN_ABBREVIATIONS + if w in _KNOWN_ABBREVIATIONS: + return True + except ImportError: + pass + return False + + def _word_freq(self, word: str) -> float: + """Get word frequency (max of EN and DE).""" + w = word.lower() + return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w)) + + def _known_in(self, word: str, lang: str) -> bool: + """True if word is known in a specific language dictionary.""" + w = word.lower() + spell = self.en if lang == "en" else self.de + return bool(spell.known([w])) + + def correct_word(self, word: str, lang: str = "en", + prev_word: str = "", next_word: str = "") -> Optional[str]: + """Correct a single word for the given language. + + Returns None if no correction needed, or the corrected string. + """ + if not word or not word.strip(): + return None + + # Skip numbers, abbreviations with dots, very short tokens + if word.isdigit() or '.' in word: + return None + + # Skip IPA/phonetic content in brackets + if '[' in word or ']' in word: + return None + + has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word) + + # 1. Already known -> no fix + if self._known(word): + # But check a/I disambiguation for single-char words + if word.lower() in ('l', '|') and next_word: + return self._disambiguate_a_I(word, next_word) + return None + + # 2. Digit/pipe substitution + if has_suspicious: + if word == '|': + return 'I' + # Try single-char substitutions + for i, ch in enumerate(word): + if ch not in _DIGIT_SUBS: + continue + for replacement in _DIGIT_SUBS[ch]: + candidate = word[:i] + replacement + word[i + 1:] + if self._known(candidate): + return candidate + # Try multi-char substitution (e.g., "sch00l" -> "school") + multi = self._try_multi_digit_sub(word) + if multi: + return multi + + # 3. Umlaut correction (German) + if lang == "de" and len(word) >= 3 and word.isalpha(): + umlaut_fix = self._try_umlaut_fix(word) + if umlaut_fix: + return umlaut_fix + + # 4. General spell correction + if not has_suspicious and len(word) >= 3 and word.isalpha(): + # Safety: don't correct if the word is valid in the OTHER language + other_lang = "de" if lang == "en" else "en" + if self._known_in(word, other_lang): + return None + if other_lang == "de" and self._try_umlaut_fix(word): + return None # has a valid DE umlaut variant -> don't touch + + spell = self.en if lang == "en" else self.de + correction = spell.correction(word.lower()) + if correction and correction != word.lower(): + if word[0].isupper(): + correction = correction[0].upper() + correction[1:] + if self._known(correction): + return correction + + return None + + # --- Multi-digit substitution --- + + def _try_multi_digit_sub(self, word: str) -> Optional[str]: + """Try replacing multiple digits simultaneously using BFS.""" + positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS] + if not positions or len(positions) > 4: + return None + + # BFS over substitution combinations + queue = [list(word)] + for pos, ch in positions: + next_queue = [] + for current in queue: + # Keep original + next_queue.append(current[:]) + # Try each substitution + for repl in _DIGIT_SUBS[ch]: + variant = current[:] + variant[pos] = repl + next_queue.append(variant) + queue = next_queue + + # Check which combinations produce known words + for combo in queue: + candidate = "".join(combo) + if candidate != word and self._known(candidate): + return candidate + + return None + + # --- Umlaut fix --- + + def _try_umlaut_fix(self, word: str) -> Optional[str]: + """Try single-char umlaut substitutions for German words.""" + for i, ch in enumerate(word): + if ch in _UMLAUT_MAP: + candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:] + if self._known(candidate): + return candidate + return None + + # --- a/I disambiguation --- + + def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]: + """Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|').""" + nw = next_word.lower().strip(".,;:!?") + if nw in _I_FOLLOWERS: + return "I" + if nw in _A_FOLLOWERS: + return "a" + return None # uncertain, don't change diff --git a/klausur-service/backend/ocr/spell/smart_spell.py b/klausur-service/backend/ocr/spell/smart_spell.py new file mode 100644 index 0000000..2e72096 --- /dev/null +++ b/klausur-service/backend/ocr/spell/smart_spell.py @@ -0,0 +1,25 @@ +""" +SmartSpellChecker — barrel re-export. + +All implementation split into: + smart_spell_core — init, data types, language detection, word correction + smart_spell_text — full text correction, boundary repair, context split + +Lizenz: Apache 2.0 (kommerziell nutzbar) +""" + +# Core: data types, lang detection (re-exported for tests) +from .core import ( # noqa: F401 + _AVAILABLE, + _DIGIT_SUBS, + _SUSPICIOUS_CHARS, + _UMLAUT_MAP, + _TOKEN_RE, + _I_FOLLOWERS, + _A_FOLLOWERS, + CorrectionResult, + Lang, +) + +# Text: SmartSpellChecker class (the main public API) +from .text import SmartSpellChecker # noqa: F401 diff --git a/klausur-service/backend/ocr/spell/text.py b/klausur-service/backend/ocr/spell/text.py new file mode 100644 index 0000000..a5081f4 --- /dev/null +++ b/klausur-service/backend/ocr/spell/text.py @@ -0,0 +1,289 @@ +""" +SmartSpellChecker Text — full text correction, boundary repair, context split. + +Extracted from smart_spell.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +""" + +import re +from typing import Dict, List, Optional, Tuple + +from .core import ( + _SmartSpellCoreBase, + _TOKEN_RE, + CorrectionResult, + Lang, +) + + +class SmartSpellChecker(_SmartSpellCoreBase): + """Language-aware OCR spell checker using pyspellchecker (no LLM). + + Inherits single-word correction from _SmartSpellCoreBase. + Adds text-level passes: boundary repair, context split, full correction. + """ + + # --- Boundary repair (shifted word boundaries) --- + + def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]: + """Fix shifted word boundaries between adjacent tokens. + + OCR sometimes shifts the boundary: "at sth." -> "ats th." + Try moving 1-2 chars from end of word1 to start of word2 and vice versa. + Returns (fixed_word1, fixed_word2) or None. + """ + # Import known abbreviations for vocabulary context + try: + from cv_ocr_engines import _KNOWN_ABBREVIATIONS + except ImportError: + _KNOWN_ABBREVIATIONS = set() + + # Strip trailing punctuation for checking, preserve for result + w2_stripped = word2.rstrip(".,;:!?") + w2_punct = word2[len(w2_stripped):] + + # Try shifting 1-2 chars from word1 -> word2 + for shift in (1, 2): + if len(word1) <= shift: + continue + new_w1 = word1[:-shift] + new_w2_base = word1[-shift:] + w2_stripped + + w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS + w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS + + if w1_ok and w2_ok: + return (new_w1, new_w2_base + w2_punct) + + # Try shifting 1-2 chars from word2 -> word1 + for shift in (1, 2): + if len(w2_stripped) <= shift: + continue + new_w1 = word1 + w2_stripped[:shift] + new_w2_base = w2_stripped[shift:] + + w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS + w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS + + if w1_ok and w2_ok: + return (new_w1, new_w2_base + w2_punct) + + return None + + # --- Context-based word split for ambiguous merges --- + + # Patterns where a valid word is actually "a" + adjective/noun + _ARTICLE_SPLIT_CANDIDATES = { + # word -> (article, remainder) -- only when followed by a compatible word + "anew": ("a", "new"), + "areal": ("a", "real"), + "alive": None, # genuinely one word, never split + "alone": None, + "aware": None, + "alike": None, + "apart": None, + "aside": None, + "above": None, + "about": None, + "among": None, + "along": None, + } + + def _try_context_split(self, word: str, next_word: str, + prev_word: str) -> Optional[str]: + """Split words like 'anew' -> 'a new' when context indicates a merge. + + Only splits when: + - The word is in the split candidates list + - The following word makes sense as a noun (for "a + adj + noun" pattern) + - OR the word is unknown and can be split into article + known word + """ + w_lower = word.lower() + + # Check explicit candidates + if w_lower in self._ARTICLE_SPLIT_CANDIDATES: + split = self._ARTICLE_SPLIT_CANDIDATES[w_lower] + if split is None: + return None # explicitly marked as "don't split" + article, remainder = split + # Only split if followed by a word (noun pattern) + if next_word and next_word[0].islower(): + return f"{article} {remainder}" + # Also split if remainder + next_word makes a common phrase + if next_word and self._known(next_word): + return f"{article} {remainder}" + + # Generic: if word starts with 'a' and rest is a known adjective/word + if (len(word) >= 4 and word[0].lower() == 'a' + and not self._known(word) # only for UNKNOWN words + and self._known(word[1:])): + return f"a {word[1:]}" + + return None + + # --- Full text correction --- + + def correct_text(self, text: str, lang: str = "en") -> CorrectionResult: + """Correct a full text string (field value). + + Three passes: + 1. Boundary repair -- fix shifted word boundaries between adjacent tokens + 2. Context split -- split ambiguous merges (anew -> a new) + 3. Per-word correction -- spell check individual words + """ + if not text or not text.strip(): + return CorrectionResult(text, text, "unknown", False) + + detected = self.detect_text_lang(text) if lang == "auto" else lang + effective_lang = detected if detected in ("en", "de") else "en" + + changes: List[str] = [] + tokens = list(_TOKEN_RE.finditer(text)) + + # Extract token list: [(word, separator), ...] + token_list: List[List[str]] = [] # [[word, sep], ...] + for m in tokens: + token_list.append([m.group(1), m.group(2)]) + + # --- Pass 1: Boundary repair between adjacent unknown words --- + # Import abbreviations for the heuristic below + try: + from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS + except ImportError: + _ABBREVS = set() + + for i in range(len(token_list) - 1): + w1 = token_list[i][0] + w2_raw = token_list[i + 1][0] + + # Skip boundary repair for IPA/bracket content + # Brackets may be in the token OR in the adjacent separators + sep_before_w1 = token_list[i - 1][1] if i > 0 else "" + sep_after_w1 = token_list[i][1] + sep_after_w2 = token_list[i + 1][1] + has_bracket = ( + '[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw + or ']' in sep_after_w1 # w1 text was inside [brackets] + or '[' in sep_after_w1 # w2 starts a bracket + or ']' in sep_after_w2 # w2 text was inside [brackets] + or '[' in sep_before_w1 # w1 starts a bracket + ) + if has_bracket: + continue + + # Include trailing punct from separator in w2 for abbreviation matching + w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ") + + # Try boundary repair -- always, even if both words are valid. + # Use word-frequency scoring to decide if repair is better. + repair = self._try_boundary_repair(w1, w2_with_punct) + if not repair and w2_with_punct != w2_raw: + repair = self._try_boundary_repair(w1, w2_raw) + if repair: + new_w1, new_w2_full = repair + new_w2_base = new_w2_full.rstrip(".,;:!?") + + # Frequency-based scoring: product of word frequencies + # Higher product = more common word pair = better + old_freq = self._word_freq(w1) * self._word_freq(w2_raw) + new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base) + + # Abbreviation bonus: if repair produces a known abbreviation + has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS + if has_abbrev: + # Accept abbreviation repair ONLY if at least one of the + # original words is rare/unknown (prevents "Can I" -> "Ca nI" + # where both original words are common and correct). + RARE_THRESHOLD = 1e-6 + orig_both_common = ( + self._word_freq(w1) > RARE_THRESHOLD + and self._word_freq(w2_raw) > RARE_THRESHOLD + ) + if not orig_both_common: + new_freq = max(new_freq, old_freq * 10) + else: + has_abbrev = False # both originals common -> don't trust + + # Accept if repair produces a more frequent word pair + # (threshold: at least 5x more frequent to avoid false positives) + if new_freq > old_freq * 5: + new_w2_punct = new_w2_full[len(new_w2_base):] + changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}") + token_list[i][0] = new_w1 + token_list[i + 1][0] = new_w2_base + if new_w2_punct: + token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?") + + # --- Pass 2: Context split (anew -> a new) --- + expanded: List[List[str]] = [] + for i, (word, sep) in enumerate(token_list): + next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" + prev_word = token_list[i - 1][0] if i > 0 else "" + split = self._try_context_split(word, next_word, prev_word) + if split and split != word: + changes.append(f"{word}\u2192{split}") + expanded.append([split, sep]) + else: + expanded.append([word, sep]) + token_list = expanded + + # --- Pass 3: Per-word correction --- + parts: List[str] = [] + + # Preserve any leading text before the first token match + first_start = tokens[0].start() if tokens else 0 + if first_start > 0: + parts.append(text[:first_start]) + + for i, (word, sep) in enumerate(token_list): + # Skip words inside IPA brackets (brackets land in separators) + prev_sep = token_list[i - 1][1] if i > 0 else "" + if '[' in prev_sep or ']' in sep: + parts.append(word) + parts.append(sep) + continue + + next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" + prev_word = token_list[i - 1][0] if i > 0 else "" + + correction = self.correct_word( + word, lang=effective_lang, + prev_word=prev_word, next_word=next_word, + ) + if correction and correction != word: + changes.append(f"{word}\u2192{correction}") + parts.append(correction) + else: + parts.append(word) + parts.append(sep) + + # Append any trailing text + last_end = tokens[-1].end() if tokens else 0 + if last_end < len(text): + parts.append(text[last_end:]) + + corrected = "".join(parts) + return CorrectionResult( + original=text, + corrected=corrected, + lang_detected=detected, + changed=corrected != text, + changes=changes, + ) + + # --- Vocabulary entry correction --- + + def correct_vocab_entry(self, english: str, german: str, + example: str = "") -> Dict[str, CorrectionResult]: + """Correct a full vocabulary entry (EN + DE + example). + + Uses column position to determine language -- the most reliable signal. + """ + results = {} + results["english"] = self.correct_text(english, lang="en") + results["german"] = self.correct_text(german, lang="de") + if example: + # For examples, auto-detect language + results["example"] = self.correct_text(example, lang="auto") + return results diff --git a/klausur-service/backend/smart_spell.py b/klausur-service/backend/smart_spell.py index 1926500..6ad63f0 100644 --- a/klausur-service/backend/smart_spell.py +++ b/klausur-service/backend/smart_spell.py @@ -1,25 +1,4 @@ -""" -SmartSpellChecker — barrel re-export. - -All implementation split into: - smart_spell_core — init, data types, language detection, word correction - smart_spell_text — full text correction, boundary repair, context split - -Lizenz: Apache 2.0 (kommerziell nutzbar) -""" - -# Core: data types, lang detection (re-exported for tests) -from smart_spell_core import ( # noqa: F401 - _AVAILABLE, - _DIGIT_SUBS, - _SUSPICIOUS_CHARS, - _UMLAUT_MAP, - _TOKEN_RE, - _I_FOLLOWERS, - _A_FOLLOWERS, - CorrectionResult, - Lang, -) - -# Text: SmartSpellChecker class (the main public API) -from smart_spell_text import SmartSpellChecker # noqa: F401 +# Backward-compat shim -- module moved to ocr/spell/smart_spell.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.spell.smart_spell") diff --git a/klausur-service/backend/smart_spell_core.py b/klausur-service/backend/smart_spell_core.py index 9f2fa7d..0082387 100644 --- a/klausur-service/backend/smart_spell_core.py +++ b/klausur-service/backend/smart_spell_core.py @@ -1,298 +1,4 @@ -""" -SmartSpellChecker Core — init, data types, language detection, word correction. - -Extracted from smart_spell.py for modularity. - -Lizenz: Apache 2.0 (kommerziell nutzbar) -""" - -import logging -import re -from dataclasses import dataclass, field -from typing import Dict, List, Literal, Optional, Set, Tuple - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Init -# --------------------------------------------------------------------------- - -try: - from spellchecker import SpellChecker as _SpellChecker - _en_spell = _SpellChecker(language='en', distance=1) - _de_spell = _SpellChecker(language='de', distance=1) - _AVAILABLE = True -except ImportError: - _AVAILABLE = False - logger.warning("pyspellchecker not installed — SmartSpellChecker disabled") - -Lang = Literal["en", "de", "both", "unknown"] - -# --------------------------------------------------------------------------- -# Bigram context for a/I disambiguation -# --------------------------------------------------------------------------- - -# Words that commonly follow "I" (subject pronoun -> verb/modal) -_I_FOLLOWERS: frozenset = frozenset({ - "am", "was", "have", "had", "do", "did", "will", "would", "can", - "could", "should", "shall", "may", "might", "must", - "think", "know", "see", "want", "need", "like", "love", "hate", - "go", "went", "come", "came", "say", "said", "get", "got", - "make", "made", "take", "took", "give", "gave", "tell", "told", - "feel", "felt", "find", "found", "believe", "hope", "wish", - "remember", "forget", "understand", "mean", "meant", - "don't", "didn't", "can't", "won't", "couldn't", "wouldn't", - "shouldn't", "haven't", "hadn't", "isn't", "wasn't", - "really", "just", "also", "always", "never", "often", "sometimes", -}) - -# Words that commonly follow "a" (article -> noun/adjective) -_A_FOLLOWERS: frozenset = frozenset({ - "lot", "few", "little", "bit", "good", "bad", "great", "new", "old", - "long", "short", "big", "small", "large", "huge", "tiny", - "nice", "beautiful", "wonderful", "terrible", "horrible", - "man", "woman", "boy", "girl", "child", "dog", "cat", "bird", - "book", "car", "house", "room", "school", "teacher", "student", - "day", "week", "month", "year", "time", "place", "way", - "friend", "family", "person", "problem", "question", "story", - "very", "really", "quite", "rather", "pretty", "single", -}) - -# Digit->letter substitutions (OCR confusion) -_DIGIT_SUBS: Dict[str, List[str]] = { - '0': ['o', 'O'], - '1': ['l', 'I'], - '5': ['s', 'S'], - '6': ['g', 'G'], - '8': ['b', 'B'], - '|': ['I', 'l'], - '/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl") -} -_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys()) - -# Umlaut confusion: OCR drops dots (u->u, a->a, o->o) -_UMLAUT_MAP = { - 'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc', - 'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc', -} - -# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words -_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)") - - -# --------------------------------------------------------------------------- -# Data types -# --------------------------------------------------------------------------- - -@dataclass -class CorrectionResult: - original: str - corrected: str - lang_detected: Lang - changed: bool - changes: List[str] = field(default_factory=list) - - -# --------------------------------------------------------------------------- -# Core class — language detection and word-level correction -# --------------------------------------------------------------------------- - -class _SmartSpellCoreBase: - """Base class with language detection and single-word correction. - - Not intended for direct use — SmartSpellChecker inherits from this. - """ - - def __init__(self): - if not _AVAILABLE: - raise RuntimeError("pyspellchecker not installed") - self.en = _en_spell - self.de = _de_spell - - # --- Language detection --- - - def detect_word_lang(self, word: str) -> Lang: - """Detect language of a single word using dual-dict heuristic.""" - w = word.lower().strip(".,;:!?\"'()") - if not w: - return "unknown" - in_en = bool(self.en.known([w])) - in_de = bool(self.de.known([w])) - if in_en and in_de: - return "both" - if in_en: - return "en" - if in_de: - return "de" - return "unknown" - - def detect_text_lang(self, text: str) -> Lang: - """Detect dominant language of a text string (sentence/phrase).""" - words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text) - if not words: - return "unknown" - - en_count = 0 - de_count = 0 - for w in words: - lang = self.detect_word_lang(w) - if lang == "en": - en_count += 1 - elif lang == "de": - de_count += 1 - # "both" doesn't count for either - - if en_count > de_count: - return "en" - if de_count > en_count: - return "de" - if en_count == de_count and en_count > 0: - return "both" - return "unknown" - - # --- Single-word correction --- - - def _known(self, word: str) -> bool: - """True if word is known in EN or DE dictionary, or is a known abbreviation.""" - w = word.lower() - if bool(self.en.known([w])) or bool(self.de.known([w])): - return True - # Also accept known abbreviations (sth, sb, adj, etc.) - try: - from cv_ocr_engines import _KNOWN_ABBREVIATIONS - if w in _KNOWN_ABBREVIATIONS: - return True - except ImportError: - pass - return False - - def _word_freq(self, word: str) -> float: - """Get word frequency (max of EN and DE).""" - w = word.lower() - return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w)) - - def _known_in(self, word: str, lang: str) -> bool: - """True if word is known in a specific language dictionary.""" - w = word.lower() - spell = self.en if lang == "en" else self.de - return bool(spell.known([w])) - - def correct_word(self, word: str, lang: str = "en", - prev_word: str = "", next_word: str = "") -> Optional[str]: - """Correct a single word for the given language. - - Returns None if no correction needed, or the corrected string. - """ - if not word or not word.strip(): - return None - - # Skip numbers, abbreviations with dots, very short tokens - if word.isdigit() or '.' in word: - return None - - # Skip IPA/phonetic content in brackets - if '[' in word or ']' in word: - return None - - has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word) - - # 1. Already known -> no fix - if self._known(word): - # But check a/I disambiguation for single-char words - if word.lower() in ('l', '|') and next_word: - return self._disambiguate_a_I(word, next_word) - return None - - # 2. Digit/pipe substitution - if has_suspicious: - if word == '|': - return 'I' - # Try single-char substitutions - for i, ch in enumerate(word): - if ch not in _DIGIT_SUBS: - continue - for replacement in _DIGIT_SUBS[ch]: - candidate = word[:i] + replacement + word[i + 1:] - if self._known(candidate): - return candidate - # Try multi-char substitution (e.g., "sch00l" -> "school") - multi = self._try_multi_digit_sub(word) - if multi: - return multi - - # 3. Umlaut correction (German) - if lang == "de" and len(word) >= 3 and word.isalpha(): - umlaut_fix = self._try_umlaut_fix(word) - if umlaut_fix: - return umlaut_fix - - # 4. General spell correction - if not has_suspicious and len(word) >= 3 and word.isalpha(): - # Safety: don't correct if the word is valid in the OTHER language - other_lang = "de" if lang == "en" else "en" - if self._known_in(word, other_lang): - return None - if other_lang == "de" and self._try_umlaut_fix(word): - return None # has a valid DE umlaut variant -> don't touch - - spell = self.en if lang == "en" else self.de - correction = spell.correction(word.lower()) - if correction and correction != word.lower(): - if word[0].isupper(): - correction = correction[0].upper() + correction[1:] - if self._known(correction): - return correction - - return None - - # --- Multi-digit substitution --- - - def _try_multi_digit_sub(self, word: str) -> Optional[str]: - """Try replacing multiple digits simultaneously using BFS.""" - positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS] - if not positions or len(positions) > 4: - return None - - # BFS over substitution combinations - queue = [list(word)] - for pos, ch in positions: - next_queue = [] - for current in queue: - # Keep original - next_queue.append(current[:]) - # Try each substitution - for repl in _DIGIT_SUBS[ch]: - variant = current[:] - variant[pos] = repl - next_queue.append(variant) - queue = next_queue - - # Check which combinations produce known words - for combo in queue: - candidate = "".join(combo) - if candidate != word and self._known(candidate): - return candidate - - return None - - # --- Umlaut fix --- - - def _try_umlaut_fix(self, word: str) -> Optional[str]: - """Try single-char umlaut substitutions for German words.""" - for i, ch in enumerate(word): - if ch in _UMLAUT_MAP: - candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:] - if self._known(candidate): - return candidate - return None - - # --- a/I disambiguation --- - - def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]: - """Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|').""" - nw = next_word.lower().strip(".,;:!?") - if nw in _I_FOLLOWERS: - return "I" - if nw in _A_FOLLOWERS: - return "a" - return None # uncertain, don't change +# Backward-compat shim -- module moved to ocr/spell/core.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.spell.core") diff --git a/klausur-service/backend/smart_spell_text.py b/klausur-service/backend/smart_spell_text.py index 7628e61..c3dc96f 100644 --- a/klausur-service/backend/smart_spell_text.py +++ b/klausur-service/backend/smart_spell_text.py @@ -1,289 +1,4 @@ -""" -SmartSpellChecker Text — full text correction, boundary repair, context split. - -Extracted from smart_spell.py for modularity. - -Lizenz: Apache 2.0 (kommerziell nutzbar) -""" - -import re -from typing import Dict, List, Optional, Tuple - -from smart_spell_core import ( - _SmartSpellCoreBase, - _TOKEN_RE, - CorrectionResult, - Lang, -) - - -class SmartSpellChecker(_SmartSpellCoreBase): - """Language-aware OCR spell checker using pyspellchecker (no LLM). - - Inherits single-word correction from _SmartSpellCoreBase. - Adds text-level passes: boundary repair, context split, full correction. - """ - - # --- Boundary repair (shifted word boundaries) --- - - def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]: - """Fix shifted word boundaries between adjacent tokens. - - OCR sometimes shifts the boundary: "at sth." -> "ats th." - Try moving 1-2 chars from end of word1 to start of word2 and vice versa. - Returns (fixed_word1, fixed_word2) or None. - """ - # Import known abbreviations for vocabulary context - try: - from cv_ocr_engines import _KNOWN_ABBREVIATIONS - except ImportError: - _KNOWN_ABBREVIATIONS = set() - - # Strip trailing punctuation for checking, preserve for result - w2_stripped = word2.rstrip(".,;:!?") - w2_punct = word2[len(w2_stripped):] - - # Try shifting 1-2 chars from word1 -> word2 - for shift in (1, 2): - if len(word1) <= shift: - continue - new_w1 = word1[:-shift] - new_w2_base = word1[-shift:] + w2_stripped - - w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS - w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS - - if w1_ok and w2_ok: - return (new_w1, new_w2_base + w2_punct) - - # Try shifting 1-2 chars from word2 -> word1 - for shift in (1, 2): - if len(w2_stripped) <= shift: - continue - new_w1 = word1 + w2_stripped[:shift] - new_w2_base = w2_stripped[shift:] - - w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS - w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS - - if w1_ok and w2_ok: - return (new_w1, new_w2_base + w2_punct) - - return None - - # --- Context-based word split for ambiguous merges --- - - # Patterns where a valid word is actually "a" + adjective/noun - _ARTICLE_SPLIT_CANDIDATES = { - # word -> (article, remainder) -- only when followed by a compatible word - "anew": ("a", "new"), - "areal": ("a", "real"), - "alive": None, # genuinely one word, never split - "alone": None, - "aware": None, - "alike": None, - "apart": None, - "aside": None, - "above": None, - "about": None, - "among": None, - "along": None, - } - - def _try_context_split(self, word: str, next_word: str, - prev_word: str) -> Optional[str]: - """Split words like 'anew' -> 'a new' when context indicates a merge. - - Only splits when: - - The word is in the split candidates list - - The following word makes sense as a noun (for "a + adj + noun" pattern) - - OR the word is unknown and can be split into article + known word - """ - w_lower = word.lower() - - # Check explicit candidates - if w_lower in self._ARTICLE_SPLIT_CANDIDATES: - split = self._ARTICLE_SPLIT_CANDIDATES[w_lower] - if split is None: - return None # explicitly marked as "don't split" - article, remainder = split - # Only split if followed by a word (noun pattern) - if next_word and next_word[0].islower(): - return f"{article} {remainder}" - # Also split if remainder + next_word makes a common phrase - if next_word and self._known(next_word): - return f"{article} {remainder}" - - # Generic: if word starts with 'a' and rest is a known adjective/word - if (len(word) >= 4 and word[0].lower() == 'a' - and not self._known(word) # only for UNKNOWN words - and self._known(word[1:])): - return f"a {word[1:]}" - - return None - - # --- Full text correction --- - - def correct_text(self, text: str, lang: str = "en") -> CorrectionResult: - """Correct a full text string (field value). - - Three passes: - 1. Boundary repair -- fix shifted word boundaries between adjacent tokens - 2. Context split -- split ambiguous merges (anew -> a new) - 3. Per-word correction -- spell check individual words - """ - if not text or not text.strip(): - return CorrectionResult(text, text, "unknown", False) - - detected = self.detect_text_lang(text) if lang == "auto" else lang - effective_lang = detected if detected in ("en", "de") else "en" - - changes: List[str] = [] - tokens = list(_TOKEN_RE.finditer(text)) - - # Extract token list: [(word, separator), ...] - token_list: List[List[str]] = [] # [[word, sep], ...] - for m in tokens: - token_list.append([m.group(1), m.group(2)]) - - # --- Pass 1: Boundary repair between adjacent unknown words --- - # Import abbreviations for the heuristic below - try: - from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS - except ImportError: - _ABBREVS = set() - - for i in range(len(token_list) - 1): - w1 = token_list[i][0] - w2_raw = token_list[i + 1][0] - - # Skip boundary repair for IPA/bracket content - # Brackets may be in the token OR in the adjacent separators - sep_before_w1 = token_list[i - 1][1] if i > 0 else "" - sep_after_w1 = token_list[i][1] - sep_after_w2 = token_list[i + 1][1] - has_bracket = ( - '[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw - or ']' in sep_after_w1 # w1 text was inside [brackets] - or '[' in sep_after_w1 # w2 starts a bracket - or ']' in sep_after_w2 # w2 text was inside [brackets] - or '[' in sep_before_w1 # w1 starts a bracket - ) - if has_bracket: - continue - - # Include trailing punct from separator in w2 for abbreviation matching - w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ") - - # Try boundary repair -- always, even if both words are valid. - # Use word-frequency scoring to decide if repair is better. - repair = self._try_boundary_repair(w1, w2_with_punct) - if not repair and w2_with_punct != w2_raw: - repair = self._try_boundary_repair(w1, w2_raw) - if repair: - new_w1, new_w2_full = repair - new_w2_base = new_w2_full.rstrip(".,;:!?") - - # Frequency-based scoring: product of word frequencies - # Higher product = more common word pair = better - old_freq = self._word_freq(w1) * self._word_freq(w2_raw) - new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base) - - # Abbreviation bonus: if repair produces a known abbreviation - has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS - if has_abbrev: - # Accept abbreviation repair ONLY if at least one of the - # original words is rare/unknown (prevents "Can I" -> "Ca nI" - # where both original words are common and correct). - RARE_THRESHOLD = 1e-6 - orig_both_common = ( - self._word_freq(w1) > RARE_THRESHOLD - and self._word_freq(w2_raw) > RARE_THRESHOLD - ) - if not orig_both_common: - new_freq = max(new_freq, old_freq * 10) - else: - has_abbrev = False # both originals common -> don't trust - - # Accept if repair produces a more frequent word pair - # (threshold: at least 5x more frequent to avoid false positives) - if new_freq > old_freq * 5: - new_w2_punct = new_w2_full[len(new_w2_base):] - changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}") - token_list[i][0] = new_w1 - token_list[i + 1][0] = new_w2_base - if new_w2_punct: - token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?") - - # --- Pass 2: Context split (anew -> a new) --- - expanded: List[List[str]] = [] - for i, (word, sep) in enumerate(token_list): - next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" - prev_word = token_list[i - 1][0] if i > 0 else "" - split = self._try_context_split(word, next_word, prev_word) - if split and split != word: - changes.append(f"{word}\u2192{split}") - expanded.append([split, sep]) - else: - expanded.append([word, sep]) - token_list = expanded - - # --- Pass 3: Per-word correction --- - parts: List[str] = [] - - # Preserve any leading text before the first token match - first_start = tokens[0].start() if tokens else 0 - if first_start > 0: - parts.append(text[:first_start]) - - for i, (word, sep) in enumerate(token_list): - # Skip words inside IPA brackets (brackets land in separators) - prev_sep = token_list[i - 1][1] if i > 0 else "" - if '[' in prev_sep or ']' in sep: - parts.append(word) - parts.append(sep) - continue - - next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" - prev_word = token_list[i - 1][0] if i > 0 else "" - - correction = self.correct_word( - word, lang=effective_lang, - prev_word=prev_word, next_word=next_word, - ) - if correction and correction != word: - changes.append(f"{word}\u2192{correction}") - parts.append(correction) - else: - parts.append(word) - parts.append(sep) - - # Append any trailing text - last_end = tokens[-1].end() if tokens else 0 - if last_end < len(text): - parts.append(text[last_end:]) - - corrected = "".join(parts) - return CorrectionResult( - original=text, - corrected=corrected, - lang_detected=detected, - changed=corrected != text, - changes=changes, - ) - - # --- Vocabulary entry correction --- - - def correct_vocab_entry(self, english: str, german: str, - example: str = "") -> Dict[str, CorrectionResult]: - """Correct a full vocabulary entry (EN + DE + example). - - Uses column position to determine language -- the most reliable signal. - """ - results = {} - results["english"] = self.correct_text(english, lang="en") - results["german"] = self.correct_text(german, lang="de") - if example: - # For examples, auto-detect language - results["example"] = self.correct_text(example, lang="auto") - return results +# Backward-compat shim -- module moved to ocr/spell/text.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.spell.text") diff --git a/klausur-service/backend/tesseract_vocab_extractor.py b/klausur-service/backend/tesseract_vocab_extractor.py index 23ac32e..fc63392 100644 --- a/klausur-service/backend/tesseract_vocab_extractor.py +++ b/klausur-service/backend/tesseract_vocab_extractor.py @@ -1,346 +1,4 @@ -""" -Tesseract-based OCR extraction with word-level bounding boxes. - -Uses Tesseract for spatial information (WHERE text is) while -the Vision LLM handles semantic understanding (WHAT the text means). - -Tesseract runs natively on ARM64 via Debian's apt package. - -Lizenz: Apache 2.0 (kommerziell nutzbar) -""" - -import io -import logging -from typing import List, Dict, Any, Optional -from difflib import SequenceMatcher - -logger = logging.getLogger(__name__) - -try: - import pytesseract - from PIL import Image - TESSERACT_AVAILABLE = True -except ImportError: - TESSERACT_AVAILABLE = False - logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable") - - -async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict: - """Run Tesseract OCR and return word-level bounding boxes. - - Args: - image_bytes: PNG/JPEG image as bytes. - lang: Tesseract language string (e.g. "eng+deu"). - - Returns: - Dict with 'words' list and 'image_width'/'image_height'. - """ - if not TESSERACT_AVAILABLE: - return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"} - - image = Image.open(io.BytesIO(image_bytes)) - data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT) - - words = [] - for i in range(len(data['text'])): - text = data['text'][i].strip() - conf = int(data['conf'][i]) - if not text or conf < 20: - continue - words.append({ - "text": text, - "left": data['left'][i], - "top": data['top'][i], - "width": data['width'][i], - "height": data['height'][i], - "conf": conf, - "block_num": data['block_num'][i], - "par_num": data['par_num'][i], - "line_num": data['line_num'][i], - "word_num": data['word_num'][i], - }) - - return { - "words": words, - "image_width": image.width, - "image_height": image.height, - } - - -def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]: - """Group words by their Y position into lines. - - Args: - words: List of word dicts from extract_bounding_boxes. - y_tolerance_px: Max pixel distance to consider words on the same line. - - Returns: - List of lines, each line is a list of words sorted by X position. - """ - if not words: - return [] - - # Sort by Y then X - sorted_words = sorted(words, key=lambda w: (w['top'], w['left'])) - - lines: List[List[dict]] = [] - current_line: List[dict] = [sorted_words[0]] - current_y = sorted_words[0]['top'] - - for word in sorted_words[1:]: - if abs(word['top'] - current_y) <= y_tolerance_px: - current_line.append(word) - else: - current_line.sort(key=lambda w: w['left']) - lines.append(current_line) - current_line = [word] - current_y = word['top'] - - if current_line: - current_line.sort(key=lambda w: w['left']) - lines.append(current_line) - - return lines - - -def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]: - """Detect column boundaries from word positions. - - Typical vocab table: Left=English, Middle=German, Right=Example sentences. - - Returns: - Dict with column boundaries and type assignments. - """ - if not lines or image_width == 0: - return {"columns": [], "column_types": []} - - # Collect all word X positions - all_x_positions = [] - for line in lines: - for word in line: - all_x_positions.append(word['left']) - - if not all_x_positions: - return {"columns": [], "column_types": []} - - # Find X-position clusters (column starts) - all_x_positions.sort() - - # Simple gap-based column detection - min_gap = image_width * 0.08 # 8% of page width = column gap - clusters = [] - current_cluster = [all_x_positions[0]] - - for x in all_x_positions[1:]: - if x - current_cluster[-1] > min_gap: - clusters.append(current_cluster) - current_cluster = [x] - else: - current_cluster.append(x) - - if current_cluster: - clusters.append(current_cluster) - - # Each cluster represents a column start - columns = [] - for cluster in clusters: - col_start = min(cluster) - columns.append({ - "x_start": col_start, - "x_start_pct": col_start / image_width * 100, - "word_count": len(cluster), - }) - - # Assign column types based on position (left→right: EN, DE, Example) - type_map = ["english", "german", "example"] - column_types = [] - for i, col in enumerate(columns): - if i < len(type_map): - column_types.append(type_map[i]) - else: - column_types.append("unknown") - - return { - "columns": columns, - "column_types": column_types, - } - - -def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict], - column_types: List[str], image_width: int, - image_height: int) -> List[dict]: - """Convert grouped words into vocabulary entries using column positions. - - Args: - lines: Grouped word lines from group_words_into_lines. - columns: Column boundaries from detect_columns. - column_types: Column type assignments. - image_width: Image width in pixels. - image_height: Image height in pixels. - - Returns: - List of vocabulary entry dicts with english/german/example fields. - """ - if not columns or not lines: - return [] - - # Build column boundaries for word assignment - col_boundaries = [] - for i, col in enumerate(columns): - start = col['x_start'] - if i + 1 < len(columns): - end = columns[i + 1]['x_start'] - else: - end = image_width - col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown")) - - entries = [] - for line in lines: - entry = {"english": "", "german": "", "example": ""} - line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []} - line_bbox: Dict[str, Optional[dict]] = {} - - for word in line: - word_center_x = word['left'] + word['width'] / 2 - assigned_type = "unknown" - for start, end, col_type in col_boundaries: - if start <= word_center_x < end: - assigned_type = col_type - break - - if assigned_type in line_words_by_col: - line_words_by_col[assigned_type].append(word['text']) - # Track bounding box for the column - if assigned_type not in line_bbox or line_bbox[assigned_type] is None: - line_bbox[assigned_type] = { - "left": word['left'], - "top": word['top'], - "right": word['left'] + word['width'], - "bottom": word['top'] + word['height'], - } - else: - bb = line_bbox[assigned_type] - bb['left'] = min(bb['left'], word['left']) - bb['top'] = min(bb['top'], word['top']) - bb['right'] = max(bb['right'], word['left'] + word['width']) - bb['bottom'] = max(bb['bottom'], word['top'] + word['height']) - - for col_type in ["english", "german", "example"]: - if line_words_by_col[col_type]: - entry[col_type] = " ".join(line_words_by_col[col_type]) - if line_bbox.get(col_type): - bb = line_bbox[col_type] - entry[f"{col_type}_bbox"] = { - "x_pct": bb['left'] / image_width * 100, - "y_pct": bb['top'] / image_height * 100, - "w_pct": (bb['right'] - bb['left']) / image_width * 100, - "h_pct": (bb['bottom'] - bb['top']) / image_height * 100, - } - - # Only add if at least one column has content - if entry["english"] or entry["german"]: - entries.append(entry) - - return entries - - -def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict], - image_w: int, image_h: int, - threshold: float = 0.6) -> List[dict]: - """Match Tesseract bounding boxes to LLM vocabulary entries. - - For each LLM vocab entry, find the best-matching Tesseract word - and attach its bounding box coordinates. - - Args: - tess_words: Word list from Tesseract with pixel coordinates. - llm_vocab: Vocabulary list from Vision LLM. - image_w: Image width in pixels. - image_h: Image height in pixels. - threshold: Minimum similarity ratio for a match. - - Returns: - llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added. - """ - if not tess_words or not llm_vocab or image_w == 0 or image_h == 0: - return llm_vocab - - for entry in llm_vocab: - english = entry.get("english", "").lower().strip() - german = entry.get("german", "").lower().strip() - - if not english and not german: - continue - - # Try to match English word first, then German - for field in ["english", "german"]: - search_text = entry.get(field, "").lower().strip() - if not search_text: - continue - - best_word = None - best_ratio = 0.0 - - for word in tess_words: - ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio() - if ratio > best_ratio: - best_ratio = ratio - best_word = word - - if best_word and best_ratio >= threshold: - entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100 - entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100 - entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100 - entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100 - entry["bbox_match_field"] = field - entry["bbox_match_ratio"] = round(best_ratio, 3) - break # Found a match, no need to try the other field - - return llm_vocab - - -async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict: - """Full Tesseract pipeline: extract words, group lines, detect columns, build vocab. - - Args: - image_bytes: PNG/JPEG image as bytes. - lang: Tesseract language string. - - Returns: - Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'. - """ - # Step 1: Extract bounding boxes - bbox_data = await extract_bounding_boxes(image_bytes, lang=lang) - - if bbox_data.get("error"): - return bbox_data - - words = bbox_data["words"] - image_w = bbox_data["image_width"] - image_h = bbox_data["image_height"] - - # Step 2: Group into lines - lines = group_words_into_lines(words) - - # Step 3: Detect columns - col_info = detect_columns(lines, image_w) - - # Step 4: Build vocabulary entries - vocab = words_to_vocab_entries( - lines, - col_info["columns"], - col_info["column_types"], - image_w, - image_h, - ) - - return { - "vocabulary": vocab, - "words": words, - "lines_count": len(lines), - "columns": col_info["columns"], - "column_types": col_info["column_types"], - "image_width": image_w, - "image_height": image_h, - "word_count": len(words), - } +# Backward-compat shim -- module moved to ocr/engines/tesseract_extractor.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.engines.tesseract_extractor") diff --git a/klausur-service/backend/unified_grid.py b/klausur-service/backend/unified_grid.py index 48e9fa4..e0aa370 100644 --- a/klausur-service/backend/unified_grid.py +++ b/klausur-service/backend/unified_grid.py @@ -1,425 +1,4 @@ -""" -Unified Grid Builder — merges multi-zone grid into a single Excel-like grid. - -Takes content zone + box zones and produces one unified zone where: -- All content rows use the dominant row height -- Full-width boxes are integrated directly (box rows replace standard rows) -- Partial-width boxes: extra rows inserted if box has more lines than standard -- Box-origin cells carry metadata (bg_color, border) for visual distinction - -The result is a single-zone StructuredGrid that can be: -- Rendered in an Excel-like editor -- Exported to Excel/CSV -- Edited with unified row/column numbering -""" - -import logging -import math -import statistics -from typing import Any, Dict, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -def _compute_dominant_row_height(content_zone: Dict) -> float: - """Median of content row-to-row spacings, excluding box-gap jumps.""" - rows = content_zone.get("rows", []) - if len(rows) < 2: - return 47.0 - - spacings = [] - for i in range(len(rows) - 1): - y1 = rows[i].get("y_min_px", rows[i].get("y_min", 0)) - y2 = rows[i + 1].get("y_min_px", rows[i + 1].get("y_min", 0)) - d = y2 - y1 - if 0 < d < 100: # exclude box-gap jumps - spacings.append(d) - - if not spacings: - return 47.0 - spacings.sort() - return spacings[len(spacings) // 2] - - -def _classify_boxes( - box_zones: List[Dict], - content_width: float, -) -> List[Dict]: - """Classify each box as full_width or partial_width.""" - result = [] - for bz in box_zones: - bb = bz.get("bbox_px", {}) - bw = bb.get("w", 0) - bx = bb.get("x", 0) - - if bw >= content_width * 0.85: - classification = "full_width" - side = "center" - else: - classification = "partial_width" - # Determine which side of the page the box is on - page_center = content_width / 2 - box_center = bx + bw / 2 - side = "right" if box_center > page_center else "left" - - # Count total text lines in box (including \n within cells) - total_lines = sum( - (c.get("text", "").count("\n") + 1) - for c in bz.get("cells", []) - ) - - result.append({ - "zone": bz, - "classification": classification, - "side": side, - "y_start": bb.get("y", 0), - "y_end": bb.get("y", 0) + bb.get("h", 0), - "total_lines": total_lines, - "bg_hex": bz.get("box_bg_hex", ""), - "bg_color": bz.get("box_bg_color", ""), - }) - return result - - -def build_unified_grid( - zones: List[Dict], - image_width: int, - image_height: int, - layout_metrics: Dict, -) -> Dict[str, Any]: - """Build a single-zone unified grid from multi-zone grid data. - - Returns a StructuredGrid with one zone containing all rows and cells. - """ - content_zone = None - box_zones = [] - for z in zones: - if z.get("zone_type") == "content": - content_zone = z - elif z.get("zone_type") == "box": - box_zones.append(z) - - if not content_zone: - logger.warning("build_unified_grid: no content zone found") - return {"zones": zones} # fallback: return as-is - - box_zones.sort(key=lambda b: b.get("bbox_px", {}).get("y", 0)) - - dominant_h = _compute_dominant_row_height(content_zone) - content_bbox = content_zone.get("bbox_px", {}) - content_width = content_bbox.get("w", image_width) - content_x = content_bbox.get("x", 0) - content_cols = content_zone.get("columns", []) - num_cols = len(content_cols) - - box_infos = _classify_boxes(box_zones, content_width) - - logger.info( - "build_unified_grid: dominant_h=%.1f, %d content rows, %d boxes (%s)", - dominant_h, len(content_zone.get("rows", [])), len(box_infos), - [b["classification"] for b in box_infos], - ) - - # --- Build unified row list + cell list --- - unified_rows: List[Dict] = [] - unified_cells: List[Dict] = [] - unified_row_idx = 0 - - # Content rows and cells indexed by row_index - content_rows = content_zone.get("rows", []) - content_cells = content_zone.get("cells", []) - content_cells_by_row: Dict[int, List[Dict]] = {} - for c in content_cells: - content_cells_by_row.setdefault(c.get("row_index", -1), []).append(c) - - # Track which content rows we've processed - content_row_ptr = 0 - - for bi, box_info in enumerate(box_infos): - bz = box_info["zone"] - by_start = box_info["y_start"] - by_end = box_info["y_end"] - - # --- Add content rows ABOVE this box --- - while content_row_ptr < len(content_rows): - cr = content_rows[content_row_ptr] - cry = cr.get("y_min_px", cr.get("y_min", 0)) - if cry >= by_start: - break - # Add this content row - _add_content_row( - unified_rows, unified_cells, unified_row_idx, - cr, content_cells_by_row, dominant_h, image_height, - ) - unified_row_idx += 1 - content_row_ptr += 1 - - # --- Add box rows --- - if box_info["classification"] == "full_width": - # Full-width box: integrate box rows directly - _add_full_width_box( - unified_rows, unified_cells, unified_row_idx, - bz, box_info, dominant_h, num_cols, image_height, - ) - unified_row_idx += len(bz.get("rows", [])) - # Skip content rows that overlap with this box - while content_row_ptr < len(content_rows): - cr = content_rows[content_row_ptr] - cry = cr.get("y_min_px", cr.get("y_min", 0)) - if cry > by_end: - break - content_row_ptr += 1 - - else: - # Partial-width box: merge with adjacent content rows - unified_row_idx = _add_partial_width_box( - unified_rows, unified_cells, unified_row_idx, - bz, box_info, content_rows, content_cells_by_row, - content_row_ptr, dominant_h, num_cols, image_height, - content_x, content_width, - ) - # Advance content pointer past box region - while content_row_ptr < len(content_rows): - cr = content_rows[content_row_ptr] - cry = cr.get("y_min_px", cr.get("y_min", 0)) - if cry > by_end: - break - content_row_ptr += 1 - - # --- Add remaining content rows BELOW all boxes --- - while content_row_ptr < len(content_rows): - cr = content_rows[content_row_ptr] - _add_content_row( - unified_rows, unified_cells, unified_row_idx, - cr, content_cells_by_row, dominant_h, image_height, - ) - unified_row_idx += 1 - content_row_ptr += 1 - - # --- Build unified zone --- - unified_zone = { - "zone_index": 0, - "zone_type": "unified", - "bbox_px": content_bbox, - "bbox_pct": content_zone.get("bbox_pct", {}), - "border": None, - "word_count": sum(len(c.get("word_boxes", [])) for c in unified_cells), - "columns": content_cols, - "rows": unified_rows, - "cells": unified_cells, - "header_rows": [], - } - - logger.info( - "build_unified_grid: %d unified rows, %d cells (from %d content + %d box zones)", - len(unified_rows), len(unified_cells), - len(content_rows), len(box_zones), - ) - - return { - "zones": [unified_zone], - "image_width": image_width, - "image_height": image_height, - "layout_metrics": layout_metrics, - "summary": { - "total_zones": 1, - "total_columns": num_cols, - "total_rows": len(unified_rows), - "total_cells": len(unified_cells), - }, - "is_unified": True, - "dominant_row_h": dominant_h, - } - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_row(idx: int, y: float, h: float, img_h: int, is_header: bool = False) -> Dict: - return { - "index": idx, - "row_index": idx, - "y_min_px": round(y), - "y_max_px": round(y + h), - "y_min_pct": round(y / img_h * 100, 2) if img_h else 0, - "y_max_pct": round((y + h) / img_h * 100, 2) if img_h else 0, - "is_header": is_header, - } - - -def _remap_cell(cell: Dict, new_row: int, new_col: int = None, - source_type: str = "content", box_region: Dict = None) -> Dict: - """Create a new cell dict with remapped indices.""" - c = dict(cell) - c["row_index"] = new_row - if new_col is not None: - c["col_index"] = new_col - c["cell_id"] = f"U_R{new_row:02d}_C{c.get('col_index', 0)}" - c["source_zone_type"] = source_type - if box_region: - c["box_region"] = box_region - return c - - -def _add_content_row( - unified_rows, unified_cells, row_idx, - content_row, cells_by_row, dominant_h, img_h, -): - """Add a single content row to the unified grid.""" - y = content_row.get("y_min_px", content_row.get("y_min", 0)) - is_hdr = content_row.get("is_header", False) - unified_rows.append(_make_row(row_idx, y, dominant_h, img_h, is_hdr)) - - for cell in cells_by_row.get(content_row.get("index", -1), []): - unified_cells.append(_remap_cell(cell, row_idx, source_type="content")) - - -def _add_full_width_box( - unified_rows, unified_cells, start_row_idx, - box_zone, box_info, dominant_h, num_cols, img_h, -): - """Add a full-width box's rows to the unified grid.""" - box_rows = box_zone.get("rows", []) - box_cells = box_zone.get("cells", []) - box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True} - - # Distribute box height evenly among its rows - box_h = box_info["y_end"] - box_info["y_start"] - row_h = box_h / len(box_rows) if box_rows else dominant_h - - for i, br in enumerate(box_rows): - y = box_info["y_start"] + i * row_h - new_idx = start_row_idx + i - is_hdr = br.get("is_header", False) - unified_rows.append(_make_row(new_idx, y, row_h, img_h, is_hdr)) - - for cell in box_cells: - if cell.get("row_index") == br.get("index", i): - unified_cells.append( - _remap_cell(cell, new_idx, source_type="box", box_region=box_region) - ) - - -def _add_partial_width_box( - unified_rows, unified_cells, start_row_idx, - box_zone, box_info, content_rows, content_cells_by_row, - content_row_ptr, dominant_h, num_cols, img_h, - content_x, content_width, -) -> int: - """Add a partial-width box merged with content rows. - - Returns the next unified_row_idx after processing. - """ - by_start = box_info["y_start"] - by_end = box_info["y_end"] - box_h = by_end - by_start - box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True} - - # Content rows in the box's Y range - overlap_content_rows = [] - ptr = content_row_ptr - while ptr < len(content_rows): - cr = content_rows[ptr] - cry = cr.get("y_min_px", cr.get("y_min", 0)) - if cry > by_end: - break - if cry >= by_start: - overlap_content_rows.append(cr) - ptr += 1 - - # How many standard rows fit in the box height - standard_rows = max(1, math.floor(box_h / dominant_h)) - # How many text lines the box actually has - box_text_lines = box_info["total_lines"] - # Extra rows needed - extra_rows = max(0, box_text_lines - standard_rows) - total_rows_for_region = standard_rows + extra_rows - - logger.info( - "partial box: standard=%d, box_lines=%d, extra=%d, content_overlap=%d", - standard_rows, box_text_lines, extra_rows, len(overlap_content_rows), - ) - - # Determine which columns the box occupies - box_bb = box_zone.get("bbox_px", {}) - box_x = box_bb.get("x", 0) - box_w = box_bb.get("w", 0) - - # Map box to content columns: find which content columns overlap - box_col_start = 0 - box_col_end = num_cols - content_cols_list = [] - for z_col_idx in range(num_cols): - # Find the column definition by checking all column entries - # Simple heuristic: if box starts past halfway, it's the right columns - pass - - # Simpler approach: box on right side → last N columns - # box on left side → first N columns - if box_info["side"] == "right": - # Box starts at x=box_x. Find first content column that overlaps - box_col_start = num_cols # default: beyond all columns - for z in (box_zone.get("columns") or [{"index": 0}]): - pass - # Use content column positions to determine overlap - content_cols_data = [ - {"idx": c.get("index", i), "x_min": c.get("x_min_px", 0), "x_max": c.get("x_max_px", 0)} - for i, c in enumerate(content_rows[0:0] or []) # placeholder - ] - # Simple: split columns at midpoint - box_col_start = num_cols // 2 # right half - box_col_end = num_cols - else: - box_col_start = 0 - box_col_end = num_cols // 2 - - # Build rows for this region - box_cells = box_zone.get("cells", []) - box_rows = box_zone.get("rows", []) - row_idx = start_row_idx - - # Expand box cell texts with \n into individual lines for row mapping - box_lines: List[Tuple[str, Dict]] = [] # (text_line, parent_cell) - for bc in sorted(box_cells, key=lambda c: c.get("row_index", 0)): - text = bc.get("text", "") - for line in text.split("\n"): - box_lines.append((line.strip(), bc)) - - for i in range(total_rows_for_region): - y = by_start + i * dominant_h - unified_rows.append(_make_row(row_idx, y, dominant_h, img_h)) - - # Content cells for this row (from overlapping content rows) - if i < len(overlap_content_rows): - cr = overlap_content_rows[i] - for cell in content_cells_by_row.get(cr.get("index", -1), []): - # Only include cells from columns NOT covered by the box - ci = cell.get("col_index", 0) - if ci < box_col_start or ci >= box_col_end: - unified_cells.append(_remap_cell(cell, row_idx, source_type="content")) - - # Box cell for this row - if i < len(box_lines): - line_text, parent_cell = box_lines[i] - box_cell = { - "cell_id": f"U_R{row_idx:02d}_C{box_col_start}", - "row_index": row_idx, - "col_index": box_col_start, - "col_type": "spanning_header" if (box_col_end - box_col_start) > 1 else parent_cell.get("col_type", "column_1"), - "colspan": box_col_end - box_col_start, - "text": line_text, - "confidence": parent_cell.get("confidence", 0), - "bbox_px": parent_cell.get("bbox_px", {}), - "bbox_pct": parent_cell.get("bbox_pct", {}), - "word_boxes": [], - "ocr_engine": parent_cell.get("ocr_engine", ""), - "is_bold": parent_cell.get("is_bold", False), - "source_zone_type": "box", - "box_region": box_region, - } - unified_cells.append(box_cell) - - row_idx += 1 - - return row_idx +# Backward-compat shim -- module moved to grid/unified.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("grid.unified") diff --git a/klausur-service/backend/upload/__init__.py b/klausur-service/backend/upload/__init__.py new file mode 100644 index 0000000..d184f47 --- /dev/null +++ b/klausur-service/backend/upload/__init__.py @@ -0,0 +1,6 @@ +""" +Upload package — chunked and mobile upload endpoints. + +Moved from backend/ flat modules (upload_api*.py). +Backward-compatible shim files remain at the old locations. +""" diff --git a/klausur-service/backend/upload/api.py b/klausur-service/backend/upload/api.py new file mode 100644 index 0000000..eff5080 --- /dev/null +++ b/klausur-service/backend/upload/api.py @@ -0,0 +1,29 @@ +""" +Mobile Upload API — barrel re-export. + +All implementation split into: + upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list) + upload_api_mobile — mobile HTML upload page + +DSGVO-konform: Data stays local in WLAN, no external transmission. +""" + +from fastapi import APIRouter + +from .chunked import ( # noqa: F401 + router as _chunked_router, + UPLOAD_DIR, + CHUNK_DIR, + EH_UPLOAD_DIR, + _upload_sessions, + InitUploadRequest, + InitUploadResponse, + ChunkUploadResponse, + FinalizeResponse, +) +from .mobile import router as _mobile_router # noqa: F401 + +# Composite router that includes both sub-routers +router = APIRouter() +router.include_router(_chunked_router) +router.include_router(_mobile_router) diff --git a/klausur-service/backend/upload/chunked.py b/klausur-service/backend/upload/chunked.py new file mode 100644 index 0000000..13ddfff --- /dev/null +++ b/klausur-service/backend/upload/chunked.py @@ -0,0 +1,320 @@ +""" +Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list. + +Extracted from upload_api.py for modularity. + +DSGVO-konform: Data stays local in WLAN, no external transmission. +""" + +import os +import uuid +import shutil +import hashlib +from pathlib import Path +from datetime import datetime, timezone +from typing import Dict, Optional + +from fastapi import APIRouter, HTTPException, UploadFile, File, Form +from pydantic import BaseModel + +# Configuration +UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads")) +CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks")) +EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads")) + +# Ensure directories exist +UPLOAD_DIR.mkdir(parents=True, exist_ok=True) +CHUNK_DIR.mkdir(parents=True, exist_ok=True) +EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + +# In-memory storage for upload sessions (for simplicity) +# In production, use Redis or database +_upload_sessions: Dict[str, dict] = {} + +router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) + + +class InitUploadRequest(BaseModel): + filename: str + filesize: int + chunks: int + destination: str = "klausur" # "klausur" or "rag" + + +class InitUploadResponse(BaseModel): + upload_id: str + chunk_size: int + total_chunks: int + message: str + + +class ChunkUploadResponse(BaseModel): + upload_id: str + chunk_index: int + received: bool + chunks_received: int + total_chunks: int + + +class FinalizeResponse(BaseModel): + upload_id: str + filename: str + filepath: str + filesize: int + checksum: str + message: str + + +@router.post("/init", response_model=InitUploadResponse) +async def init_upload(request: InitUploadRequest): + """ + Initialize a chunked upload session. + + Returns an upload_id that must be used for subsequent chunk uploads. + """ + upload_id = str(uuid.uuid4()) + + # Create session directory + session_dir = CHUNK_DIR / upload_id + session_dir.mkdir(parents=True, exist_ok=True) + + # Store session info + _upload_sessions[upload_id] = { + "filename": request.filename, + "filesize": request.filesize, + "total_chunks": request.chunks, + "received_chunks": set(), + "destination": request.destination, + "session_dir": str(session_dir), + "created_at": datetime.now(timezone.utc).isoformat(), + } + + return InitUploadResponse( + upload_id=upload_id, + chunk_size=5 * 1024 * 1024, # 5 MB + total_chunks=request.chunks, + message="Upload-Session erstellt" + ) + + +@router.post("/chunk", response_model=ChunkUploadResponse) +async def upload_chunk( + chunk: UploadFile = File(...), + upload_id: str = Form(...), + chunk_index: int = Form(...) +): + """ + Upload a single chunk of a file. + + Chunks are stored temporarily until finalize is called. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + if chunk_index < 0 or chunk_index >= session["total_chunks"]: + raise HTTPException( + status_code=400, + detail=f"Ungueltiger Chunk-Index: {chunk_index}" + ) + + # Save chunk + chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}" + + with open(chunk_path, "wb") as f: + content = await chunk.read() + f.write(content) + + # Track received chunks + session["received_chunks"].add(chunk_index) + + return ChunkUploadResponse( + upload_id=upload_id, + chunk_index=chunk_index, + received=True, + chunks_received=len(session["received_chunks"]), + total_chunks=session["total_chunks"] + ) + + +@router.post("/finalize", response_model=FinalizeResponse) +async def finalize_upload(upload_id: str = Form(...)): + """ + Finalize the upload by combining all chunks into a single file. + + Validates that all chunks were received and calculates checksum. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + # Check if all chunks received + if len(session["received_chunks"]) != session["total_chunks"]: + missing = session["total_chunks"] - len(session["received_chunks"]) + raise HTTPException( + status_code=400, + detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}" + ) + + # Determine destination directory + if session["destination"] == "rag": + dest_dir = EH_UPLOAD_DIR + else: + dest_dir = UPLOAD_DIR + + # Generate unique filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + safe_filename = session["filename"].replace(" ", "_") + final_filename = f"{timestamp}_{safe_filename}" + final_path = dest_dir / final_filename + + # Combine chunks + hasher = hashlib.sha256() + total_size = 0 + + with open(final_path, "wb") as outfile: + for i in range(session["total_chunks"]): + chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}" + + if not chunk_path.exists(): + raise HTTPException( + status_code=500, + detail=f"Chunk {i} nicht gefunden" + ) + + with open(chunk_path, "rb") as infile: + data = infile.read() + outfile.write(data) + hasher.update(data) + total_size += len(data) + + # Clean up chunks + shutil.rmtree(session["session_dir"], ignore_errors=True) + del _upload_sessions[upload_id] + + checksum = hasher.hexdigest() + + return FinalizeResponse( + upload_id=upload_id, + filename=final_filename, + filepath=str(final_path), + filesize=total_size, + checksum=checksum, + message="Upload erfolgreich abgeschlossen" + ) + + +@router.post("/simple") +async def simple_upload( + file: UploadFile = File(...), + destination: str = Form("klausur") +): + """ + Simple single-request upload for smaller files (<10MB). + + For larger files, use the chunked upload endpoints. + """ + # Determine destination directory + if destination == "rag": + dest_dir = EH_UPLOAD_DIR + else: + dest_dir = UPLOAD_DIR + + # Generate unique filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf" + final_filename = f"{timestamp}_{safe_filename}" + final_path = dest_dir / final_filename + + # Calculate checksum while writing + hasher = hashlib.sha256() + total_size = 0 + + with open(final_path, "wb") as f: + while True: + chunk = await file.read(1024 * 1024) # Read 1MB at a time + if not chunk: + break + f.write(chunk) + hasher.update(chunk) + total_size += len(chunk) + + return { + "filename": final_filename, + "filepath": str(final_path), + "filesize": total_size, + "checksum": hasher.hexdigest(), + "message": "Upload erfolgreich" + } + + +@router.get("/status/{upload_id}") +async def get_upload_status(upload_id: str): + """ + Get the status of an ongoing upload. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + return { + "upload_id": upload_id, + "filename": session["filename"], + "total_chunks": session["total_chunks"], + "received_chunks": len(session["received_chunks"]), + "progress_percent": round( + len(session["received_chunks"]) / session["total_chunks"] * 100, 1 + ), + "destination": session["destination"], + "created_at": session["created_at"] + } + + +@router.delete("/cancel/{upload_id}") +async def cancel_upload(upload_id: str): + """ + Cancel an ongoing upload and clean up temporary files. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + # Clean up chunks + shutil.rmtree(session["session_dir"], ignore_errors=True) + del _upload_sessions[upload_id] + + return {"message": "Upload abgebrochen", "upload_id": upload_id} + + +@router.get("/list") +async def list_uploads(destination: str = "klausur"): + """ + List all uploaded files in the specified destination. + """ + if destination == "rag": + dest_dir = EH_UPLOAD_DIR + else: + dest_dir = UPLOAD_DIR + + files = [] + + for f in dest_dir.iterdir(): + if f.is_file() and f.suffix.lower() == ".pdf": + stat = f.stat() + files.append({ + "filename": f.name, + "size": stat.st_size, + "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(), + }) + + files.sort(key=lambda x: x["modified"], reverse=True) + + return { + "destination": destination, + "count": len(files), + "files": files[:50] # Limit to 50 most recent + } diff --git a/klausur-service/backend/upload/mobile.py b/klausur-service/backend/upload/mobile.py new file mode 100644 index 0000000..8ddd423 --- /dev/null +++ b/klausur-service/backend/upload/mobile.py @@ -0,0 +1,292 @@ +""" +Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service. + +Extracted from upload_api.py for modularity. + +DSGVO-konform: Data stays local in WLAN, no external transmission. +""" + +from fastapi import APIRouter +from fastapi.responses import HTMLResponse + +router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) + + +@router.get("/mobile", response_class=HTMLResponse) +async def mobile_upload_page(): + """ + Serve the mobile upload page directly from the klausur-service. + This allows mobile devices to upload without needing the Next.js website. + """ + html_content = ''' + + + + + + BreakPilot Upload + + + +
+

BreakPilot Upload

+ DSGVO-konform +
+ +
+ + +
+ +
+ +
+
PDF-Dateien hochladen
+
Tippen zum Auswaehlen oder hierher ziehen
+
Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen
+
+ + + +
+ +
+

Hinweise:

+
    +
  • Die Dateien werden lokal im WLAN uebertragen
  • +
  • Keine Daten werden ins Internet gesendet
  • +
  • Unterstuetzte Formate: PDF
  • +
+
+ +
Server: wird ermittelt...
+ + + +''' + return HTMLResponse(content=html_content) diff --git a/klausur-service/backend/upload_api.py b/klausur-service/backend/upload_api.py index 98e6f1a..79c5359 100644 --- a/klausur-service/backend/upload_api.py +++ b/klausur-service/backend/upload_api.py @@ -1,29 +1,4 @@ -""" -Mobile Upload API — barrel re-export. - -All implementation split into: - upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list) - upload_api_mobile — mobile HTML upload page - -DSGVO-konform: Data stays local in WLAN, no external transmission. -""" - -from fastapi import APIRouter - -from upload_api_chunked import ( # noqa: F401 - router as _chunked_router, - UPLOAD_DIR, - CHUNK_DIR, - EH_UPLOAD_DIR, - _upload_sessions, - InitUploadRequest, - InitUploadResponse, - ChunkUploadResponse, - FinalizeResponse, -) -from upload_api_mobile import router as _mobile_router # noqa: F401 - -# Composite router that includes both sub-routers -router = APIRouter() -router.include_router(_chunked_router) -router.include_router(_mobile_router) +# Backward-compat shim -- module moved to upload/api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("upload.api") diff --git a/klausur-service/backend/upload_api_chunked.py b/klausur-service/backend/upload_api_chunked.py index 13ddfff..7478b87 100644 --- a/klausur-service/backend/upload_api_chunked.py +++ b/klausur-service/backend/upload_api_chunked.py @@ -1,320 +1,4 @@ -""" -Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list. - -Extracted from upload_api.py for modularity. - -DSGVO-konform: Data stays local in WLAN, no external transmission. -""" - -import os -import uuid -import shutil -import hashlib -from pathlib import Path -from datetime import datetime, timezone -from typing import Dict, Optional - -from fastapi import APIRouter, HTTPException, UploadFile, File, Form -from pydantic import BaseModel - -# Configuration -UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads")) -CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks")) -EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads")) - -# Ensure directories exist -UPLOAD_DIR.mkdir(parents=True, exist_ok=True) -CHUNK_DIR.mkdir(parents=True, exist_ok=True) -EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) - -# In-memory storage for upload sessions (for simplicity) -# In production, use Redis or database -_upload_sessions: Dict[str, dict] = {} - -router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) - - -class InitUploadRequest(BaseModel): - filename: str - filesize: int - chunks: int - destination: str = "klausur" # "klausur" or "rag" - - -class InitUploadResponse(BaseModel): - upload_id: str - chunk_size: int - total_chunks: int - message: str - - -class ChunkUploadResponse(BaseModel): - upload_id: str - chunk_index: int - received: bool - chunks_received: int - total_chunks: int - - -class FinalizeResponse(BaseModel): - upload_id: str - filename: str - filepath: str - filesize: int - checksum: str - message: str - - -@router.post("/init", response_model=InitUploadResponse) -async def init_upload(request: InitUploadRequest): - """ - Initialize a chunked upload session. - - Returns an upload_id that must be used for subsequent chunk uploads. - """ - upload_id = str(uuid.uuid4()) - - # Create session directory - session_dir = CHUNK_DIR / upload_id - session_dir.mkdir(parents=True, exist_ok=True) - - # Store session info - _upload_sessions[upload_id] = { - "filename": request.filename, - "filesize": request.filesize, - "total_chunks": request.chunks, - "received_chunks": set(), - "destination": request.destination, - "session_dir": str(session_dir), - "created_at": datetime.now(timezone.utc).isoformat(), - } - - return InitUploadResponse( - upload_id=upload_id, - chunk_size=5 * 1024 * 1024, # 5 MB - total_chunks=request.chunks, - message="Upload-Session erstellt" - ) - - -@router.post("/chunk", response_model=ChunkUploadResponse) -async def upload_chunk( - chunk: UploadFile = File(...), - upload_id: str = Form(...), - chunk_index: int = Form(...) -): - """ - Upload a single chunk of a file. - - Chunks are stored temporarily until finalize is called. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - if chunk_index < 0 or chunk_index >= session["total_chunks"]: - raise HTTPException( - status_code=400, - detail=f"Ungueltiger Chunk-Index: {chunk_index}" - ) - - # Save chunk - chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}" - - with open(chunk_path, "wb") as f: - content = await chunk.read() - f.write(content) - - # Track received chunks - session["received_chunks"].add(chunk_index) - - return ChunkUploadResponse( - upload_id=upload_id, - chunk_index=chunk_index, - received=True, - chunks_received=len(session["received_chunks"]), - total_chunks=session["total_chunks"] - ) - - -@router.post("/finalize", response_model=FinalizeResponse) -async def finalize_upload(upload_id: str = Form(...)): - """ - Finalize the upload by combining all chunks into a single file. - - Validates that all chunks were received and calculates checksum. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - # Check if all chunks received - if len(session["received_chunks"]) != session["total_chunks"]: - missing = session["total_chunks"] - len(session["received_chunks"]) - raise HTTPException( - status_code=400, - detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}" - ) - - # Determine destination directory - if session["destination"] == "rag": - dest_dir = EH_UPLOAD_DIR - else: - dest_dir = UPLOAD_DIR - - # Generate unique filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - safe_filename = session["filename"].replace(" ", "_") - final_filename = f"{timestamp}_{safe_filename}" - final_path = dest_dir / final_filename - - # Combine chunks - hasher = hashlib.sha256() - total_size = 0 - - with open(final_path, "wb") as outfile: - for i in range(session["total_chunks"]): - chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}" - - if not chunk_path.exists(): - raise HTTPException( - status_code=500, - detail=f"Chunk {i} nicht gefunden" - ) - - with open(chunk_path, "rb") as infile: - data = infile.read() - outfile.write(data) - hasher.update(data) - total_size += len(data) - - # Clean up chunks - shutil.rmtree(session["session_dir"], ignore_errors=True) - del _upload_sessions[upload_id] - - checksum = hasher.hexdigest() - - return FinalizeResponse( - upload_id=upload_id, - filename=final_filename, - filepath=str(final_path), - filesize=total_size, - checksum=checksum, - message="Upload erfolgreich abgeschlossen" - ) - - -@router.post("/simple") -async def simple_upload( - file: UploadFile = File(...), - destination: str = Form("klausur") -): - """ - Simple single-request upload for smaller files (<10MB). - - For larger files, use the chunked upload endpoints. - """ - # Determine destination directory - if destination == "rag": - dest_dir = EH_UPLOAD_DIR - else: - dest_dir = UPLOAD_DIR - - # Generate unique filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf" - final_filename = f"{timestamp}_{safe_filename}" - final_path = dest_dir / final_filename - - # Calculate checksum while writing - hasher = hashlib.sha256() - total_size = 0 - - with open(final_path, "wb") as f: - while True: - chunk = await file.read(1024 * 1024) # Read 1MB at a time - if not chunk: - break - f.write(chunk) - hasher.update(chunk) - total_size += len(chunk) - - return { - "filename": final_filename, - "filepath": str(final_path), - "filesize": total_size, - "checksum": hasher.hexdigest(), - "message": "Upload erfolgreich" - } - - -@router.get("/status/{upload_id}") -async def get_upload_status(upload_id: str): - """ - Get the status of an ongoing upload. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - return { - "upload_id": upload_id, - "filename": session["filename"], - "total_chunks": session["total_chunks"], - "received_chunks": len(session["received_chunks"]), - "progress_percent": round( - len(session["received_chunks"]) / session["total_chunks"] * 100, 1 - ), - "destination": session["destination"], - "created_at": session["created_at"] - } - - -@router.delete("/cancel/{upload_id}") -async def cancel_upload(upload_id: str): - """ - Cancel an ongoing upload and clean up temporary files. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - # Clean up chunks - shutil.rmtree(session["session_dir"], ignore_errors=True) - del _upload_sessions[upload_id] - - return {"message": "Upload abgebrochen", "upload_id": upload_id} - - -@router.get("/list") -async def list_uploads(destination: str = "klausur"): - """ - List all uploaded files in the specified destination. - """ - if destination == "rag": - dest_dir = EH_UPLOAD_DIR - else: - dest_dir = UPLOAD_DIR - - files = [] - - for f in dest_dir.iterdir(): - if f.is_file() and f.suffix.lower() == ".pdf": - stat = f.stat() - files.append({ - "filename": f.name, - "size": stat.st_size, - "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(), - }) - - files.sort(key=lambda x: x["modified"], reverse=True) - - return { - "destination": destination, - "count": len(files), - "files": files[:50] # Limit to 50 most recent - } +# Backward-compat shim -- module moved to upload/chunked.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("upload.chunked") diff --git a/klausur-service/backend/upload_api_mobile.py b/klausur-service/backend/upload_api_mobile.py index 8ddd423..10fd87e 100644 --- a/klausur-service/backend/upload_api_mobile.py +++ b/klausur-service/backend/upload_api_mobile.py @@ -1,292 +1,4 @@ -""" -Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service. - -Extracted from upload_api.py for modularity. - -DSGVO-konform: Data stays local in WLAN, no external transmission. -""" - -from fastapi import APIRouter -from fastapi.responses import HTMLResponse - -router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) - - -@router.get("/mobile", response_class=HTMLResponse) -async def mobile_upload_page(): - """ - Serve the mobile upload page directly from the klausur-service. - This allows mobile devices to upload without needing the Next.js website. - """ - html_content = ''' - - - - - - BreakPilot Upload - - - -
-

BreakPilot Upload

- DSGVO-konform -
- -
- - -
- -
- -
-
PDF-Dateien hochladen
-
Tippen zum Auswaehlen oder hierher ziehen
-
Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen
-
- - - -
- -
-

Hinweise:

-
    -
  • Die Dateien werden lokal im WLAN uebertragen
  • -
  • Keine Daten werden ins Internet gesendet
  • -
  • Unterstuetzte Formate: PDF
  • -
-
- -
Server: wird ermittelt...
- - - -''' - return HTMLResponse(content=html_content) +# Backward-compat shim -- module moved to upload/mobile.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("upload.mobile")