Restructure: Move final 12 root files into packages (klausur-service)
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m23s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 19s

ocr/spell/  (3): smart_spell, core, text
upload/     (3): api, chunked, mobile
crawler/    (3): github, github_core, github_parsers
+ unified_grid → grid/, tesseract_extractor → ocr/engines/, htr_api → ocr/pipeline/

12 shims added. Only main.py, config.py, storage + RAG files remain at root.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 23:19:11 +02:00
parent cba877c65a
commit d093a4d388
27 changed files with 3116 additions and 3049 deletions

View File

@@ -0,0 +1,6 @@
"""
Crawler package — GitHub repository crawler for legal templates.
Moved from backend/ flat modules (github_crawler*.py).
Backward-compatible shim files remain at the old locations.
"""

View File

@@ -0,0 +1,35 @@
"""
GitHub Repository Crawler — Barrel Re-export
Split into:
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
All public names are re-exported here for backward compatibility.
"""
# Parsers
from .github_parsers import ( # noqa: F401
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
# Crawler and downloader
from .github_core import ( # noqa: F401
GITHUB_API_URL,
GITLAB_API_URL,
GITHUB_TOKEN,
MAX_FILE_SIZE,
REQUEST_TIMEOUT,
RATE_LIMIT_DELAY,
GitHubCrawler,
RepositoryDownloader,
crawl_source,
main,
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -0,0 +1,411 @@
"""
GitHub Crawler - Core Crawler and Downloader
GitHubCrawler for API-based repository crawling and RepositoryDownloader
for ZIP-based local extraction.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import asyncio
import logging
import os
import shutil
import tempfile
import zipfile
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from template_sources import SourceConfig
from .github_parsers import (
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
logger = logging.getLogger(__name__)
# Configuration
GITHUB_API_URL = "https://api.github.com"
GITLAB_API_URL = "https://gitlab.com/api/v4"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
REQUEST_TIMEOUT = 60.0
RATE_LIMIT_DELAY = 1.0
class GitHubCrawler:
"""Crawl GitHub repositories for legal templates."""
def __init__(self, token: Optional[str] = None):
self.token = token or GITHUB_TOKEN
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LegalTemplatesCrawler/1.0",
}
if self.token:
self.headers["Authorization"] = f"token {self.token}"
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
headers=self.headers,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
"""Parse repository URL into owner, repo, and host."""
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL: {url}")
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
if 'gitlab' in parsed.netloc:
host = 'gitlab'
else:
host = 'github'
return owner, repo, host
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch of a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("default_branch", "main")
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
"""Get the latest commit SHA for a branch."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("sha", "")
async def list_files(
self,
owner: str,
repo: str,
path: str = "",
branch: str = "main",
patterns: List[str] = None,
exclude_patterns: List[str] = None,
) -> List[Dict[str, Any]]:
"""List files in a repository matching the given patterns."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
patterns = patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = exclude_patterns or []
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("tree", []):
if item["type"] != "blob":
continue
file_path = item["path"]
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if excluded:
continue
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
if not matched:
continue
if item.get("size", 0) > MAX_FILE_SIZE:
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
continue
files.append({
"path": file_path,
"sha": item["sha"],
"size": item.get("size", 0),
"url": item.get("url", ""),
})
return files
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
"""Get the content of a file from a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
response = await self.http_client.get(url)
response.raise_for_status()
return response.text
async def crawl_repository(
self,
source: SourceConfig,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a repository and yield extracted documents."""
if not source.repo_url:
logger.warning(f"No repo URL for source: {source.name}")
return
try:
owner, repo, host = self._parse_repo_url(source.repo_url)
except ValueError as e:
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
return
if host == "gitlab":
logger.info(f"GitLab repos not yet supported: {source.name}")
return
logger.info(f"Crawling repository: {owner}/{repo}")
try:
branch = await self.get_default_branch(owner, repo)
commit_sha = await self.get_latest_commit(owner, repo, branch)
await asyncio.sleep(RATE_LIMIT_DELAY)
files = await self.list_files(
owner, repo,
branch=branch,
patterns=source.file_patterns,
exclude_patterns=source.exclude_patterns,
)
logger.info(f"Found {len(files)} matching files in {source.name}")
for file_info in files:
await asyncio.sleep(RATE_LIMIT_DELAY)
try:
content = await self.get_file_content(
owner, repo, file_info["path"], branch
)
file_path = file_info["path"]
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
if file_path.endswith('.md'):
doc = MarkdownParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.html') or file_path.endswith('.htm'):
doc = HTMLParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.json'):
docs = JSONParser.parse(content, file_path)
for doc in docs:
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.txt'):
yield ExtractedDocument(
text=content,
title=Path(file_path).stem,
file_path=file_path,
file_type="text",
source_url=source_url,
source_commit=commit_sha,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
except httpx.HTTPError as e:
logger.error(f"HTTP error crawling {source.name}: {e}")
except Exception as e:
logger.error(f"Error crawling {source.name}: {e}")
class RepositoryDownloader:
"""Download and extract repository archives."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
"""Download repository as ZIP and extract to temp directory."""
if not self.http_client:
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
parsed = urlparse(repo_url)
path_parts = parsed.path.strip('/').split('/')
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
logger.info(f"Downloading ZIP from {zip_url}")
response = await self.http_client.get(zip_url)
response.raise_for_status()
temp_dir = Path(tempfile.mkdtemp())
zip_path = temp_dir / f"{repo}.zip"
with open(zip_path, 'wb') as f:
f.write(response.content)
extract_dir = temp_dir / repo
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
if extracted_dirs:
return extracted_dirs[0]
return extract_dir
async def crawl_local_directory(
self,
directory: Path,
source: SourceConfig,
base_url: str,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a local directory for documents."""
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = source.exclude_patterns or []
for pattern in patterns:
for file_path in directory.rglob(pattern.replace("**/", "")):
if not file_path.is_file():
continue
rel_path = str(file_path.relative_to(directory))
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
if excluded:
continue
if file_path.stat().st_size > MAX_FILE_SIZE:
continue
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
try:
content = file_path.read_text(encoding='latin-1')
except Exception:
continue
source_url = f"{base_url}/{rel_path}"
if file_path.suffix == '.md':
doc = MarkdownParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix in ['.html', '.htm']:
doc = HTMLParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix == '.json':
docs = JSONParser.parse(content, rel_path)
for doc in docs:
doc.source_url = source_url
yield doc
elif file_path.suffix == '.txt':
yield ExtractedDocument(
text=content,
title=file_path.stem,
file_path=rel_path,
file_type="text",
source_url=source_url,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
def cleanup(self, directory: Path):
"""Clean up temporary directory."""
if directory.exists():
shutil.rmtree(directory, ignore_errors=True)
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
"""Crawl a source configuration and return all extracted documents."""
documents = []
if source.repo_url:
async with GitHubCrawler() as crawler:
async for doc in crawler.crawl_repository(source):
documents.append(doc)
return documents
# CLI for testing
async def main():
"""Test crawler with a sample source."""
from template_sources import TEMPLATE_SOURCES
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
async with GitHubCrawler() as crawler:
count = 0
async for doc in crawler.crawl_repository(source):
count += 1
print(f"\n{'='*60}")
print(f"Title: {doc.title}")
print(f"Path: {doc.file_path}")
print(f"Type: {doc.file_type}")
print(f"Language: {doc.language}")
print(f"URL: {doc.source_url}")
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
print(f"Text preview: {doc.text[:200]}...")
print(f"\n\nTotal documents: {count}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,303 @@
"""
GitHub Crawler - Document Parsers
Markdown, HTML, and JSON parsers for extracting structured content
from legal template documents.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import hashlib
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass
class ExtractedDocument:
"""A document extracted from a repository."""
text: str
title: str
file_path: str
file_type: str # "markdown", "html", "json", "text"
source_url: str
source_commit: Optional[str] = None
source_hash: str = "" # SHA256 of original content
sections: List[Dict[str, Any]] = field(default_factory=list)
placeholders: List[str] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.source_hash and self.text:
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
class MarkdownParser:
"""Parse Markdown files into structured content."""
# Common placeholder patterns
PLACEHOLDER_PATTERNS = [
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
r'\{([a-z_]+)\}', # {company_name}
r'\{\{([a-z_]+)\}\}', # {{company_name}}
r'__([A-Z_]+)__', # __COMPANY_NAME__
r'<([A-Z_]+)>', # <COMPANY_NAME>
]
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse markdown content into an ExtractedDocument."""
title = cls._extract_title(content, filename)
sections = cls._extract_sections(content)
placeholders = cls._find_placeholders(content)
language = cls._detect_language(content)
clean_text = cls._clean_for_indexing(content)
return ExtractedDocument(
text=clean_text,
title=title,
file_path=filename,
file_type="markdown",
source_url="",
sections=sections,
placeholders=placeholders,
language=language,
)
@classmethod
def _extract_title(cls, content: str, filename: str) -> str:
"""Extract title from markdown heading or filename."""
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if h1_match:
return h1_match.group(1).strip()
frontmatter_match = re.search(
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
content, re.DOTALL
)
if frontmatter_match:
return frontmatter_match.group(1).strip()
if filename:
name = Path(filename).stem
return name.replace('-', ' ').replace('_', ' ').title()
return "Untitled"
@classmethod
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
"""Extract sections from markdown content."""
sections = []
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = current_section["content"].strip()
sections.append(current_section.copy())
level = len(match.group(1))
heading = match.group(2).strip()
current_section = {
"heading": heading,
"level": level,
"content": "",
"start": match.end(),
}
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = content[current_section["start"]:].strip()
sections.append(current_section)
return sections
@classmethod
def _find_placeholders(cls, content: str) -> List[str]:
"""Find placeholder patterns in content."""
placeholders = set()
for pattern in cls.PLACEHOLDER_PATTERNS:
for match in re.finditer(pattern, content):
placeholder = match.group(0)
placeholders.add(placeholder)
return sorted(list(placeholders))
@classmethod
def _detect_language(cls, content: str) -> str:
"""Detect language from content."""
german_indicators = [
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
]
lower_content = content.lower()
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
if german_count >= 3:
return "de"
return "en"
@classmethod
def _clean_for_indexing(cls, content: str) -> str:
"""Clean markdown content for text indexing."""
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
content = re.sub(r'<[^>]+>', '', content)
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
content = re.sub(r'\*(.+?)\*', r'\1', content)
content = re.sub(r'`(.+?)`', r'\1', content)
content = re.sub(r'~~(.+?)~~', r'\1', content)
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' +', ' ', content)
return content.strip()
class HTMLParser:
"""Parse HTML files into structured content."""
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse HTML content into an ExtractedDocument."""
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
title = title_match.group(1) if title_match else Path(filename).stem
text = cls._html_to_text(content)
placeholders = MarkdownParser._find_placeholders(text)
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
return ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="html",
source_url="",
placeholders=placeholders,
language=language,
)
@classmethod
def _html_to_text(cls, html: str) -> str:
"""Convert HTML to clean text."""
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
html = html.replace('&nbsp;', ' ')
html = html.replace('&amp;', '&')
html = html.replace('&lt;', '<')
html = html.replace('&gt;', '>')
html = html.replace('&quot;', '"')
html = html.replace('&apos;', "'")
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'<[^>]+>', '', html)
html = re.sub(r'[ \t]+', ' ', html)
html = re.sub(r'\n[ \t]+', '\n', html)
html = re.sub(r'[ \t]+\n', '\n', html)
html = re.sub(r'\n{3,}', '\n\n', html)
return html.strip()
class JSONParser:
"""Parse JSON files containing legal template data."""
@classmethod
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
"""Parse JSON content into ExtractedDocuments."""
try:
data = json.loads(content)
except json.JSONDecodeError as e:
import logging
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
return []
documents = []
if isinstance(data, dict):
documents.extend(cls._parse_dict(data, filename))
elif isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, dict):
docs = cls._parse_dict(item, f"{filename}[{i}]")
documents.extend(docs)
return documents
@classmethod
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
"""Parse a dictionary into documents."""
documents = []
text_keys = ['text', 'content', 'body', 'description', 'value']
title_keys = ['title', 'name', 'heading', 'label', 'key']
text = ""
for key in text_keys:
if key in data and isinstance(data[key], str):
text = data[key]
break
if not text:
for key, value in data.items():
if isinstance(value, dict):
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
documents.extend(nested_docs)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
documents.extend(nested_docs)
elif isinstance(item, str) and len(item) > 50:
documents.append(ExtractedDocument(
text=item,
title=f"{key} {i+1}",
file_path=filename,
file_type="json",
source_url="",
language=MarkdownParser._detect_language(item),
))
return documents
title = ""
for key in title_keys:
if key in data and isinstance(data[key], str):
title = data[key]
break
if not title:
title = Path(filename).stem
metadata = {}
for key, value in data.items():
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
metadata[key] = value
placeholders = MarkdownParser._find_placeholders(text)
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
documents.append(ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="json",
source_url="",
placeholders=placeholders,
language=language,
metadata=metadata,
))
return documents

View File

@@ -1,35 +1,4 @@
"""
GitHub Repository Crawler — Barrel Re-export
Split into:
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
All public names are re-exported here for backward compatibility.
"""
# Parsers
from github_crawler_parsers import ( # noqa: F401
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
# Crawler and downloader
from github_crawler_core import ( # noqa: F401
GITHUB_API_URL,
GITLAB_API_URL,
GITHUB_TOKEN,
MAX_FILE_SIZE,
REQUEST_TIMEOUT,
RATE_LIMIT_DELAY,
GitHubCrawler,
RepositoryDownloader,
crawl_source,
main,
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
# Backward-compat shim -- module moved to crawler/github.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("crawler.github")

View File

@@ -1,411 +1,4 @@
"""
GitHub Crawler - Core Crawler and Downloader
GitHubCrawler for API-based repository crawling and RepositoryDownloader
for ZIP-based local extraction.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import asyncio
import logging
import os
import shutil
import tempfile
import zipfile
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from template_sources import SourceConfig
from github_crawler_parsers import (
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
logger = logging.getLogger(__name__)
# Configuration
GITHUB_API_URL = "https://api.github.com"
GITLAB_API_URL = "https://gitlab.com/api/v4"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
REQUEST_TIMEOUT = 60.0
RATE_LIMIT_DELAY = 1.0
class GitHubCrawler:
"""Crawl GitHub repositories for legal templates."""
def __init__(self, token: Optional[str] = None):
self.token = token or GITHUB_TOKEN
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LegalTemplatesCrawler/1.0",
}
if self.token:
self.headers["Authorization"] = f"token {self.token}"
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
headers=self.headers,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
"""Parse repository URL into owner, repo, and host."""
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL: {url}")
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
if 'gitlab' in parsed.netloc:
host = 'gitlab'
else:
host = 'github'
return owner, repo, host
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch of a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("default_branch", "main")
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
"""Get the latest commit SHA for a branch."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("sha", "")
async def list_files(
self,
owner: str,
repo: str,
path: str = "",
branch: str = "main",
patterns: List[str] = None,
exclude_patterns: List[str] = None,
) -> List[Dict[str, Any]]:
"""List files in a repository matching the given patterns."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
patterns = patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = exclude_patterns or []
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("tree", []):
if item["type"] != "blob":
continue
file_path = item["path"]
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if excluded:
continue
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
if not matched:
continue
if item.get("size", 0) > MAX_FILE_SIZE:
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
continue
files.append({
"path": file_path,
"sha": item["sha"],
"size": item.get("size", 0),
"url": item.get("url", ""),
})
return files
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
"""Get the content of a file from a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
response = await self.http_client.get(url)
response.raise_for_status()
return response.text
async def crawl_repository(
self,
source: SourceConfig,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a repository and yield extracted documents."""
if not source.repo_url:
logger.warning(f"No repo URL for source: {source.name}")
return
try:
owner, repo, host = self._parse_repo_url(source.repo_url)
except ValueError as e:
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
return
if host == "gitlab":
logger.info(f"GitLab repos not yet supported: {source.name}")
return
logger.info(f"Crawling repository: {owner}/{repo}")
try:
branch = await self.get_default_branch(owner, repo)
commit_sha = await self.get_latest_commit(owner, repo, branch)
await asyncio.sleep(RATE_LIMIT_DELAY)
files = await self.list_files(
owner, repo,
branch=branch,
patterns=source.file_patterns,
exclude_patterns=source.exclude_patterns,
)
logger.info(f"Found {len(files)} matching files in {source.name}")
for file_info in files:
await asyncio.sleep(RATE_LIMIT_DELAY)
try:
content = await self.get_file_content(
owner, repo, file_info["path"], branch
)
file_path = file_info["path"]
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
if file_path.endswith('.md'):
doc = MarkdownParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.html') or file_path.endswith('.htm'):
doc = HTMLParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.json'):
docs = JSONParser.parse(content, file_path)
for doc in docs:
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.txt'):
yield ExtractedDocument(
text=content,
title=Path(file_path).stem,
file_path=file_path,
file_type="text",
source_url=source_url,
source_commit=commit_sha,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
except httpx.HTTPError as e:
logger.error(f"HTTP error crawling {source.name}: {e}")
except Exception as e:
logger.error(f"Error crawling {source.name}: {e}")
class RepositoryDownloader:
"""Download and extract repository archives."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
"""Download repository as ZIP and extract to temp directory."""
if not self.http_client:
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
parsed = urlparse(repo_url)
path_parts = parsed.path.strip('/').split('/')
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
logger.info(f"Downloading ZIP from {zip_url}")
response = await self.http_client.get(zip_url)
response.raise_for_status()
temp_dir = Path(tempfile.mkdtemp())
zip_path = temp_dir / f"{repo}.zip"
with open(zip_path, 'wb') as f:
f.write(response.content)
extract_dir = temp_dir / repo
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
if extracted_dirs:
return extracted_dirs[0]
return extract_dir
async def crawl_local_directory(
self,
directory: Path,
source: SourceConfig,
base_url: str,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a local directory for documents."""
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = source.exclude_patterns or []
for pattern in patterns:
for file_path in directory.rglob(pattern.replace("**/", "")):
if not file_path.is_file():
continue
rel_path = str(file_path.relative_to(directory))
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
if excluded:
continue
if file_path.stat().st_size > MAX_FILE_SIZE:
continue
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
try:
content = file_path.read_text(encoding='latin-1')
except Exception:
continue
source_url = f"{base_url}/{rel_path}"
if file_path.suffix == '.md':
doc = MarkdownParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix in ['.html', '.htm']:
doc = HTMLParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix == '.json':
docs = JSONParser.parse(content, rel_path)
for doc in docs:
doc.source_url = source_url
yield doc
elif file_path.suffix == '.txt':
yield ExtractedDocument(
text=content,
title=file_path.stem,
file_path=rel_path,
file_type="text",
source_url=source_url,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
def cleanup(self, directory: Path):
"""Clean up temporary directory."""
if directory.exists():
shutil.rmtree(directory, ignore_errors=True)
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
"""Crawl a source configuration and return all extracted documents."""
documents = []
if source.repo_url:
async with GitHubCrawler() as crawler:
async for doc in crawler.crawl_repository(source):
documents.append(doc)
return documents
# CLI for testing
async def main():
"""Test crawler with a sample source."""
from template_sources import TEMPLATE_SOURCES
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
async with GitHubCrawler() as crawler:
count = 0
async for doc in crawler.crawl_repository(source):
count += 1
print(f"\n{'='*60}")
print(f"Title: {doc.title}")
print(f"Path: {doc.file_path}")
print(f"Type: {doc.file_type}")
print(f"Language: {doc.language}")
print(f"URL: {doc.source_url}")
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
print(f"Text preview: {doc.text[:200]}...")
print(f"\n\nTotal documents: {count}")
if __name__ == "__main__":
asyncio.run(main())
# Backward-compat shim -- module moved to crawler/github_core.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("crawler.github_core")

View File

@@ -1,303 +1,4 @@
"""
GitHub Crawler - Document Parsers
Markdown, HTML, and JSON parsers for extracting structured content
from legal template documents.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import hashlib
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass
class ExtractedDocument:
"""A document extracted from a repository."""
text: str
title: str
file_path: str
file_type: str # "markdown", "html", "json", "text"
source_url: str
source_commit: Optional[str] = None
source_hash: str = "" # SHA256 of original content
sections: List[Dict[str, Any]] = field(default_factory=list)
placeholders: List[str] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.source_hash and self.text:
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
class MarkdownParser:
"""Parse Markdown files into structured content."""
# Common placeholder patterns
PLACEHOLDER_PATTERNS = [
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
r'\{([a-z_]+)\}', # {company_name}
r'\{\{([a-z_]+)\}\}', # {{company_name}}
r'__([A-Z_]+)__', # __COMPANY_NAME__
r'<([A-Z_]+)>', # <COMPANY_NAME>
]
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse markdown content into an ExtractedDocument."""
title = cls._extract_title(content, filename)
sections = cls._extract_sections(content)
placeholders = cls._find_placeholders(content)
language = cls._detect_language(content)
clean_text = cls._clean_for_indexing(content)
return ExtractedDocument(
text=clean_text,
title=title,
file_path=filename,
file_type="markdown",
source_url="",
sections=sections,
placeholders=placeholders,
language=language,
)
@classmethod
def _extract_title(cls, content: str, filename: str) -> str:
"""Extract title from markdown heading or filename."""
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if h1_match:
return h1_match.group(1).strip()
frontmatter_match = re.search(
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
content, re.DOTALL
)
if frontmatter_match:
return frontmatter_match.group(1).strip()
if filename:
name = Path(filename).stem
return name.replace('-', ' ').replace('_', ' ').title()
return "Untitled"
@classmethod
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
"""Extract sections from markdown content."""
sections = []
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = current_section["content"].strip()
sections.append(current_section.copy())
level = len(match.group(1))
heading = match.group(2).strip()
current_section = {
"heading": heading,
"level": level,
"content": "",
"start": match.end(),
}
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = content[current_section["start"]:].strip()
sections.append(current_section)
return sections
@classmethod
def _find_placeholders(cls, content: str) -> List[str]:
"""Find placeholder patterns in content."""
placeholders = set()
for pattern in cls.PLACEHOLDER_PATTERNS:
for match in re.finditer(pattern, content):
placeholder = match.group(0)
placeholders.add(placeholder)
return sorted(list(placeholders))
@classmethod
def _detect_language(cls, content: str) -> str:
"""Detect language from content."""
german_indicators = [
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
]
lower_content = content.lower()
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
if german_count >= 3:
return "de"
return "en"
@classmethod
def _clean_for_indexing(cls, content: str) -> str:
"""Clean markdown content for text indexing."""
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
content = re.sub(r'<[^>]+>', '', content)
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
content = re.sub(r'\*(.+?)\*', r'\1', content)
content = re.sub(r'`(.+?)`', r'\1', content)
content = re.sub(r'~~(.+?)~~', r'\1', content)
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' +', ' ', content)
return content.strip()
class HTMLParser:
"""Parse HTML files into structured content."""
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse HTML content into an ExtractedDocument."""
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
title = title_match.group(1) if title_match else Path(filename).stem
text = cls._html_to_text(content)
placeholders = MarkdownParser._find_placeholders(text)
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
return ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="html",
source_url="",
placeholders=placeholders,
language=language,
)
@classmethod
def _html_to_text(cls, html: str) -> str:
"""Convert HTML to clean text."""
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
html = html.replace('&nbsp;', ' ')
html = html.replace('&amp;', '&')
html = html.replace('&lt;', '<')
html = html.replace('&gt;', '>')
html = html.replace('&quot;', '"')
html = html.replace('&apos;', "'")
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'<[^>]+>', '', html)
html = re.sub(r'[ \t]+', ' ', html)
html = re.sub(r'\n[ \t]+', '\n', html)
html = re.sub(r'[ \t]+\n', '\n', html)
html = re.sub(r'\n{3,}', '\n\n', html)
return html.strip()
class JSONParser:
"""Parse JSON files containing legal template data."""
@classmethod
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
"""Parse JSON content into ExtractedDocuments."""
try:
data = json.loads(content)
except json.JSONDecodeError as e:
import logging
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
return []
documents = []
if isinstance(data, dict):
documents.extend(cls._parse_dict(data, filename))
elif isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, dict):
docs = cls._parse_dict(item, f"{filename}[{i}]")
documents.extend(docs)
return documents
@classmethod
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
"""Parse a dictionary into documents."""
documents = []
text_keys = ['text', 'content', 'body', 'description', 'value']
title_keys = ['title', 'name', 'heading', 'label', 'key']
text = ""
for key in text_keys:
if key in data and isinstance(data[key], str):
text = data[key]
break
if not text:
for key, value in data.items():
if isinstance(value, dict):
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
documents.extend(nested_docs)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
documents.extend(nested_docs)
elif isinstance(item, str) and len(item) > 50:
documents.append(ExtractedDocument(
text=item,
title=f"{key} {i+1}",
file_path=filename,
file_type="json",
source_url="",
language=MarkdownParser._detect_language(item),
))
return documents
title = ""
for key in title_keys:
if key in data and isinstance(data[key], str):
title = data[key]
break
if not title:
title = Path(filename).stem
metadata = {}
for key, value in data.items():
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
metadata[key] = value
placeholders = MarkdownParser._find_placeholders(text)
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
documents.append(ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="json",
source_url="",
placeholders=placeholders,
language=language,
metadata=metadata,
))
return documents
# Backward-compat shim -- module moved to crawler/github_parsers.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("crawler.github_parsers")

View File

@@ -0,0 +1,425 @@
"""
Unified Grid Builder — merges multi-zone grid into a single Excel-like grid.
Takes content zone + box zones and produces one unified zone where:
- All content rows use the dominant row height
- Full-width boxes are integrated directly (box rows replace standard rows)
- Partial-width boxes: extra rows inserted if box has more lines than standard
- Box-origin cells carry metadata (bg_color, border) for visual distinction
The result is a single-zone StructuredGrid that can be:
- Rendered in an Excel-like editor
- Exported to Excel/CSV
- Edited with unified row/column numbering
"""
import logging
import math
import statistics
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def _compute_dominant_row_height(content_zone: Dict) -> float:
"""Median of content row-to-row spacings, excluding box-gap jumps."""
rows = content_zone.get("rows", [])
if len(rows) < 2:
return 47.0
spacings = []
for i in range(len(rows) - 1):
y1 = rows[i].get("y_min_px", rows[i].get("y_min", 0))
y2 = rows[i + 1].get("y_min_px", rows[i + 1].get("y_min", 0))
d = y2 - y1
if 0 < d < 100: # exclude box-gap jumps
spacings.append(d)
if not spacings:
return 47.0
spacings.sort()
return spacings[len(spacings) // 2]
def _classify_boxes(
box_zones: List[Dict],
content_width: float,
) -> List[Dict]:
"""Classify each box as full_width or partial_width."""
result = []
for bz in box_zones:
bb = bz.get("bbox_px", {})
bw = bb.get("w", 0)
bx = bb.get("x", 0)
if bw >= content_width * 0.85:
classification = "full_width"
side = "center"
else:
classification = "partial_width"
# Determine which side of the page the box is on
page_center = content_width / 2
box_center = bx + bw / 2
side = "right" if box_center > page_center else "left"
# Count total text lines in box (including \n within cells)
total_lines = sum(
(c.get("text", "").count("\n") + 1)
for c in bz.get("cells", [])
)
result.append({
"zone": bz,
"classification": classification,
"side": side,
"y_start": bb.get("y", 0),
"y_end": bb.get("y", 0) + bb.get("h", 0),
"total_lines": total_lines,
"bg_hex": bz.get("box_bg_hex", ""),
"bg_color": bz.get("box_bg_color", ""),
})
return result
def build_unified_grid(
zones: List[Dict],
image_width: int,
image_height: int,
layout_metrics: Dict,
) -> Dict[str, Any]:
"""Build a single-zone unified grid from multi-zone grid data.
Returns a StructuredGrid with one zone containing all rows and cells.
"""
content_zone = None
box_zones = []
for z in zones:
if z.get("zone_type") == "content":
content_zone = z
elif z.get("zone_type") == "box":
box_zones.append(z)
if not content_zone:
logger.warning("build_unified_grid: no content zone found")
return {"zones": zones} # fallback: return as-is
box_zones.sort(key=lambda b: b.get("bbox_px", {}).get("y", 0))
dominant_h = _compute_dominant_row_height(content_zone)
content_bbox = content_zone.get("bbox_px", {})
content_width = content_bbox.get("w", image_width)
content_x = content_bbox.get("x", 0)
content_cols = content_zone.get("columns", [])
num_cols = len(content_cols)
box_infos = _classify_boxes(box_zones, content_width)
logger.info(
"build_unified_grid: dominant_h=%.1f, %d content rows, %d boxes (%s)",
dominant_h, len(content_zone.get("rows", [])), len(box_infos),
[b["classification"] for b in box_infos],
)
# --- Build unified row list + cell list ---
unified_rows: List[Dict] = []
unified_cells: List[Dict] = []
unified_row_idx = 0
# Content rows and cells indexed by row_index
content_rows = content_zone.get("rows", [])
content_cells = content_zone.get("cells", [])
content_cells_by_row: Dict[int, List[Dict]] = {}
for c in content_cells:
content_cells_by_row.setdefault(c.get("row_index", -1), []).append(c)
# Track which content rows we've processed
content_row_ptr = 0
for bi, box_info in enumerate(box_infos):
bz = box_info["zone"]
by_start = box_info["y_start"]
by_end = box_info["y_end"]
# --- Add content rows ABOVE this box ---
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry >= by_start:
break
# Add this content row
_add_content_row(
unified_rows, unified_cells, unified_row_idx,
cr, content_cells_by_row, dominant_h, image_height,
)
unified_row_idx += 1
content_row_ptr += 1
# --- Add box rows ---
if box_info["classification"] == "full_width":
# Full-width box: integrate box rows directly
_add_full_width_box(
unified_rows, unified_cells, unified_row_idx,
bz, box_info, dominant_h, num_cols, image_height,
)
unified_row_idx += len(bz.get("rows", []))
# Skip content rows that overlap with this box
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry > by_end:
break
content_row_ptr += 1
else:
# Partial-width box: merge with adjacent content rows
unified_row_idx = _add_partial_width_box(
unified_rows, unified_cells, unified_row_idx,
bz, box_info, content_rows, content_cells_by_row,
content_row_ptr, dominant_h, num_cols, image_height,
content_x, content_width,
)
# Advance content pointer past box region
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry > by_end:
break
content_row_ptr += 1
# --- Add remaining content rows BELOW all boxes ---
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
_add_content_row(
unified_rows, unified_cells, unified_row_idx,
cr, content_cells_by_row, dominant_h, image_height,
)
unified_row_idx += 1
content_row_ptr += 1
# --- Build unified zone ---
unified_zone = {
"zone_index": 0,
"zone_type": "unified",
"bbox_px": content_bbox,
"bbox_pct": content_zone.get("bbox_pct", {}),
"border": None,
"word_count": sum(len(c.get("word_boxes", [])) for c in unified_cells),
"columns": content_cols,
"rows": unified_rows,
"cells": unified_cells,
"header_rows": [],
}
logger.info(
"build_unified_grid: %d unified rows, %d cells (from %d content + %d box zones)",
len(unified_rows), len(unified_cells),
len(content_rows), len(box_zones),
)
return {
"zones": [unified_zone],
"image_width": image_width,
"image_height": image_height,
"layout_metrics": layout_metrics,
"summary": {
"total_zones": 1,
"total_columns": num_cols,
"total_rows": len(unified_rows),
"total_cells": len(unified_cells),
},
"is_unified": True,
"dominant_row_h": dominant_h,
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_row(idx: int, y: float, h: float, img_h: int, is_header: bool = False) -> Dict:
return {
"index": idx,
"row_index": idx,
"y_min_px": round(y),
"y_max_px": round(y + h),
"y_min_pct": round(y / img_h * 100, 2) if img_h else 0,
"y_max_pct": round((y + h) / img_h * 100, 2) if img_h else 0,
"is_header": is_header,
}
def _remap_cell(cell: Dict, new_row: int, new_col: int = None,
source_type: str = "content", box_region: Dict = None) -> Dict:
"""Create a new cell dict with remapped indices."""
c = dict(cell)
c["row_index"] = new_row
if new_col is not None:
c["col_index"] = new_col
c["cell_id"] = f"U_R{new_row:02d}_C{c.get('col_index', 0)}"
c["source_zone_type"] = source_type
if box_region:
c["box_region"] = box_region
return c
def _add_content_row(
unified_rows, unified_cells, row_idx,
content_row, cells_by_row, dominant_h, img_h,
):
"""Add a single content row to the unified grid."""
y = content_row.get("y_min_px", content_row.get("y_min", 0))
is_hdr = content_row.get("is_header", False)
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h, is_hdr))
for cell in cells_by_row.get(content_row.get("index", -1), []):
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
def _add_full_width_box(
unified_rows, unified_cells, start_row_idx,
box_zone, box_info, dominant_h, num_cols, img_h,
):
"""Add a full-width box's rows to the unified grid."""
box_rows = box_zone.get("rows", [])
box_cells = box_zone.get("cells", [])
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
# Distribute box height evenly among its rows
box_h = box_info["y_end"] - box_info["y_start"]
row_h = box_h / len(box_rows) if box_rows else dominant_h
for i, br in enumerate(box_rows):
y = box_info["y_start"] + i * row_h
new_idx = start_row_idx + i
is_hdr = br.get("is_header", False)
unified_rows.append(_make_row(new_idx, y, row_h, img_h, is_hdr))
for cell in box_cells:
if cell.get("row_index") == br.get("index", i):
unified_cells.append(
_remap_cell(cell, new_idx, source_type="box", box_region=box_region)
)
def _add_partial_width_box(
unified_rows, unified_cells, start_row_idx,
box_zone, box_info, content_rows, content_cells_by_row,
content_row_ptr, dominant_h, num_cols, img_h,
content_x, content_width,
) -> int:
"""Add a partial-width box merged with content rows.
Returns the next unified_row_idx after processing.
"""
by_start = box_info["y_start"]
by_end = box_info["y_end"]
box_h = by_end - by_start
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
# Content rows in the box's Y range
overlap_content_rows = []
ptr = content_row_ptr
while ptr < len(content_rows):
cr = content_rows[ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry > by_end:
break
if cry >= by_start:
overlap_content_rows.append(cr)
ptr += 1
# How many standard rows fit in the box height
standard_rows = max(1, math.floor(box_h / dominant_h))
# How many text lines the box actually has
box_text_lines = box_info["total_lines"]
# Extra rows needed
extra_rows = max(0, box_text_lines - standard_rows)
total_rows_for_region = standard_rows + extra_rows
logger.info(
"partial box: standard=%d, box_lines=%d, extra=%d, content_overlap=%d",
standard_rows, box_text_lines, extra_rows, len(overlap_content_rows),
)
# Determine which columns the box occupies
box_bb = box_zone.get("bbox_px", {})
box_x = box_bb.get("x", 0)
box_w = box_bb.get("w", 0)
# Map box to content columns: find which content columns overlap
box_col_start = 0
box_col_end = num_cols
content_cols_list = []
for z_col_idx in range(num_cols):
# Find the column definition by checking all column entries
# Simple heuristic: if box starts past halfway, it's the right columns
pass
# Simpler approach: box on right side → last N columns
# box on left side → first N columns
if box_info["side"] == "right":
# Box starts at x=box_x. Find first content column that overlaps
box_col_start = num_cols # default: beyond all columns
for z in (box_zone.get("columns") or [{"index": 0}]):
pass
# Use content column positions to determine overlap
content_cols_data = [
{"idx": c.get("index", i), "x_min": c.get("x_min_px", 0), "x_max": c.get("x_max_px", 0)}
for i, c in enumerate(content_rows[0:0] or []) # placeholder
]
# Simple: split columns at midpoint
box_col_start = num_cols // 2 # right half
box_col_end = num_cols
else:
box_col_start = 0
box_col_end = num_cols // 2
# Build rows for this region
box_cells = box_zone.get("cells", [])
box_rows = box_zone.get("rows", [])
row_idx = start_row_idx
# Expand box cell texts with \n into individual lines for row mapping
box_lines: List[Tuple[str, Dict]] = [] # (text_line, parent_cell)
for bc in sorted(box_cells, key=lambda c: c.get("row_index", 0)):
text = bc.get("text", "")
for line in text.split("\n"):
box_lines.append((line.strip(), bc))
for i in range(total_rows_for_region):
y = by_start + i * dominant_h
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h))
# Content cells for this row (from overlapping content rows)
if i < len(overlap_content_rows):
cr = overlap_content_rows[i]
for cell in content_cells_by_row.get(cr.get("index", -1), []):
# Only include cells from columns NOT covered by the box
ci = cell.get("col_index", 0)
if ci < box_col_start or ci >= box_col_end:
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
# Box cell for this row
if i < len(box_lines):
line_text, parent_cell = box_lines[i]
box_cell = {
"cell_id": f"U_R{row_idx:02d}_C{box_col_start}",
"row_index": row_idx,
"col_index": box_col_start,
"col_type": "spanning_header" if (box_col_end - box_col_start) > 1 else parent_cell.get("col_type", "column_1"),
"colspan": box_col_end - box_col_start,
"text": line_text,
"confidence": parent_cell.get("confidence", 0),
"bbox_px": parent_cell.get("bbox_px", {}),
"bbox_pct": parent_cell.get("bbox_pct", {}),
"word_boxes": [],
"ocr_engine": parent_cell.get("ocr_engine", ""),
"is_bold": parent_cell.get("is_bold", False),
"source_zone_type": "box",
"box_region": box_region,
}
unified_cells.append(box_cell)
row_idx += 1
return row_idx

View File

@@ -1,276 +1,4 @@
"""
Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen.
Endpoints:
- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text
- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen
Modell-Strategie:
1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM)
2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
"""
import io
import os
import logging
import time
import base64
from typing import Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/htr", tags=["HTR"])
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large")
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class HTRSessionRequest(BaseModel):
session_id: str
model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large"
use_clean: bool = True # Prefer clean_png (after handwriting removal)
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray:
"""
CLAHE contrast enhancement + upscale to improve HTR accuracy.
Returns grayscale enhanced image.
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Upscale if image is too small
h, w = enhanced.shape
if min(h, w) < 800:
scale = 800 / min(h, w)
enhanced = cv2.resize(
enhanced, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC
)
return enhanced
def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes:
"""Convert BGR ndarray to PNG bytes."""
success, buf = cv2.imencode(".png", img_bgr)
if not success:
raise RuntimeError("Failed to encode image to PNG")
return buf.tobytes()
def _preprocess_image_bytes(image_bytes: bytes) -> bytes:
"""Load image, apply HTR preprocessing, return PNG bytes."""
arr = np.frombuffer(image_bytes, dtype=np.uint8)
img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise ValueError("Could not decode image")
enhanced = _preprocess_for_htr(img_bgr)
# Convert grayscale back to BGR for encoding
enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
return _bgr_to_png_bytes(enhanced_bgr)
# ---------------------------------------------------------------------------
# Backend: Ollama qwen2.5vl
# ---------------------------------------------------------------------------
async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]:
"""
Send image to Ollama qwen2.5vl:32b for HTR.
Returns extracted text or None on error.
"""
import httpx
lang_hint = {
"de": "Deutsch",
"en": "Englisch",
"de+en": "Deutsch und Englisch",
}.get(language, "Deutsch")
prompt = (
f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. "
"Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. "
"Antworte NUR mit dem erkannten Text, ohne Erklaerungen."
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": OLLAMA_HTR_MODEL,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload)
resp.raise_for_status()
data = resp.json()
return data.get("response", "").strip()
except Exception as e:
logger.warning(f"Ollama qwen2.5vl HTR failed: {e}")
return None
# ---------------------------------------------------------------------------
# Backend: TrOCR-large fallback
# ---------------------------------------------------------------------------
async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]:
"""
Use microsoft/trocr-large-handwritten via trocr_service.py.
Returns extracted text or None on error.
"""
try:
from services.trocr_service import run_trocr_ocr, _check_trocr_available
if not _check_trocr_available():
logger.warning("TrOCR not available for HTR fallback")
return None
text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large")
return text.strip() if text else None
except Exception as e:
logger.warning(f"TrOCR-large HTR failed: {e}")
return None
# ---------------------------------------------------------------------------
# Core recognition logic
# ---------------------------------------------------------------------------
async def _do_recognize(
image_bytes: bytes,
model: str = "auto",
preprocess: bool = True,
language: str = "de",
) -> dict:
"""
Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large.
Returns dict with text, model_used, processing_time_ms.
"""
t0 = time.monotonic()
if preprocess:
try:
image_bytes = _preprocess_image_bytes(image_bytes)
except Exception as e:
logger.warning(f"HTR preprocessing failed, using raw image: {e}")
text: Optional[str] = None
model_used: str = "none"
use_qwen = model in ("auto", "qwen2.5vl")
use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None)
if use_qwen:
text = await _recognize_with_qwen_vl(image_bytes, language)
if text is not None:
model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})"
if text is None and (use_trocr or model == "trocr-large"):
text = await _recognize_with_trocr_large(image_bytes)
if text is not None:
model_used = "trocr-large-handwritten"
if text is None:
text = ""
model_used = "none (all backends failed)"
elapsed_ms = int((time.monotonic() - t0) * 1000)
return {
"text": text,
"model_used": model_used,
"processing_time_ms": elapsed_ms,
"language": language,
"preprocessed": preprocess,
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/recognize")
async def recognize_handwriting(
file: UploadFile = File(...),
model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"),
preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"),
language: str = Query("de", description="de | en | de+en"),
):
"""
Upload an image and get back the handwritten text as plain text.
Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten.
"""
if model not in ("auto", "qwen2.5vl", "trocr-large"):
raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large")
if language not in ("de", "en", "de+en"):
raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en")
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Empty file")
return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language)
@router.post("/recognize-session")
async def recognize_from_session(req: HTRSessionRequest):
"""
Use an OCR-Pipeline session as image source for HTR.
Set use_clean=true to prefer the clean image (after handwriting removal step).
This is useful when you want to do HTR on isolated handwriting regions.
"""
from ocr_pipeline_session_store import get_session_db, get_session_image
session = await get_session_db(req.session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found")
# Choose source image
image_bytes: Optional[bytes] = None
source_used: str = ""
if req.use_clean:
image_bytes = await get_session_image(req.session_id, "clean")
if image_bytes:
source_used = "clean"
if not image_bytes:
image_bytes = await get_session_image(req.session_id, "deskewed")
if image_bytes:
source_used = "deskewed"
if not image_bytes:
image_bytes = await get_session_image(req.session_id, "original")
source_used = "original"
if not image_bytes:
raise HTTPException(status_code=404, detail="No image available in session")
result = await _do_recognize(image_bytes, model=req.model)
result["session_id"] = req.session_id
result["source_image"] = source_used
return result
# Backward-compat shim -- module moved to ocr/pipeline/htr_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.htr_api")

View File

@@ -0,0 +1,346 @@
"""
Tesseract-based OCR extraction with word-level bounding boxes.
Uses Tesseract for spatial information (WHERE text is) while
the Vision LLM handles semantic understanding (WHAT the text means).
Tesseract runs natively on ARM64 via Debian's apt package.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import io
import logging
from typing import List, Dict, Any, Optional
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
try:
import pytesseract
from PIL import Image
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
"""Run Tesseract OCR and return word-level bounding boxes.
Args:
image_bytes: PNG/JPEG image as bytes.
lang: Tesseract language string (e.g. "eng+deu").
Returns:
Dict with 'words' list and 'image_width'/'image_height'.
"""
if not TESSERACT_AVAILABLE:
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
image = Image.open(io.BytesIO(image_bytes))
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
words = []
for i in range(len(data['text'])):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if not text or conf < 20:
continue
words.append({
"text": text,
"left": data['left'][i],
"top": data['top'][i],
"width": data['width'][i],
"height": data['height'][i],
"conf": conf,
"block_num": data['block_num'][i],
"par_num": data['par_num'][i],
"line_num": data['line_num'][i],
"word_num": data['word_num'][i],
})
return {
"words": words,
"image_width": image.width,
"image_height": image.height,
}
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
"""Group words by their Y position into lines.
Args:
words: List of word dicts from extract_bounding_boxes.
y_tolerance_px: Max pixel distance to consider words on the same line.
Returns:
List of lines, each line is a list of words sorted by X position.
"""
if not words:
return []
# Sort by Y then X
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
lines: List[List[dict]] = []
current_line: List[dict] = [sorted_words[0]]
current_y = sorted_words[0]['top']
for word in sorted_words[1:]:
if abs(word['top'] - current_y) <= y_tolerance_px:
current_line.append(word)
else:
current_line.sort(key=lambda w: w['left'])
lines.append(current_line)
current_line = [word]
current_y = word['top']
if current_line:
current_line.sort(key=lambda w: w['left'])
lines.append(current_line)
return lines
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
"""Detect column boundaries from word positions.
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
Returns:
Dict with column boundaries and type assignments.
"""
if not lines or image_width == 0:
return {"columns": [], "column_types": []}
# Collect all word X positions
all_x_positions = []
for line in lines:
for word in line:
all_x_positions.append(word['left'])
if not all_x_positions:
return {"columns": [], "column_types": []}
# Find X-position clusters (column starts)
all_x_positions.sort()
# Simple gap-based column detection
min_gap = image_width * 0.08 # 8% of page width = column gap
clusters = []
current_cluster = [all_x_positions[0]]
for x in all_x_positions[1:]:
if x - current_cluster[-1] > min_gap:
clusters.append(current_cluster)
current_cluster = [x]
else:
current_cluster.append(x)
if current_cluster:
clusters.append(current_cluster)
# Each cluster represents a column start
columns = []
for cluster in clusters:
col_start = min(cluster)
columns.append({
"x_start": col_start,
"x_start_pct": col_start / image_width * 100,
"word_count": len(cluster),
})
# Assign column types based on position (left→right: EN, DE, Example)
type_map = ["english", "german", "example"]
column_types = []
for i, col in enumerate(columns):
if i < len(type_map):
column_types.append(type_map[i])
else:
column_types.append("unknown")
return {
"columns": columns,
"column_types": column_types,
}
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
column_types: List[str], image_width: int,
image_height: int) -> List[dict]:
"""Convert grouped words into vocabulary entries using column positions.
Args:
lines: Grouped word lines from group_words_into_lines.
columns: Column boundaries from detect_columns.
column_types: Column type assignments.
image_width: Image width in pixels.
image_height: Image height in pixels.
Returns:
List of vocabulary entry dicts with english/german/example fields.
"""
if not columns or not lines:
return []
# Build column boundaries for word assignment
col_boundaries = []
for i, col in enumerate(columns):
start = col['x_start']
if i + 1 < len(columns):
end = columns[i + 1]['x_start']
else:
end = image_width
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
entries = []
for line in lines:
entry = {"english": "", "german": "", "example": ""}
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
line_bbox: Dict[str, Optional[dict]] = {}
for word in line:
word_center_x = word['left'] + word['width'] / 2
assigned_type = "unknown"
for start, end, col_type in col_boundaries:
if start <= word_center_x < end:
assigned_type = col_type
break
if assigned_type in line_words_by_col:
line_words_by_col[assigned_type].append(word['text'])
# Track bounding box for the column
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
line_bbox[assigned_type] = {
"left": word['left'],
"top": word['top'],
"right": word['left'] + word['width'],
"bottom": word['top'] + word['height'],
}
else:
bb = line_bbox[assigned_type]
bb['left'] = min(bb['left'], word['left'])
bb['top'] = min(bb['top'], word['top'])
bb['right'] = max(bb['right'], word['left'] + word['width'])
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
for col_type in ["english", "german", "example"]:
if line_words_by_col[col_type]:
entry[col_type] = " ".join(line_words_by_col[col_type])
if line_bbox.get(col_type):
bb = line_bbox[col_type]
entry[f"{col_type}_bbox"] = {
"x_pct": bb['left'] / image_width * 100,
"y_pct": bb['top'] / image_height * 100,
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
}
# Only add if at least one column has content
if entry["english"] or entry["german"]:
entries.append(entry)
return entries
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
image_w: int, image_h: int,
threshold: float = 0.6) -> List[dict]:
"""Match Tesseract bounding boxes to LLM vocabulary entries.
For each LLM vocab entry, find the best-matching Tesseract word
and attach its bounding box coordinates.
Args:
tess_words: Word list from Tesseract with pixel coordinates.
llm_vocab: Vocabulary list from Vision LLM.
image_w: Image width in pixels.
image_h: Image height in pixels.
threshold: Minimum similarity ratio for a match.
Returns:
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
"""
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
return llm_vocab
for entry in llm_vocab:
english = entry.get("english", "").lower().strip()
german = entry.get("german", "").lower().strip()
if not english and not german:
continue
# Try to match English word first, then German
for field in ["english", "german"]:
search_text = entry.get(field, "").lower().strip()
if not search_text:
continue
best_word = None
best_ratio = 0.0
for word in tess_words:
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_word = word
if best_word and best_ratio >= threshold:
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
entry["bbox_match_field"] = field
entry["bbox_match_ratio"] = round(best_ratio, 3)
break # Found a match, no need to try the other field
return llm_vocab
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
Args:
image_bytes: PNG/JPEG image as bytes.
lang: Tesseract language string.
Returns:
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
"""
# Step 1: Extract bounding boxes
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
if bbox_data.get("error"):
return bbox_data
words = bbox_data["words"]
image_w = bbox_data["image_width"]
image_h = bbox_data["image_height"]
# Step 2: Group into lines
lines = group_words_into_lines(words)
# Step 3: Detect columns
col_info = detect_columns(lines, image_w)
# Step 4: Build vocabulary entries
vocab = words_to_vocab_entries(
lines,
col_info["columns"],
col_info["column_types"],
image_w,
image_h,
)
return {
"vocabulary": vocab,
"words": words,
"lines_count": len(lines),
"columns": col_info["columns"],
"column_types": col_info["column_types"],
"image_width": image_w,
"image_height": image_h,
"word_count": len(words),
}

View File

@@ -0,0 +1,276 @@
"""
Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen.
Endpoints:
- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text
- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen
Modell-Strategie:
1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM)
2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
"""
import io
import os
import logging
import time
import base64
from typing import Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/htr", tags=["HTR"])
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large")
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class HTRSessionRequest(BaseModel):
session_id: str
model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large"
use_clean: bool = True # Prefer clean_png (after handwriting removal)
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray:
"""
CLAHE contrast enhancement + upscale to improve HTR accuracy.
Returns grayscale enhanced image.
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Upscale if image is too small
h, w = enhanced.shape
if min(h, w) < 800:
scale = 800 / min(h, w)
enhanced = cv2.resize(
enhanced, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC
)
return enhanced
def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes:
"""Convert BGR ndarray to PNG bytes."""
success, buf = cv2.imencode(".png", img_bgr)
if not success:
raise RuntimeError("Failed to encode image to PNG")
return buf.tobytes()
def _preprocess_image_bytes(image_bytes: bytes) -> bytes:
"""Load image, apply HTR preprocessing, return PNG bytes."""
arr = np.frombuffer(image_bytes, dtype=np.uint8)
img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise ValueError("Could not decode image")
enhanced = _preprocess_for_htr(img_bgr)
# Convert grayscale back to BGR for encoding
enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
return _bgr_to_png_bytes(enhanced_bgr)
# ---------------------------------------------------------------------------
# Backend: Ollama qwen2.5vl
# ---------------------------------------------------------------------------
async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]:
"""
Send image to Ollama qwen2.5vl:32b for HTR.
Returns extracted text or None on error.
"""
import httpx
lang_hint = {
"de": "Deutsch",
"en": "Englisch",
"de+en": "Deutsch und Englisch",
}.get(language, "Deutsch")
prompt = (
f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. "
"Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. "
"Antworte NUR mit dem erkannten Text, ohne Erklaerungen."
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": OLLAMA_HTR_MODEL,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload)
resp.raise_for_status()
data = resp.json()
return data.get("response", "").strip()
except Exception as e:
logger.warning(f"Ollama qwen2.5vl HTR failed: {e}")
return None
# ---------------------------------------------------------------------------
# Backend: TrOCR-large fallback
# ---------------------------------------------------------------------------
async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]:
"""
Use microsoft/trocr-large-handwritten via trocr_service.py.
Returns extracted text or None on error.
"""
try:
from services.trocr_service import run_trocr_ocr, _check_trocr_available
if not _check_trocr_available():
logger.warning("TrOCR not available for HTR fallback")
return None
text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large")
return text.strip() if text else None
except Exception as e:
logger.warning(f"TrOCR-large HTR failed: {e}")
return None
# ---------------------------------------------------------------------------
# Core recognition logic
# ---------------------------------------------------------------------------
async def _do_recognize(
image_bytes: bytes,
model: str = "auto",
preprocess: bool = True,
language: str = "de",
) -> dict:
"""
Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large.
Returns dict with text, model_used, processing_time_ms.
"""
t0 = time.monotonic()
if preprocess:
try:
image_bytes = _preprocess_image_bytes(image_bytes)
except Exception as e:
logger.warning(f"HTR preprocessing failed, using raw image: {e}")
text: Optional[str] = None
model_used: str = "none"
use_qwen = model in ("auto", "qwen2.5vl")
use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None)
if use_qwen:
text = await _recognize_with_qwen_vl(image_bytes, language)
if text is not None:
model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})"
if text is None and (use_trocr or model == "trocr-large"):
text = await _recognize_with_trocr_large(image_bytes)
if text is not None:
model_used = "trocr-large-handwritten"
if text is None:
text = ""
model_used = "none (all backends failed)"
elapsed_ms = int((time.monotonic() - t0) * 1000)
return {
"text": text,
"model_used": model_used,
"processing_time_ms": elapsed_ms,
"language": language,
"preprocessed": preprocess,
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/recognize")
async def recognize_handwriting(
file: UploadFile = File(...),
model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"),
preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"),
language: str = Query("de", description="de | en | de+en"),
):
"""
Upload an image and get back the handwritten text as plain text.
Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten.
"""
if model not in ("auto", "qwen2.5vl", "trocr-large"):
raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large")
if language not in ("de", "en", "de+en"):
raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en")
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Empty file")
return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language)
@router.post("/recognize-session")
async def recognize_from_session(req: HTRSessionRequest):
"""
Use an OCR-Pipeline session as image source for HTR.
Set use_clean=true to prefer the clean image (after handwriting removal step).
This is useful when you want to do HTR on isolated handwriting regions.
"""
from ocr_pipeline_session_store import get_session_db, get_session_image
session = await get_session_db(req.session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found")
# Choose source image
image_bytes: Optional[bytes] = None
source_used: str = ""
if req.use_clean:
image_bytes = await get_session_image(req.session_id, "clean")
if image_bytes:
source_used = "clean"
if not image_bytes:
image_bytes = await get_session_image(req.session_id, "deskewed")
if image_bytes:
source_used = "deskewed"
if not image_bytes:
image_bytes = await get_session_image(req.session_id, "original")
source_used = "original"
if not image_bytes:
raise HTTPException(status_code=404, detail="No image available in session")
result = await _do_recognize(image_bytes, model=req.model)
result["session_id"] = req.session_id
result["source_image"] = source_used
return result

View File

@@ -0,0 +1,7 @@
"""
OCR spell-checking sub-package — language-aware OCR correction.
Moved from backend/ flat modules (smart_spell*.py).
Backward-compatible shim files remain at the old locations.
"""
from .smart_spell import * # noqa: F401,F403

View File

@@ -0,0 +1,298 @@
"""
SmartSpellChecker Core — init, data types, language detection, word correction.
Extracted from smart_spell.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Init
# ---------------------------------------------------------------------------
try:
from spellchecker import SpellChecker as _SpellChecker
_en_spell = _SpellChecker(language='en', distance=1)
_de_spell = _SpellChecker(language='de', distance=1)
_AVAILABLE = True
except ImportError:
_AVAILABLE = False
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
Lang = Literal["en", "de", "both", "unknown"]
# ---------------------------------------------------------------------------
# Bigram context for a/I disambiguation
# ---------------------------------------------------------------------------
# Words that commonly follow "I" (subject pronoun -> verb/modal)
_I_FOLLOWERS: frozenset = frozenset({
"am", "was", "have", "had", "do", "did", "will", "would", "can",
"could", "should", "shall", "may", "might", "must",
"think", "know", "see", "want", "need", "like", "love", "hate",
"go", "went", "come", "came", "say", "said", "get", "got",
"make", "made", "take", "took", "give", "gave", "tell", "told",
"feel", "felt", "find", "found", "believe", "hope", "wish",
"remember", "forget", "understand", "mean", "meant",
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
"really", "just", "also", "always", "never", "often", "sometimes",
})
# Words that commonly follow "a" (article -> noun/adjective)
_A_FOLLOWERS: frozenset = frozenset({
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
"long", "short", "big", "small", "large", "huge", "tiny",
"nice", "beautiful", "wonderful", "terrible", "horrible",
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
"book", "car", "house", "room", "school", "teacher", "student",
"day", "week", "month", "year", "time", "place", "way",
"friend", "family", "person", "problem", "question", "story",
"very", "really", "quite", "rather", "pretty", "single",
})
# Digit->letter substitutions (OCR confusion)
_DIGIT_SUBS: Dict[str, List[str]] = {
'0': ['o', 'O'],
'1': ['l', 'I'],
'5': ['s', 'S'],
'6': ['g', 'G'],
'8': ['b', 'B'],
'|': ['I', 'l'],
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
}
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
_UMLAUT_MAP = {
'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
}
# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------
@dataclass
class CorrectionResult:
original: str
corrected: str
lang_detected: Lang
changed: bool
changes: List[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Core class — language detection and word-level correction
# ---------------------------------------------------------------------------
class _SmartSpellCoreBase:
"""Base class with language detection and single-word correction.
Not intended for direct use — SmartSpellChecker inherits from this.
"""
def __init__(self):
if not _AVAILABLE:
raise RuntimeError("pyspellchecker not installed")
self.en = _en_spell
self.de = _de_spell
# --- Language detection ---
def detect_word_lang(self, word: str) -> Lang:
"""Detect language of a single word using dual-dict heuristic."""
w = word.lower().strip(".,;:!?\"'()")
if not w:
return "unknown"
in_en = bool(self.en.known([w]))
in_de = bool(self.de.known([w]))
if in_en and in_de:
return "both"
if in_en:
return "en"
if in_de:
return "de"
return "unknown"
def detect_text_lang(self, text: str) -> Lang:
"""Detect dominant language of a text string (sentence/phrase)."""
words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
if not words:
return "unknown"
en_count = 0
de_count = 0
for w in words:
lang = self.detect_word_lang(w)
if lang == "en":
en_count += 1
elif lang == "de":
de_count += 1
# "both" doesn't count for either
if en_count > de_count:
return "en"
if de_count > en_count:
return "de"
if en_count == de_count and en_count > 0:
return "both"
return "unknown"
# --- Single-word correction ---
def _known(self, word: str) -> bool:
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
w = word.lower()
if bool(self.en.known([w])) or bool(self.de.known([w])):
return True
# Also accept known abbreviations (sth, sb, adj, etc.)
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
if w in _KNOWN_ABBREVIATIONS:
return True
except ImportError:
pass
return False
def _word_freq(self, word: str) -> float:
"""Get word frequency (max of EN and DE)."""
w = word.lower()
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
def _known_in(self, word: str, lang: str) -> bool:
"""True if word is known in a specific language dictionary."""
w = word.lower()
spell = self.en if lang == "en" else self.de
return bool(spell.known([w]))
def correct_word(self, word: str, lang: str = "en",
prev_word: str = "", next_word: str = "") -> Optional[str]:
"""Correct a single word for the given language.
Returns None if no correction needed, or the corrected string.
"""
if not word or not word.strip():
return None
# Skip numbers, abbreviations with dots, very short tokens
if word.isdigit() or '.' in word:
return None
# Skip IPA/phonetic content in brackets
if '[' in word or ']' in word:
return None
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
# 1. Already known -> no fix
if self._known(word):
# But check a/I disambiguation for single-char words
if word.lower() in ('l', '|') and next_word:
return self._disambiguate_a_I(word, next_word)
return None
# 2. Digit/pipe substitution
if has_suspicious:
if word == '|':
return 'I'
# Try single-char substitutions
for i, ch in enumerate(word):
if ch not in _DIGIT_SUBS:
continue
for replacement in _DIGIT_SUBS[ch]:
candidate = word[:i] + replacement + word[i + 1:]
if self._known(candidate):
return candidate
# Try multi-char substitution (e.g., "sch00l" -> "school")
multi = self._try_multi_digit_sub(word)
if multi:
return multi
# 3. Umlaut correction (German)
if lang == "de" and len(word) >= 3 and word.isalpha():
umlaut_fix = self._try_umlaut_fix(word)
if umlaut_fix:
return umlaut_fix
# 4. General spell correction
if not has_suspicious and len(word) >= 3 and word.isalpha():
# Safety: don't correct if the word is valid in the OTHER language
other_lang = "de" if lang == "en" else "en"
if self._known_in(word, other_lang):
return None
if other_lang == "de" and self._try_umlaut_fix(word):
return None # has a valid DE umlaut variant -> don't touch
spell = self.en if lang == "en" else self.de
correction = spell.correction(word.lower())
if correction and correction != word.lower():
if word[0].isupper():
correction = correction[0].upper() + correction[1:]
if self._known(correction):
return correction
return None
# --- Multi-digit substitution ---
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
"""Try replacing multiple digits simultaneously using BFS."""
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
if not positions or len(positions) > 4:
return None
# BFS over substitution combinations
queue = [list(word)]
for pos, ch in positions:
next_queue = []
for current in queue:
# Keep original
next_queue.append(current[:])
# Try each substitution
for repl in _DIGIT_SUBS[ch]:
variant = current[:]
variant[pos] = repl
next_queue.append(variant)
queue = next_queue
# Check which combinations produce known words
for combo in queue:
candidate = "".join(combo)
if candidate != word and self._known(candidate):
return candidate
return None
# --- Umlaut fix ---
def _try_umlaut_fix(self, word: str) -> Optional[str]:
"""Try single-char umlaut substitutions for German words."""
for i, ch in enumerate(word):
if ch in _UMLAUT_MAP:
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
if self._known(candidate):
return candidate
return None
# --- a/I disambiguation ---
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
nw = next_word.lower().strip(".,;:!?")
if nw in _I_FOLLOWERS:
return "I"
if nw in _A_FOLLOWERS:
return "a"
return None # uncertain, don't change

View File

@@ -0,0 +1,25 @@
"""
SmartSpellChecker — barrel re-export.
All implementation split into:
smart_spell_core — init, data types, language detection, word correction
smart_spell_text — full text correction, boundary repair, context split
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
# Core: data types, lang detection (re-exported for tests)
from .core import ( # noqa: F401
_AVAILABLE,
_DIGIT_SUBS,
_SUSPICIOUS_CHARS,
_UMLAUT_MAP,
_TOKEN_RE,
_I_FOLLOWERS,
_A_FOLLOWERS,
CorrectionResult,
Lang,
)
# Text: SmartSpellChecker class (the main public API)
from .text import SmartSpellChecker # noqa: F401

View File

@@ -0,0 +1,289 @@
"""
SmartSpellChecker Text — full text correction, boundary repair, context split.
Extracted from smart_spell.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import re
from typing import Dict, List, Optional, Tuple
from .core import (
_SmartSpellCoreBase,
_TOKEN_RE,
CorrectionResult,
Lang,
)
class SmartSpellChecker(_SmartSpellCoreBase):
"""Language-aware OCR spell checker using pyspellchecker (no LLM).
Inherits single-word correction from _SmartSpellCoreBase.
Adds text-level passes: boundary repair, context split, full correction.
"""
# --- Boundary repair (shifted word boundaries) ---
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
"""Fix shifted word boundaries between adjacent tokens.
OCR sometimes shifts the boundary: "at sth." -> "ats th."
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
Returns (fixed_word1, fixed_word2) or None.
"""
# Import known abbreviations for vocabulary context
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
except ImportError:
_KNOWN_ABBREVIATIONS = set()
# Strip trailing punctuation for checking, preserve for result
w2_stripped = word2.rstrip(".,;:!?")
w2_punct = word2[len(w2_stripped):]
# Try shifting 1-2 chars from word1 -> word2
for shift in (1, 2):
if len(word1) <= shift:
continue
new_w1 = word1[:-shift]
new_w2_base = word1[-shift:] + w2_stripped
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
# Try shifting 1-2 chars from word2 -> word1
for shift in (1, 2):
if len(w2_stripped) <= shift:
continue
new_w1 = word1 + w2_stripped[:shift]
new_w2_base = w2_stripped[shift:]
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
return None
# --- Context-based word split for ambiguous merges ---
# Patterns where a valid word is actually "a" + adjective/noun
_ARTICLE_SPLIT_CANDIDATES = {
# word -> (article, remainder) -- only when followed by a compatible word
"anew": ("a", "new"),
"areal": ("a", "real"),
"alive": None, # genuinely one word, never split
"alone": None,
"aware": None,
"alike": None,
"apart": None,
"aside": None,
"above": None,
"about": None,
"among": None,
"along": None,
}
def _try_context_split(self, word: str, next_word: str,
prev_word: str) -> Optional[str]:
"""Split words like 'anew' -> 'a new' when context indicates a merge.
Only splits when:
- The word is in the split candidates list
- The following word makes sense as a noun (for "a + adj + noun" pattern)
- OR the word is unknown and can be split into article + known word
"""
w_lower = word.lower()
# Check explicit candidates
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
if split is None:
return None # explicitly marked as "don't split"
article, remainder = split
# Only split if followed by a word (noun pattern)
if next_word and next_word[0].islower():
return f"{article} {remainder}"
# Also split if remainder + next_word makes a common phrase
if next_word and self._known(next_word):
return f"{article} {remainder}"
# Generic: if word starts with 'a' and rest is a known adjective/word
if (len(word) >= 4 and word[0].lower() == 'a'
and not self._known(word) # only for UNKNOWN words
and self._known(word[1:])):
return f"a {word[1:]}"
return None
# --- Full text correction ---
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
"""Correct a full text string (field value).
Three passes:
1. Boundary repair -- fix shifted word boundaries between adjacent tokens
2. Context split -- split ambiguous merges (anew -> a new)
3. Per-word correction -- spell check individual words
"""
if not text or not text.strip():
return CorrectionResult(text, text, "unknown", False)
detected = self.detect_text_lang(text) if lang == "auto" else lang
effective_lang = detected if detected in ("en", "de") else "en"
changes: List[str] = []
tokens = list(_TOKEN_RE.finditer(text))
# Extract token list: [(word, separator), ...]
token_list: List[List[str]] = [] # [[word, sep], ...]
for m in tokens:
token_list.append([m.group(1), m.group(2)])
# --- Pass 1: Boundary repair between adjacent unknown words ---
# Import abbreviations for the heuristic below
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
except ImportError:
_ABBREVS = set()
for i in range(len(token_list) - 1):
w1 = token_list[i][0]
w2_raw = token_list[i + 1][0]
# Skip boundary repair for IPA/bracket content
# Brackets may be in the token OR in the adjacent separators
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
sep_after_w1 = token_list[i][1]
sep_after_w2 = token_list[i + 1][1]
has_bracket = (
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
or ']' in sep_after_w1 # w1 text was inside [brackets]
or '[' in sep_after_w1 # w2 starts a bracket
or ']' in sep_after_w2 # w2 text was inside [brackets]
or '[' in sep_before_w1 # w1 starts a bracket
)
if has_bracket:
continue
# Include trailing punct from separator in w2 for abbreviation matching
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
# Try boundary repair -- always, even if both words are valid.
# Use word-frequency scoring to decide if repair is better.
repair = self._try_boundary_repair(w1, w2_with_punct)
if not repair and w2_with_punct != w2_raw:
repair = self._try_boundary_repair(w1, w2_raw)
if repair:
new_w1, new_w2_full = repair
new_w2_base = new_w2_full.rstrip(".,;:!?")
# Frequency-based scoring: product of word frequencies
# Higher product = more common word pair = better
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
# Abbreviation bonus: if repair produces a known abbreviation
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
if has_abbrev:
# Accept abbreviation repair ONLY if at least one of the
# original words is rare/unknown (prevents "Can I" -> "Ca nI"
# where both original words are common and correct).
RARE_THRESHOLD = 1e-6
orig_both_common = (
self._word_freq(w1) > RARE_THRESHOLD
and self._word_freq(w2_raw) > RARE_THRESHOLD
)
if not orig_both_common:
new_freq = max(new_freq, old_freq * 10)
else:
has_abbrev = False # both originals common -> don't trust
# Accept if repair produces a more frequent word pair
# (threshold: at least 5x more frequent to avoid false positives)
if new_freq > old_freq * 5:
new_w2_punct = new_w2_full[len(new_w2_base):]
changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
token_list[i][0] = new_w1
token_list[i + 1][0] = new_w2_base
if new_w2_punct:
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
# --- Pass 2: Context split (anew -> a new) ---
expanded: List[List[str]] = []
for i, (word, sep) in enumerate(token_list):
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
split = self._try_context_split(word, next_word, prev_word)
if split and split != word:
changes.append(f"{word}\u2192{split}")
expanded.append([split, sep])
else:
expanded.append([word, sep])
token_list = expanded
# --- Pass 3: Per-word correction ---
parts: List[str] = []
# Preserve any leading text before the first token match
first_start = tokens[0].start() if tokens else 0
if first_start > 0:
parts.append(text[:first_start])
for i, (word, sep) in enumerate(token_list):
# Skip words inside IPA brackets (brackets land in separators)
prev_sep = token_list[i - 1][1] if i > 0 else ""
if '[' in prev_sep or ']' in sep:
parts.append(word)
parts.append(sep)
continue
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
correction = self.correct_word(
word, lang=effective_lang,
prev_word=prev_word, next_word=next_word,
)
if correction and correction != word:
changes.append(f"{word}\u2192{correction}")
parts.append(correction)
else:
parts.append(word)
parts.append(sep)
# Append any trailing text
last_end = tokens[-1].end() if tokens else 0
if last_end < len(text):
parts.append(text[last_end:])
corrected = "".join(parts)
return CorrectionResult(
original=text,
corrected=corrected,
lang_detected=detected,
changed=corrected != text,
changes=changes,
)
# --- Vocabulary entry correction ---
def correct_vocab_entry(self, english: str, german: str,
example: str = "") -> Dict[str, CorrectionResult]:
"""Correct a full vocabulary entry (EN + DE + example).
Uses column position to determine language -- the most reliable signal.
"""
results = {}
results["english"] = self.correct_text(english, lang="en")
results["german"] = self.correct_text(german, lang="de")
if example:
# For examples, auto-detect language
results["example"] = self.correct_text(example, lang="auto")
return results

View File

@@ -1,25 +1,4 @@
"""
SmartSpellChecker — barrel re-export.
All implementation split into:
smart_spell_core — init, data types, language detection, word correction
smart_spell_text — full text correction, boundary repair, context split
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
# Core: data types, lang detection (re-exported for tests)
from smart_spell_core import ( # noqa: F401
_AVAILABLE,
_DIGIT_SUBS,
_SUSPICIOUS_CHARS,
_UMLAUT_MAP,
_TOKEN_RE,
_I_FOLLOWERS,
_A_FOLLOWERS,
CorrectionResult,
Lang,
)
# Text: SmartSpellChecker class (the main public API)
from smart_spell_text import SmartSpellChecker # noqa: F401
# Backward-compat shim -- module moved to ocr/spell/smart_spell.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.spell.smart_spell")

View File

@@ -1,298 +1,4 @@
"""
SmartSpellChecker Core — init, data types, language detection, word correction.
Extracted from smart_spell.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Init
# ---------------------------------------------------------------------------
try:
from spellchecker import SpellChecker as _SpellChecker
_en_spell = _SpellChecker(language='en', distance=1)
_de_spell = _SpellChecker(language='de', distance=1)
_AVAILABLE = True
except ImportError:
_AVAILABLE = False
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
Lang = Literal["en", "de", "both", "unknown"]
# ---------------------------------------------------------------------------
# Bigram context for a/I disambiguation
# ---------------------------------------------------------------------------
# Words that commonly follow "I" (subject pronoun -> verb/modal)
_I_FOLLOWERS: frozenset = frozenset({
"am", "was", "have", "had", "do", "did", "will", "would", "can",
"could", "should", "shall", "may", "might", "must",
"think", "know", "see", "want", "need", "like", "love", "hate",
"go", "went", "come", "came", "say", "said", "get", "got",
"make", "made", "take", "took", "give", "gave", "tell", "told",
"feel", "felt", "find", "found", "believe", "hope", "wish",
"remember", "forget", "understand", "mean", "meant",
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
"really", "just", "also", "always", "never", "often", "sometimes",
})
# Words that commonly follow "a" (article -> noun/adjective)
_A_FOLLOWERS: frozenset = frozenset({
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
"long", "short", "big", "small", "large", "huge", "tiny",
"nice", "beautiful", "wonderful", "terrible", "horrible",
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
"book", "car", "house", "room", "school", "teacher", "student",
"day", "week", "month", "year", "time", "place", "way",
"friend", "family", "person", "problem", "question", "story",
"very", "really", "quite", "rather", "pretty", "single",
})
# Digit->letter substitutions (OCR confusion)
_DIGIT_SUBS: Dict[str, List[str]] = {
'0': ['o', 'O'],
'1': ['l', 'I'],
'5': ['s', 'S'],
'6': ['g', 'G'],
'8': ['b', 'B'],
'|': ['I', 'l'],
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
}
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
_UMLAUT_MAP = {
'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
}
# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------
@dataclass
class CorrectionResult:
original: str
corrected: str
lang_detected: Lang
changed: bool
changes: List[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Core class — language detection and word-level correction
# ---------------------------------------------------------------------------
class _SmartSpellCoreBase:
"""Base class with language detection and single-word correction.
Not intended for direct use — SmartSpellChecker inherits from this.
"""
def __init__(self):
if not _AVAILABLE:
raise RuntimeError("pyspellchecker not installed")
self.en = _en_spell
self.de = _de_spell
# --- Language detection ---
def detect_word_lang(self, word: str) -> Lang:
"""Detect language of a single word using dual-dict heuristic."""
w = word.lower().strip(".,;:!?\"'()")
if not w:
return "unknown"
in_en = bool(self.en.known([w]))
in_de = bool(self.de.known([w]))
if in_en and in_de:
return "both"
if in_en:
return "en"
if in_de:
return "de"
return "unknown"
def detect_text_lang(self, text: str) -> Lang:
"""Detect dominant language of a text string (sentence/phrase)."""
words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
if not words:
return "unknown"
en_count = 0
de_count = 0
for w in words:
lang = self.detect_word_lang(w)
if lang == "en":
en_count += 1
elif lang == "de":
de_count += 1
# "both" doesn't count for either
if en_count > de_count:
return "en"
if de_count > en_count:
return "de"
if en_count == de_count and en_count > 0:
return "both"
return "unknown"
# --- Single-word correction ---
def _known(self, word: str) -> bool:
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
w = word.lower()
if bool(self.en.known([w])) or bool(self.de.known([w])):
return True
# Also accept known abbreviations (sth, sb, adj, etc.)
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
if w in _KNOWN_ABBREVIATIONS:
return True
except ImportError:
pass
return False
def _word_freq(self, word: str) -> float:
"""Get word frequency (max of EN and DE)."""
w = word.lower()
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
def _known_in(self, word: str, lang: str) -> bool:
"""True if word is known in a specific language dictionary."""
w = word.lower()
spell = self.en if lang == "en" else self.de
return bool(spell.known([w]))
def correct_word(self, word: str, lang: str = "en",
prev_word: str = "", next_word: str = "") -> Optional[str]:
"""Correct a single word for the given language.
Returns None if no correction needed, or the corrected string.
"""
if not word or not word.strip():
return None
# Skip numbers, abbreviations with dots, very short tokens
if word.isdigit() or '.' in word:
return None
# Skip IPA/phonetic content in brackets
if '[' in word or ']' in word:
return None
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
# 1. Already known -> no fix
if self._known(word):
# But check a/I disambiguation for single-char words
if word.lower() in ('l', '|') and next_word:
return self._disambiguate_a_I(word, next_word)
return None
# 2. Digit/pipe substitution
if has_suspicious:
if word == '|':
return 'I'
# Try single-char substitutions
for i, ch in enumerate(word):
if ch not in _DIGIT_SUBS:
continue
for replacement in _DIGIT_SUBS[ch]:
candidate = word[:i] + replacement + word[i + 1:]
if self._known(candidate):
return candidate
# Try multi-char substitution (e.g., "sch00l" -> "school")
multi = self._try_multi_digit_sub(word)
if multi:
return multi
# 3. Umlaut correction (German)
if lang == "de" and len(word) >= 3 and word.isalpha():
umlaut_fix = self._try_umlaut_fix(word)
if umlaut_fix:
return umlaut_fix
# 4. General spell correction
if not has_suspicious and len(word) >= 3 and word.isalpha():
# Safety: don't correct if the word is valid in the OTHER language
other_lang = "de" if lang == "en" else "en"
if self._known_in(word, other_lang):
return None
if other_lang == "de" and self._try_umlaut_fix(word):
return None # has a valid DE umlaut variant -> don't touch
spell = self.en if lang == "en" else self.de
correction = spell.correction(word.lower())
if correction and correction != word.lower():
if word[0].isupper():
correction = correction[0].upper() + correction[1:]
if self._known(correction):
return correction
return None
# --- Multi-digit substitution ---
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
"""Try replacing multiple digits simultaneously using BFS."""
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
if not positions or len(positions) > 4:
return None
# BFS over substitution combinations
queue = [list(word)]
for pos, ch in positions:
next_queue = []
for current in queue:
# Keep original
next_queue.append(current[:])
# Try each substitution
for repl in _DIGIT_SUBS[ch]:
variant = current[:]
variant[pos] = repl
next_queue.append(variant)
queue = next_queue
# Check which combinations produce known words
for combo in queue:
candidate = "".join(combo)
if candidate != word and self._known(candidate):
return candidate
return None
# --- Umlaut fix ---
def _try_umlaut_fix(self, word: str) -> Optional[str]:
"""Try single-char umlaut substitutions for German words."""
for i, ch in enumerate(word):
if ch in _UMLAUT_MAP:
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
if self._known(candidate):
return candidate
return None
# --- a/I disambiguation ---
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
nw = next_word.lower().strip(".,;:!?")
if nw in _I_FOLLOWERS:
return "I"
if nw in _A_FOLLOWERS:
return "a"
return None # uncertain, don't change
# Backward-compat shim -- module moved to ocr/spell/core.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.spell.core")

View File

@@ -1,289 +1,4 @@
"""
SmartSpellChecker Text — full text correction, boundary repair, context split.
Extracted from smart_spell.py for modularity.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import re
from typing import Dict, List, Optional, Tuple
from smart_spell_core import (
_SmartSpellCoreBase,
_TOKEN_RE,
CorrectionResult,
Lang,
)
class SmartSpellChecker(_SmartSpellCoreBase):
"""Language-aware OCR spell checker using pyspellchecker (no LLM).
Inherits single-word correction from _SmartSpellCoreBase.
Adds text-level passes: boundary repair, context split, full correction.
"""
# --- Boundary repair (shifted word boundaries) ---
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
"""Fix shifted word boundaries between adjacent tokens.
OCR sometimes shifts the boundary: "at sth." -> "ats th."
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
Returns (fixed_word1, fixed_word2) or None.
"""
# Import known abbreviations for vocabulary context
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
except ImportError:
_KNOWN_ABBREVIATIONS = set()
# Strip trailing punctuation for checking, preserve for result
w2_stripped = word2.rstrip(".,;:!?")
w2_punct = word2[len(w2_stripped):]
# Try shifting 1-2 chars from word1 -> word2
for shift in (1, 2):
if len(word1) <= shift:
continue
new_w1 = word1[:-shift]
new_w2_base = word1[-shift:] + w2_stripped
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
# Try shifting 1-2 chars from word2 -> word1
for shift in (1, 2):
if len(w2_stripped) <= shift:
continue
new_w1 = word1 + w2_stripped[:shift]
new_w2_base = w2_stripped[shift:]
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
if w1_ok and w2_ok:
return (new_w1, new_w2_base + w2_punct)
return None
# --- Context-based word split for ambiguous merges ---
# Patterns where a valid word is actually "a" + adjective/noun
_ARTICLE_SPLIT_CANDIDATES = {
# word -> (article, remainder) -- only when followed by a compatible word
"anew": ("a", "new"),
"areal": ("a", "real"),
"alive": None, # genuinely one word, never split
"alone": None,
"aware": None,
"alike": None,
"apart": None,
"aside": None,
"above": None,
"about": None,
"among": None,
"along": None,
}
def _try_context_split(self, word: str, next_word: str,
prev_word: str) -> Optional[str]:
"""Split words like 'anew' -> 'a new' when context indicates a merge.
Only splits when:
- The word is in the split candidates list
- The following word makes sense as a noun (for "a + adj + noun" pattern)
- OR the word is unknown and can be split into article + known word
"""
w_lower = word.lower()
# Check explicit candidates
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
if split is None:
return None # explicitly marked as "don't split"
article, remainder = split
# Only split if followed by a word (noun pattern)
if next_word and next_word[0].islower():
return f"{article} {remainder}"
# Also split if remainder + next_word makes a common phrase
if next_word and self._known(next_word):
return f"{article} {remainder}"
# Generic: if word starts with 'a' and rest is a known adjective/word
if (len(word) >= 4 and word[0].lower() == 'a'
and not self._known(word) # only for UNKNOWN words
and self._known(word[1:])):
return f"a {word[1:]}"
return None
# --- Full text correction ---
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
"""Correct a full text string (field value).
Three passes:
1. Boundary repair -- fix shifted word boundaries between adjacent tokens
2. Context split -- split ambiguous merges (anew -> a new)
3. Per-word correction -- spell check individual words
"""
if not text or not text.strip():
return CorrectionResult(text, text, "unknown", False)
detected = self.detect_text_lang(text) if lang == "auto" else lang
effective_lang = detected if detected in ("en", "de") else "en"
changes: List[str] = []
tokens = list(_TOKEN_RE.finditer(text))
# Extract token list: [(word, separator), ...]
token_list: List[List[str]] = [] # [[word, sep], ...]
for m in tokens:
token_list.append([m.group(1), m.group(2)])
# --- Pass 1: Boundary repair between adjacent unknown words ---
# Import abbreviations for the heuristic below
try:
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
except ImportError:
_ABBREVS = set()
for i in range(len(token_list) - 1):
w1 = token_list[i][0]
w2_raw = token_list[i + 1][0]
# Skip boundary repair for IPA/bracket content
# Brackets may be in the token OR in the adjacent separators
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
sep_after_w1 = token_list[i][1]
sep_after_w2 = token_list[i + 1][1]
has_bracket = (
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
or ']' in sep_after_w1 # w1 text was inside [brackets]
or '[' in sep_after_w1 # w2 starts a bracket
or ']' in sep_after_w2 # w2 text was inside [brackets]
or '[' in sep_before_w1 # w1 starts a bracket
)
if has_bracket:
continue
# Include trailing punct from separator in w2 for abbreviation matching
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
# Try boundary repair -- always, even if both words are valid.
# Use word-frequency scoring to decide if repair is better.
repair = self._try_boundary_repair(w1, w2_with_punct)
if not repair and w2_with_punct != w2_raw:
repair = self._try_boundary_repair(w1, w2_raw)
if repair:
new_w1, new_w2_full = repair
new_w2_base = new_w2_full.rstrip(".,;:!?")
# Frequency-based scoring: product of word frequencies
# Higher product = more common word pair = better
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
# Abbreviation bonus: if repair produces a known abbreviation
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
if has_abbrev:
# Accept abbreviation repair ONLY if at least one of the
# original words is rare/unknown (prevents "Can I" -> "Ca nI"
# where both original words are common and correct).
RARE_THRESHOLD = 1e-6
orig_both_common = (
self._word_freq(w1) > RARE_THRESHOLD
and self._word_freq(w2_raw) > RARE_THRESHOLD
)
if not orig_both_common:
new_freq = max(new_freq, old_freq * 10)
else:
has_abbrev = False # both originals common -> don't trust
# Accept if repair produces a more frequent word pair
# (threshold: at least 5x more frequent to avoid false positives)
if new_freq > old_freq * 5:
new_w2_punct = new_w2_full[len(new_w2_base):]
changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
token_list[i][0] = new_w1
token_list[i + 1][0] = new_w2_base
if new_w2_punct:
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
# --- Pass 2: Context split (anew -> a new) ---
expanded: List[List[str]] = []
for i, (word, sep) in enumerate(token_list):
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
split = self._try_context_split(word, next_word, prev_word)
if split and split != word:
changes.append(f"{word}\u2192{split}")
expanded.append([split, sep])
else:
expanded.append([word, sep])
token_list = expanded
# --- Pass 3: Per-word correction ---
parts: List[str] = []
# Preserve any leading text before the first token match
first_start = tokens[0].start() if tokens else 0
if first_start > 0:
parts.append(text[:first_start])
for i, (word, sep) in enumerate(token_list):
# Skip words inside IPA brackets (brackets land in separators)
prev_sep = token_list[i - 1][1] if i > 0 else ""
if '[' in prev_sep or ']' in sep:
parts.append(word)
parts.append(sep)
continue
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
prev_word = token_list[i - 1][0] if i > 0 else ""
correction = self.correct_word(
word, lang=effective_lang,
prev_word=prev_word, next_word=next_word,
)
if correction and correction != word:
changes.append(f"{word}\u2192{correction}")
parts.append(correction)
else:
parts.append(word)
parts.append(sep)
# Append any trailing text
last_end = tokens[-1].end() if tokens else 0
if last_end < len(text):
parts.append(text[last_end:])
corrected = "".join(parts)
return CorrectionResult(
original=text,
corrected=corrected,
lang_detected=detected,
changed=corrected != text,
changes=changes,
)
# --- Vocabulary entry correction ---
def correct_vocab_entry(self, english: str, german: str,
example: str = "") -> Dict[str, CorrectionResult]:
"""Correct a full vocabulary entry (EN + DE + example).
Uses column position to determine language -- the most reliable signal.
"""
results = {}
results["english"] = self.correct_text(english, lang="en")
results["german"] = self.correct_text(german, lang="de")
if example:
# For examples, auto-detect language
results["example"] = self.correct_text(example, lang="auto")
return results
# Backward-compat shim -- module moved to ocr/spell/text.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.spell.text")

View File

@@ -1,346 +1,4 @@
"""
Tesseract-based OCR extraction with word-level bounding boxes.
Uses Tesseract for spatial information (WHERE text is) while
the Vision LLM handles semantic understanding (WHAT the text means).
Tesseract runs natively on ARM64 via Debian's apt package.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import io
import logging
from typing import List, Dict, Any, Optional
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
try:
import pytesseract
from PIL import Image
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
"""Run Tesseract OCR and return word-level bounding boxes.
Args:
image_bytes: PNG/JPEG image as bytes.
lang: Tesseract language string (e.g. "eng+deu").
Returns:
Dict with 'words' list and 'image_width'/'image_height'.
"""
if not TESSERACT_AVAILABLE:
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
image = Image.open(io.BytesIO(image_bytes))
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
words = []
for i in range(len(data['text'])):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if not text or conf < 20:
continue
words.append({
"text": text,
"left": data['left'][i],
"top": data['top'][i],
"width": data['width'][i],
"height": data['height'][i],
"conf": conf,
"block_num": data['block_num'][i],
"par_num": data['par_num'][i],
"line_num": data['line_num'][i],
"word_num": data['word_num'][i],
})
return {
"words": words,
"image_width": image.width,
"image_height": image.height,
}
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
"""Group words by their Y position into lines.
Args:
words: List of word dicts from extract_bounding_boxes.
y_tolerance_px: Max pixel distance to consider words on the same line.
Returns:
List of lines, each line is a list of words sorted by X position.
"""
if not words:
return []
# Sort by Y then X
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
lines: List[List[dict]] = []
current_line: List[dict] = [sorted_words[0]]
current_y = sorted_words[0]['top']
for word in sorted_words[1:]:
if abs(word['top'] - current_y) <= y_tolerance_px:
current_line.append(word)
else:
current_line.sort(key=lambda w: w['left'])
lines.append(current_line)
current_line = [word]
current_y = word['top']
if current_line:
current_line.sort(key=lambda w: w['left'])
lines.append(current_line)
return lines
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
"""Detect column boundaries from word positions.
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
Returns:
Dict with column boundaries and type assignments.
"""
if not lines or image_width == 0:
return {"columns": [], "column_types": []}
# Collect all word X positions
all_x_positions = []
for line in lines:
for word in line:
all_x_positions.append(word['left'])
if not all_x_positions:
return {"columns": [], "column_types": []}
# Find X-position clusters (column starts)
all_x_positions.sort()
# Simple gap-based column detection
min_gap = image_width * 0.08 # 8% of page width = column gap
clusters = []
current_cluster = [all_x_positions[0]]
for x in all_x_positions[1:]:
if x - current_cluster[-1] > min_gap:
clusters.append(current_cluster)
current_cluster = [x]
else:
current_cluster.append(x)
if current_cluster:
clusters.append(current_cluster)
# Each cluster represents a column start
columns = []
for cluster in clusters:
col_start = min(cluster)
columns.append({
"x_start": col_start,
"x_start_pct": col_start / image_width * 100,
"word_count": len(cluster),
})
# Assign column types based on position (left→right: EN, DE, Example)
type_map = ["english", "german", "example"]
column_types = []
for i, col in enumerate(columns):
if i < len(type_map):
column_types.append(type_map[i])
else:
column_types.append("unknown")
return {
"columns": columns,
"column_types": column_types,
}
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
column_types: List[str], image_width: int,
image_height: int) -> List[dict]:
"""Convert grouped words into vocabulary entries using column positions.
Args:
lines: Grouped word lines from group_words_into_lines.
columns: Column boundaries from detect_columns.
column_types: Column type assignments.
image_width: Image width in pixels.
image_height: Image height in pixels.
Returns:
List of vocabulary entry dicts with english/german/example fields.
"""
if not columns or not lines:
return []
# Build column boundaries for word assignment
col_boundaries = []
for i, col in enumerate(columns):
start = col['x_start']
if i + 1 < len(columns):
end = columns[i + 1]['x_start']
else:
end = image_width
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
entries = []
for line in lines:
entry = {"english": "", "german": "", "example": ""}
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
line_bbox: Dict[str, Optional[dict]] = {}
for word in line:
word_center_x = word['left'] + word['width'] / 2
assigned_type = "unknown"
for start, end, col_type in col_boundaries:
if start <= word_center_x < end:
assigned_type = col_type
break
if assigned_type in line_words_by_col:
line_words_by_col[assigned_type].append(word['text'])
# Track bounding box for the column
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
line_bbox[assigned_type] = {
"left": word['left'],
"top": word['top'],
"right": word['left'] + word['width'],
"bottom": word['top'] + word['height'],
}
else:
bb = line_bbox[assigned_type]
bb['left'] = min(bb['left'], word['left'])
bb['top'] = min(bb['top'], word['top'])
bb['right'] = max(bb['right'], word['left'] + word['width'])
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
for col_type in ["english", "german", "example"]:
if line_words_by_col[col_type]:
entry[col_type] = " ".join(line_words_by_col[col_type])
if line_bbox.get(col_type):
bb = line_bbox[col_type]
entry[f"{col_type}_bbox"] = {
"x_pct": bb['left'] / image_width * 100,
"y_pct": bb['top'] / image_height * 100,
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
}
# Only add if at least one column has content
if entry["english"] or entry["german"]:
entries.append(entry)
return entries
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
image_w: int, image_h: int,
threshold: float = 0.6) -> List[dict]:
"""Match Tesseract bounding boxes to LLM vocabulary entries.
For each LLM vocab entry, find the best-matching Tesseract word
and attach its bounding box coordinates.
Args:
tess_words: Word list from Tesseract with pixel coordinates.
llm_vocab: Vocabulary list from Vision LLM.
image_w: Image width in pixels.
image_h: Image height in pixels.
threshold: Minimum similarity ratio for a match.
Returns:
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
"""
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
return llm_vocab
for entry in llm_vocab:
english = entry.get("english", "").lower().strip()
german = entry.get("german", "").lower().strip()
if not english and not german:
continue
# Try to match English word first, then German
for field in ["english", "german"]:
search_text = entry.get(field, "").lower().strip()
if not search_text:
continue
best_word = None
best_ratio = 0.0
for word in tess_words:
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_word = word
if best_word and best_ratio >= threshold:
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
entry["bbox_match_field"] = field
entry["bbox_match_ratio"] = round(best_ratio, 3)
break # Found a match, no need to try the other field
return llm_vocab
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
Args:
image_bytes: PNG/JPEG image as bytes.
lang: Tesseract language string.
Returns:
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
"""
# Step 1: Extract bounding boxes
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
if bbox_data.get("error"):
return bbox_data
words = bbox_data["words"]
image_w = bbox_data["image_width"]
image_h = bbox_data["image_height"]
# Step 2: Group into lines
lines = group_words_into_lines(words)
# Step 3: Detect columns
col_info = detect_columns(lines, image_w)
# Step 4: Build vocabulary entries
vocab = words_to_vocab_entries(
lines,
col_info["columns"],
col_info["column_types"],
image_w,
image_h,
)
return {
"vocabulary": vocab,
"words": words,
"lines_count": len(lines),
"columns": col_info["columns"],
"column_types": col_info["column_types"],
"image_width": image_w,
"image_height": image_h,
"word_count": len(words),
}
# Backward-compat shim -- module moved to ocr/engines/tesseract_extractor.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.engines.tesseract_extractor")

View File

@@ -1,425 +1,4 @@
"""
Unified Grid Builder — merges multi-zone grid into a single Excel-like grid.
Takes content zone + box zones and produces one unified zone where:
- All content rows use the dominant row height
- Full-width boxes are integrated directly (box rows replace standard rows)
- Partial-width boxes: extra rows inserted if box has more lines than standard
- Box-origin cells carry metadata (bg_color, border) for visual distinction
The result is a single-zone StructuredGrid that can be:
- Rendered in an Excel-like editor
- Exported to Excel/CSV
- Edited with unified row/column numbering
"""
import logging
import math
import statistics
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def _compute_dominant_row_height(content_zone: Dict) -> float:
"""Median of content row-to-row spacings, excluding box-gap jumps."""
rows = content_zone.get("rows", [])
if len(rows) < 2:
return 47.0
spacings = []
for i in range(len(rows) - 1):
y1 = rows[i].get("y_min_px", rows[i].get("y_min", 0))
y2 = rows[i + 1].get("y_min_px", rows[i + 1].get("y_min", 0))
d = y2 - y1
if 0 < d < 100: # exclude box-gap jumps
spacings.append(d)
if not spacings:
return 47.0
spacings.sort()
return spacings[len(spacings) // 2]
def _classify_boxes(
box_zones: List[Dict],
content_width: float,
) -> List[Dict]:
"""Classify each box as full_width or partial_width."""
result = []
for bz in box_zones:
bb = bz.get("bbox_px", {})
bw = bb.get("w", 0)
bx = bb.get("x", 0)
if bw >= content_width * 0.85:
classification = "full_width"
side = "center"
else:
classification = "partial_width"
# Determine which side of the page the box is on
page_center = content_width / 2
box_center = bx + bw / 2
side = "right" if box_center > page_center else "left"
# Count total text lines in box (including \n within cells)
total_lines = sum(
(c.get("text", "").count("\n") + 1)
for c in bz.get("cells", [])
)
result.append({
"zone": bz,
"classification": classification,
"side": side,
"y_start": bb.get("y", 0),
"y_end": bb.get("y", 0) + bb.get("h", 0),
"total_lines": total_lines,
"bg_hex": bz.get("box_bg_hex", ""),
"bg_color": bz.get("box_bg_color", ""),
})
return result
def build_unified_grid(
zones: List[Dict],
image_width: int,
image_height: int,
layout_metrics: Dict,
) -> Dict[str, Any]:
"""Build a single-zone unified grid from multi-zone grid data.
Returns a StructuredGrid with one zone containing all rows and cells.
"""
content_zone = None
box_zones = []
for z in zones:
if z.get("zone_type") == "content":
content_zone = z
elif z.get("zone_type") == "box":
box_zones.append(z)
if not content_zone:
logger.warning("build_unified_grid: no content zone found")
return {"zones": zones} # fallback: return as-is
box_zones.sort(key=lambda b: b.get("bbox_px", {}).get("y", 0))
dominant_h = _compute_dominant_row_height(content_zone)
content_bbox = content_zone.get("bbox_px", {})
content_width = content_bbox.get("w", image_width)
content_x = content_bbox.get("x", 0)
content_cols = content_zone.get("columns", [])
num_cols = len(content_cols)
box_infos = _classify_boxes(box_zones, content_width)
logger.info(
"build_unified_grid: dominant_h=%.1f, %d content rows, %d boxes (%s)",
dominant_h, len(content_zone.get("rows", [])), len(box_infos),
[b["classification"] for b in box_infos],
)
# --- Build unified row list + cell list ---
unified_rows: List[Dict] = []
unified_cells: List[Dict] = []
unified_row_idx = 0
# Content rows and cells indexed by row_index
content_rows = content_zone.get("rows", [])
content_cells = content_zone.get("cells", [])
content_cells_by_row: Dict[int, List[Dict]] = {}
for c in content_cells:
content_cells_by_row.setdefault(c.get("row_index", -1), []).append(c)
# Track which content rows we've processed
content_row_ptr = 0
for bi, box_info in enumerate(box_infos):
bz = box_info["zone"]
by_start = box_info["y_start"]
by_end = box_info["y_end"]
# --- Add content rows ABOVE this box ---
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry >= by_start:
break
# Add this content row
_add_content_row(
unified_rows, unified_cells, unified_row_idx,
cr, content_cells_by_row, dominant_h, image_height,
)
unified_row_idx += 1
content_row_ptr += 1
# --- Add box rows ---
if box_info["classification"] == "full_width":
# Full-width box: integrate box rows directly
_add_full_width_box(
unified_rows, unified_cells, unified_row_idx,
bz, box_info, dominant_h, num_cols, image_height,
)
unified_row_idx += len(bz.get("rows", []))
# Skip content rows that overlap with this box
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry > by_end:
break
content_row_ptr += 1
else:
# Partial-width box: merge with adjacent content rows
unified_row_idx = _add_partial_width_box(
unified_rows, unified_cells, unified_row_idx,
bz, box_info, content_rows, content_cells_by_row,
content_row_ptr, dominant_h, num_cols, image_height,
content_x, content_width,
)
# Advance content pointer past box region
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry > by_end:
break
content_row_ptr += 1
# --- Add remaining content rows BELOW all boxes ---
while content_row_ptr < len(content_rows):
cr = content_rows[content_row_ptr]
_add_content_row(
unified_rows, unified_cells, unified_row_idx,
cr, content_cells_by_row, dominant_h, image_height,
)
unified_row_idx += 1
content_row_ptr += 1
# --- Build unified zone ---
unified_zone = {
"zone_index": 0,
"zone_type": "unified",
"bbox_px": content_bbox,
"bbox_pct": content_zone.get("bbox_pct", {}),
"border": None,
"word_count": sum(len(c.get("word_boxes", [])) for c in unified_cells),
"columns": content_cols,
"rows": unified_rows,
"cells": unified_cells,
"header_rows": [],
}
logger.info(
"build_unified_grid: %d unified rows, %d cells (from %d content + %d box zones)",
len(unified_rows), len(unified_cells),
len(content_rows), len(box_zones),
)
return {
"zones": [unified_zone],
"image_width": image_width,
"image_height": image_height,
"layout_metrics": layout_metrics,
"summary": {
"total_zones": 1,
"total_columns": num_cols,
"total_rows": len(unified_rows),
"total_cells": len(unified_cells),
},
"is_unified": True,
"dominant_row_h": dominant_h,
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_row(idx: int, y: float, h: float, img_h: int, is_header: bool = False) -> Dict:
return {
"index": idx,
"row_index": idx,
"y_min_px": round(y),
"y_max_px": round(y + h),
"y_min_pct": round(y / img_h * 100, 2) if img_h else 0,
"y_max_pct": round((y + h) / img_h * 100, 2) if img_h else 0,
"is_header": is_header,
}
def _remap_cell(cell: Dict, new_row: int, new_col: int = None,
source_type: str = "content", box_region: Dict = None) -> Dict:
"""Create a new cell dict with remapped indices."""
c = dict(cell)
c["row_index"] = new_row
if new_col is not None:
c["col_index"] = new_col
c["cell_id"] = f"U_R{new_row:02d}_C{c.get('col_index', 0)}"
c["source_zone_type"] = source_type
if box_region:
c["box_region"] = box_region
return c
def _add_content_row(
unified_rows, unified_cells, row_idx,
content_row, cells_by_row, dominant_h, img_h,
):
"""Add a single content row to the unified grid."""
y = content_row.get("y_min_px", content_row.get("y_min", 0))
is_hdr = content_row.get("is_header", False)
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h, is_hdr))
for cell in cells_by_row.get(content_row.get("index", -1), []):
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
def _add_full_width_box(
unified_rows, unified_cells, start_row_idx,
box_zone, box_info, dominant_h, num_cols, img_h,
):
"""Add a full-width box's rows to the unified grid."""
box_rows = box_zone.get("rows", [])
box_cells = box_zone.get("cells", [])
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
# Distribute box height evenly among its rows
box_h = box_info["y_end"] - box_info["y_start"]
row_h = box_h / len(box_rows) if box_rows else dominant_h
for i, br in enumerate(box_rows):
y = box_info["y_start"] + i * row_h
new_idx = start_row_idx + i
is_hdr = br.get("is_header", False)
unified_rows.append(_make_row(new_idx, y, row_h, img_h, is_hdr))
for cell in box_cells:
if cell.get("row_index") == br.get("index", i):
unified_cells.append(
_remap_cell(cell, new_idx, source_type="box", box_region=box_region)
)
def _add_partial_width_box(
unified_rows, unified_cells, start_row_idx,
box_zone, box_info, content_rows, content_cells_by_row,
content_row_ptr, dominant_h, num_cols, img_h,
content_x, content_width,
) -> int:
"""Add a partial-width box merged with content rows.
Returns the next unified_row_idx after processing.
"""
by_start = box_info["y_start"]
by_end = box_info["y_end"]
box_h = by_end - by_start
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
# Content rows in the box's Y range
overlap_content_rows = []
ptr = content_row_ptr
while ptr < len(content_rows):
cr = content_rows[ptr]
cry = cr.get("y_min_px", cr.get("y_min", 0))
if cry > by_end:
break
if cry >= by_start:
overlap_content_rows.append(cr)
ptr += 1
# How many standard rows fit in the box height
standard_rows = max(1, math.floor(box_h / dominant_h))
# How many text lines the box actually has
box_text_lines = box_info["total_lines"]
# Extra rows needed
extra_rows = max(0, box_text_lines - standard_rows)
total_rows_for_region = standard_rows + extra_rows
logger.info(
"partial box: standard=%d, box_lines=%d, extra=%d, content_overlap=%d",
standard_rows, box_text_lines, extra_rows, len(overlap_content_rows),
)
# Determine which columns the box occupies
box_bb = box_zone.get("bbox_px", {})
box_x = box_bb.get("x", 0)
box_w = box_bb.get("w", 0)
# Map box to content columns: find which content columns overlap
box_col_start = 0
box_col_end = num_cols
content_cols_list = []
for z_col_idx in range(num_cols):
# Find the column definition by checking all column entries
# Simple heuristic: if box starts past halfway, it's the right columns
pass
# Simpler approach: box on right side → last N columns
# box on left side → first N columns
if box_info["side"] == "right":
# Box starts at x=box_x. Find first content column that overlaps
box_col_start = num_cols # default: beyond all columns
for z in (box_zone.get("columns") or [{"index": 0}]):
pass
# Use content column positions to determine overlap
content_cols_data = [
{"idx": c.get("index", i), "x_min": c.get("x_min_px", 0), "x_max": c.get("x_max_px", 0)}
for i, c in enumerate(content_rows[0:0] or []) # placeholder
]
# Simple: split columns at midpoint
box_col_start = num_cols // 2 # right half
box_col_end = num_cols
else:
box_col_start = 0
box_col_end = num_cols // 2
# Build rows for this region
box_cells = box_zone.get("cells", [])
box_rows = box_zone.get("rows", [])
row_idx = start_row_idx
# Expand box cell texts with \n into individual lines for row mapping
box_lines: List[Tuple[str, Dict]] = [] # (text_line, parent_cell)
for bc in sorted(box_cells, key=lambda c: c.get("row_index", 0)):
text = bc.get("text", "")
for line in text.split("\n"):
box_lines.append((line.strip(), bc))
for i in range(total_rows_for_region):
y = by_start + i * dominant_h
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h))
# Content cells for this row (from overlapping content rows)
if i < len(overlap_content_rows):
cr = overlap_content_rows[i]
for cell in content_cells_by_row.get(cr.get("index", -1), []):
# Only include cells from columns NOT covered by the box
ci = cell.get("col_index", 0)
if ci < box_col_start or ci >= box_col_end:
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
# Box cell for this row
if i < len(box_lines):
line_text, parent_cell = box_lines[i]
box_cell = {
"cell_id": f"U_R{row_idx:02d}_C{box_col_start}",
"row_index": row_idx,
"col_index": box_col_start,
"col_type": "spanning_header" if (box_col_end - box_col_start) > 1 else parent_cell.get("col_type", "column_1"),
"colspan": box_col_end - box_col_start,
"text": line_text,
"confidence": parent_cell.get("confidence", 0),
"bbox_px": parent_cell.get("bbox_px", {}),
"bbox_pct": parent_cell.get("bbox_pct", {}),
"word_boxes": [],
"ocr_engine": parent_cell.get("ocr_engine", ""),
"is_bold": parent_cell.get("is_bold", False),
"source_zone_type": "box",
"box_region": box_region,
}
unified_cells.append(box_cell)
row_idx += 1
return row_idx
# Backward-compat shim -- module moved to grid/unified.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("grid.unified")

View File

@@ -0,0 +1,6 @@
"""
Upload package — chunked and mobile upload endpoints.
Moved from backend/ flat modules (upload_api*.py).
Backward-compatible shim files remain at the old locations.
"""

View File

@@ -0,0 +1,29 @@
"""
Mobile Upload API — barrel re-export.
All implementation split into:
upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
upload_api_mobile — mobile HTML upload page
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
from fastapi import APIRouter
from .chunked import ( # noqa: F401
router as _chunked_router,
UPLOAD_DIR,
CHUNK_DIR,
EH_UPLOAD_DIR,
_upload_sessions,
InitUploadRequest,
InitUploadResponse,
ChunkUploadResponse,
FinalizeResponse,
)
from .mobile import router as _mobile_router # noqa: F401
# Composite router that includes both sub-routers
router = APIRouter()
router.include_router(_chunked_router)
router.include_router(_mobile_router)

View File

@@ -0,0 +1,320 @@
"""
Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
Extracted from upload_api.py for modularity.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
import os
import uuid
import shutil
import hashlib
from pathlib import Path
from datetime import datetime, timezone
from typing import Dict, Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
# Configuration
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
# Ensure directories exist
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# In-memory storage for upload sessions (for simplicity)
# In production, use Redis or database
_upload_sessions: Dict[str, dict] = {}
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
class InitUploadRequest(BaseModel):
filename: str
filesize: int
chunks: int
destination: str = "klausur" # "klausur" or "rag"
class InitUploadResponse(BaseModel):
upload_id: str
chunk_size: int
total_chunks: int
message: str
class ChunkUploadResponse(BaseModel):
upload_id: str
chunk_index: int
received: bool
chunks_received: int
total_chunks: int
class FinalizeResponse(BaseModel):
upload_id: str
filename: str
filepath: str
filesize: int
checksum: str
message: str
@router.post("/init", response_model=InitUploadResponse)
async def init_upload(request: InitUploadRequest):
"""
Initialize a chunked upload session.
Returns an upload_id that must be used for subsequent chunk uploads.
"""
upload_id = str(uuid.uuid4())
# Create session directory
session_dir = CHUNK_DIR / upload_id
session_dir.mkdir(parents=True, exist_ok=True)
# Store session info
_upload_sessions[upload_id] = {
"filename": request.filename,
"filesize": request.filesize,
"total_chunks": request.chunks,
"received_chunks": set(),
"destination": request.destination,
"session_dir": str(session_dir),
"created_at": datetime.now(timezone.utc).isoformat(),
}
return InitUploadResponse(
upload_id=upload_id,
chunk_size=5 * 1024 * 1024, # 5 MB
total_chunks=request.chunks,
message="Upload-Session erstellt"
)
@router.post("/chunk", response_model=ChunkUploadResponse)
async def upload_chunk(
chunk: UploadFile = File(...),
upload_id: str = Form(...),
chunk_index: int = Form(...)
):
"""
Upload a single chunk of a file.
Chunks are stored temporarily until finalize is called.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
raise HTTPException(
status_code=400,
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
)
# Save chunk
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
with open(chunk_path, "wb") as f:
content = await chunk.read()
f.write(content)
# Track received chunks
session["received_chunks"].add(chunk_index)
return ChunkUploadResponse(
upload_id=upload_id,
chunk_index=chunk_index,
received=True,
chunks_received=len(session["received_chunks"]),
total_chunks=session["total_chunks"]
)
@router.post("/finalize", response_model=FinalizeResponse)
async def finalize_upload(upload_id: str = Form(...)):
"""
Finalize the upload by combining all chunks into a single file.
Validates that all chunks were received and calculates checksum.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Check if all chunks received
if len(session["received_chunks"]) != session["total_chunks"]:
missing = session["total_chunks"] - len(session["received_chunks"])
raise HTTPException(
status_code=400,
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
)
# Determine destination directory
if session["destination"] == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = session["filename"].replace(" ", "_")
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Combine chunks
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as outfile:
for i in range(session["total_chunks"]):
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
if not chunk_path.exists():
raise HTTPException(
status_code=500,
detail=f"Chunk {i} nicht gefunden"
)
with open(chunk_path, "rb") as infile:
data = infile.read()
outfile.write(data)
hasher.update(data)
total_size += len(data)
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
checksum = hasher.hexdigest()
return FinalizeResponse(
upload_id=upload_id,
filename=final_filename,
filepath=str(final_path),
filesize=total_size,
checksum=checksum,
message="Upload erfolgreich abgeschlossen"
)
@router.post("/simple")
async def simple_upload(
file: UploadFile = File(...),
destination: str = Form("klausur")
):
"""
Simple single-request upload for smaller files (<10MB).
For larger files, use the chunked upload endpoints.
"""
# Determine destination directory
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Calculate checksum while writing
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as f:
while True:
chunk = await file.read(1024 * 1024) # Read 1MB at a time
if not chunk:
break
f.write(chunk)
hasher.update(chunk)
total_size += len(chunk)
return {
"filename": final_filename,
"filepath": str(final_path),
"filesize": total_size,
"checksum": hasher.hexdigest(),
"message": "Upload erfolgreich"
}
@router.get("/status/{upload_id}")
async def get_upload_status(upload_id: str):
"""
Get the status of an ongoing upload.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
return {
"upload_id": upload_id,
"filename": session["filename"],
"total_chunks": session["total_chunks"],
"received_chunks": len(session["received_chunks"]),
"progress_percent": round(
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
),
"destination": session["destination"],
"created_at": session["created_at"]
}
@router.delete("/cancel/{upload_id}")
async def cancel_upload(upload_id: str):
"""
Cancel an ongoing upload and clean up temporary files.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
return {"message": "Upload abgebrochen", "upload_id": upload_id}
@router.get("/list")
async def list_uploads(destination: str = "klausur"):
"""
List all uploaded files in the specified destination.
"""
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
files = []
for f in dest_dir.iterdir():
if f.is_file() and f.suffix.lower() == ".pdf":
stat = f.stat()
files.append({
"filename": f.name,
"size": stat.st_size,
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
})
files.sort(key=lambda x: x["modified"], reverse=True)
return {
"destination": destination,
"count": len(files),
"files": files[:50] # Limit to 50 most recent
}

View File

@@ -0,0 +1,292 @@
"""
Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
Extracted from upload_api.py for modularity.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
from fastapi import APIRouter
from fastapi.responses import HTMLResponse
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
@router.get("/mobile", response_class=HTMLResponse)
async def mobile_upload_page():
"""
Serve the mobile upload page directly from the klausur-service.
This allows mobile devices to upload without needing the Next.js website.
"""
html_content = '''<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<meta name="apple-mobile-web-app-capable" content="yes">
<title>BreakPilot Upload</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
color: white;
min-height: 100vh;
padding: 16px;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 16px;
border-bottom: 1px solid #334155;
margin-bottom: 24px;
}
.header h1 { font-size: 20px; color: #60a5fa; }
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
.destination-selector {
display: flex;
gap: 8px;
margin-bottom: 24px;
}
.dest-btn {
flex: 1;
padding: 14px;
border: none;
border-radius: 10px;
font-size: 14px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
}
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
.upload-zone {
border: 2px dashed #475569;
border-radius: 16px;
padding: 40px 20px;
text-align: center;
margin-bottom: 24px;
transition: all 0.2s;
position: relative;
}
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
.upload-zone input[type="file"] {
position: absolute;
inset: 0;
opacity: 0;
cursor: pointer;
}
.upload-icon {
width: 64px;
height: 64px;
background: #334155;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 16px;
font-size: 28px;
}
.upload-title { font-size: 18px; margin-bottom: 8px; }
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
.upload-hint { font-size: 12px; color: #64748b; }
.file-list { margin-bottom: 24px; }
.file-item {
background: #1e293b;
border-radius: 12px;
padding: 16px;
margin-bottom: 12px;
}
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
.file-name { font-weight: 500; word-break: break-all; }
.file-size { font-size: 14px; color: #94a3b8; }
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
.info-box {
background: rgba(30, 41, 59, 0.5);
border-radius: 12px;
padding: 16px;
font-size: 14px;
color: #94a3b8;
}
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
.info-box ul { padding-left: 20px; }
.info-box li { margin-bottom: 4px; }
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
</style>
</head>
<body>
<header class="header">
<h1>BreakPilot Upload</h1>
<span class="badge">DSGVO-konform</span>
</header>
<div class="destination-selector">
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
</div>
<div class="upload-zone" id="upload-zone">
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
<div class="upload-icon">&#x2601;</div>
<div class="upload-title">PDF-Dateien hochladen</div>
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
</div>
<div class="stats" id="stats" style="display: none;">
<span id="completed-count">0 von 0 fertig</span>
<span id="total-size">0 B gesamt</span>
</div>
<div class="file-list" id="file-list"></div>
<div class="info-box">
<h3>Hinweise:</h3>
<ul>
<li>Die Dateien werden lokal im WLAN uebertragen</li>
<li>Keine Daten werden ins Internet gesendet</li>
<li>Unterstuetzte Formate: PDF</li>
</ul>
</div>
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
<script>
const CHUNK_SIZE = 5 * 1024 * 1024;
let destination = 'klausur';
let files = [];
const serverUrl = window.location.origin;
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
function setDestination(dest) {
destination = dest;
document.querySelectorAll('.dest-btn').forEach(btn => {
btn.classList.remove('active-klausur', 'active-rag');
});
if (dest === 'klausur') {
document.getElementById('btn-klausur').classList.add('active-klausur');
} else {
document.getElementById('btn-rag').classList.add('active-rag');
}
}
function formatSize(bytes) {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
}
function updateStats() {
const completed = files.filter(f => f.status === 'complete').length;
const total = files.reduce((sum, f) => sum + f.size, 0);
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
}
function renderFiles() {
const list = document.getElementById('file-list');
list.innerHTML = files.map(f => {
let statusHtml = '';
if (f.status === 'uploading' || f.status === 'pending') {
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
} else if (f.status === 'complete') {
statusHtml = '<div class="status-complete">&#x2713; Erfolgreich hochgeladen</div>';
} else if (f.status === 'error') {
statusHtml = '<div class="status-error">&#x26A0; ' + (f.error || 'Fehler beim Hochladen') + '</div>';
}
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">&times;</button></div>' + statusHtml + '</div>';
}).join('');
updateStats();
}
function removeFile(id) {
files = files.filter(f => f.id !== id);
renderFiles();
}
async function uploadFile(file, fileId) {
const updateProgress = (progress) => {
const f = files.find(f => f.id === fileId);
if (f) { f.progress = progress; renderFiles(); }
};
const setStatus = (status, error) => {
const f = files.find(f => f.id === fileId);
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
};
try {
setStatus('uploading');
if (file.size > 10 * 1024 * 1024) {
// Chunked upload
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
});
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
const { upload_id } = await initRes.json();
for (let i = 0; i < totalChunks; i++) {
const start = i * CHUNK_SIZE;
const end = Math.min(start + CHUNK_SIZE, file.size);
const chunk = file.slice(start, end);
const formData = new FormData();
formData.append('chunk', chunk);
formData.append('upload_id', upload_id);
formData.append('chunk_index', i.toString());
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
}
const finalizeForm = new FormData();
finalizeForm.append('upload_id', upload_id);
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
} else {
// Simple upload
const formData = new FormData();
formData.append('file', file);
formData.append('destination', destination);
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
if (!res.ok) throw new Error('Upload fehlgeschlagen');
updateProgress(100);
}
setStatus('complete');
} catch (e) {
setStatus('error', e.message);
}
}
function handleFiles(fileList) {
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
newFiles.forEach(file => {
const id = Math.random().toString(36).substr(2, 9);
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
renderFiles();
uploadFile(file, id);
});
}
// Drag & Drop
const zone = document.getElementById('upload-zone');
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
</script>
</body>
</html>'''
return HTMLResponse(content=html_content)

View File

@@ -1,29 +1,4 @@
"""
Mobile Upload API — barrel re-export.
All implementation split into:
upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
upload_api_mobile — mobile HTML upload page
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
from fastapi import APIRouter
from upload_api_chunked import ( # noqa: F401
router as _chunked_router,
UPLOAD_DIR,
CHUNK_DIR,
EH_UPLOAD_DIR,
_upload_sessions,
InitUploadRequest,
InitUploadResponse,
ChunkUploadResponse,
FinalizeResponse,
)
from upload_api_mobile import router as _mobile_router # noqa: F401
# Composite router that includes both sub-routers
router = APIRouter()
router.include_router(_chunked_router)
router.include_router(_mobile_router)
# Backward-compat shim -- module moved to upload/api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("upload.api")

View File

@@ -1,320 +1,4 @@
"""
Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
Extracted from upload_api.py for modularity.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
import os
import uuid
import shutil
import hashlib
from pathlib import Path
from datetime import datetime, timezone
from typing import Dict, Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
# Configuration
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
# Ensure directories exist
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# In-memory storage for upload sessions (for simplicity)
# In production, use Redis or database
_upload_sessions: Dict[str, dict] = {}
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
class InitUploadRequest(BaseModel):
filename: str
filesize: int
chunks: int
destination: str = "klausur" # "klausur" or "rag"
class InitUploadResponse(BaseModel):
upload_id: str
chunk_size: int
total_chunks: int
message: str
class ChunkUploadResponse(BaseModel):
upload_id: str
chunk_index: int
received: bool
chunks_received: int
total_chunks: int
class FinalizeResponse(BaseModel):
upload_id: str
filename: str
filepath: str
filesize: int
checksum: str
message: str
@router.post("/init", response_model=InitUploadResponse)
async def init_upload(request: InitUploadRequest):
"""
Initialize a chunked upload session.
Returns an upload_id that must be used for subsequent chunk uploads.
"""
upload_id = str(uuid.uuid4())
# Create session directory
session_dir = CHUNK_DIR / upload_id
session_dir.mkdir(parents=True, exist_ok=True)
# Store session info
_upload_sessions[upload_id] = {
"filename": request.filename,
"filesize": request.filesize,
"total_chunks": request.chunks,
"received_chunks": set(),
"destination": request.destination,
"session_dir": str(session_dir),
"created_at": datetime.now(timezone.utc).isoformat(),
}
return InitUploadResponse(
upload_id=upload_id,
chunk_size=5 * 1024 * 1024, # 5 MB
total_chunks=request.chunks,
message="Upload-Session erstellt"
)
@router.post("/chunk", response_model=ChunkUploadResponse)
async def upload_chunk(
chunk: UploadFile = File(...),
upload_id: str = Form(...),
chunk_index: int = Form(...)
):
"""
Upload a single chunk of a file.
Chunks are stored temporarily until finalize is called.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
raise HTTPException(
status_code=400,
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
)
# Save chunk
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
with open(chunk_path, "wb") as f:
content = await chunk.read()
f.write(content)
# Track received chunks
session["received_chunks"].add(chunk_index)
return ChunkUploadResponse(
upload_id=upload_id,
chunk_index=chunk_index,
received=True,
chunks_received=len(session["received_chunks"]),
total_chunks=session["total_chunks"]
)
@router.post("/finalize", response_model=FinalizeResponse)
async def finalize_upload(upload_id: str = Form(...)):
"""
Finalize the upload by combining all chunks into a single file.
Validates that all chunks were received and calculates checksum.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Check if all chunks received
if len(session["received_chunks"]) != session["total_chunks"]:
missing = session["total_chunks"] - len(session["received_chunks"])
raise HTTPException(
status_code=400,
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
)
# Determine destination directory
if session["destination"] == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = session["filename"].replace(" ", "_")
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Combine chunks
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as outfile:
for i in range(session["total_chunks"]):
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
if not chunk_path.exists():
raise HTTPException(
status_code=500,
detail=f"Chunk {i} nicht gefunden"
)
with open(chunk_path, "rb") as infile:
data = infile.read()
outfile.write(data)
hasher.update(data)
total_size += len(data)
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
checksum = hasher.hexdigest()
return FinalizeResponse(
upload_id=upload_id,
filename=final_filename,
filepath=str(final_path),
filesize=total_size,
checksum=checksum,
message="Upload erfolgreich abgeschlossen"
)
@router.post("/simple")
async def simple_upload(
file: UploadFile = File(...),
destination: str = Form("klausur")
):
"""
Simple single-request upload for smaller files (<10MB).
For larger files, use the chunked upload endpoints.
"""
# Determine destination directory
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
final_filename = f"{timestamp}_{safe_filename}"
final_path = dest_dir / final_filename
# Calculate checksum while writing
hasher = hashlib.sha256()
total_size = 0
with open(final_path, "wb") as f:
while True:
chunk = await file.read(1024 * 1024) # Read 1MB at a time
if not chunk:
break
f.write(chunk)
hasher.update(chunk)
total_size += len(chunk)
return {
"filename": final_filename,
"filepath": str(final_path),
"filesize": total_size,
"checksum": hasher.hexdigest(),
"message": "Upload erfolgreich"
}
@router.get("/status/{upload_id}")
async def get_upload_status(upload_id: str):
"""
Get the status of an ongoing upload.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
return {
"upload_id": upload_id,
"filename": session["filename"],
"total_chunks": session["total_chunks"],
"received_chunks": len(session["received_chunks"]),
"progress_percent": round(
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
),
"destination": session["destination"],
"created_at": session["created_at"]
}
@router.delete("/cancel/{upload_id}")
async def cancel_upload(upload_id: str):
"""
Cancel an ongoing upload and clean up temporary files.
"""
if upload_id not in _upload_sessions:
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
session = _upload_sessions[upload_id]
# Clean up chunks
shutil.rmtree(session["session_dir"], ignore_errors=True)
del _upload_sessions[upload_id]
return {"message": "Upload abgebrochen", "upload_id": upload_id}
@router.get("/list")
async def list_uploads(destination: str = "klausur"):
"""
List all uploaded files in the specified destination.
"""
if destination == "rag":
dest_dir = EH_UPLOAD_DIR
else:
dest_dir = UPLOAD_DIR
files = []
for f in dest_dir.iterdir():
if f.is_file() and f.suffix.lower() == ".pdf":
stat = f.stat()
files.append({
"filename": f.name,
"size": stat.st_size,
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
})
files.sort(key=lambda x: x["modified"], reverse=True)
return {
"destination": destination,
"count": len(files),
"files": files[:50] # Limit to 50 most recent
}
# Backward-compat shim -- module moved to upload/chunked.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("upload.chunked")

View File

@@ -1,292 +1,4 @@
"""
Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
Extracted from upload_api.py for modularity.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
from fastapi import APIRouter
from fastapi.responses import HTMLResponse
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
@router.get("/mobile", response_class=HTMLResponse)
async def mobile_upload_page():
"""
Serve the mobile upload page directly from the klausur-service.
This allows mobile devices to upload without needing the Next.js website.
"""
html_content = '''<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<meta name="apple-mobile-web-app-capable" content="yes">
<title>BreakPilot Upload</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
color: white;
min-height: 100vh;
padding: 16px;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 16px;
border-bottom: 1px solid #334155;
margin-bottom: 24px;
}
.header h1 { font-size: 20px; color: #60a5fa; }
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
.destination-selector {
display: flex;
gap: 8px;
margin-bottom: 24px;
}
.dest-btn {
flex: 1;
padding: 14px;
border: none;
border-radius: 10px;
font-size: 14px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
}
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
.upload-zone {
border: 2px dashed #475569;
border-radius: 16px;
padding: 40px 20px;
text-align: center;
margin-bottom: 24px;
transition: all 0.2s;
position: relative;
}
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
.upload-zone input[type="file"] {
position: absolute;
inset: 0;
opacity: 0;
cursor: pointer;
}
.upload-icon {
width: 64px;
height: 64px;
background: #334155;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 16px;
font-size: 28px;
}
.upload-title { font-size: 18px; margin-bottom: 8px; }
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
.upload-hint { font-size: 12px; color: #64748b; }
.file-list { margin-bottom: 24px; }
.file-item {
background: #1e293b;
border-radius: 12px;
padding: 16px;
margin-bottom: 12px;
}
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
.file-name { font-weight: 500; word-break: break-all; }
.file-size { font-size: 14px; color: #94a3b8; }
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
.info-box {
background: rgba(30, 41, 59, 0.5);
border-radius: 12px;
padding: 16px;
font-size: 14px;
color: #94a3b8;
}
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
.info-box ul { padding-left: 20px; }
.info-box li { margin-bottom: 4px; }
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
</style>
</head>
<body>
<header class="header">
<h1>BreakPilot Upload</h1>
<span class="badge">DSGVO-konform</span>
</header>
<div class="destination-selector">
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
</div>
<div class="upload-zone" id="upload-zone">
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
<div class="upload-icon">&#x2601;</div>
<div class="upload-title">PDF-Dateien hochladen</div>
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
</div>
<div class="stats" id="stats" style="display: none;">
<span id="completed-count">0 von 0 fertig</span>
<span id="total-size">0 B gesamt</span>
</div>
<div class="file-list" id="file-list"></div>
<div class="info-box">
<h3>Hinweise:</h3>
<ul>
<li>Die Dateien werden lokal im WLAN uebertragen</li>
<li>Keine Daten werden ins Internet gesendet</li>
<li>Unterstuetzte Formate: PDF</li>
</ul>
</div>
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
<script>
const CHUNK_SIZE = 5 * 1024 * 1024;
let destination = 'klausur';
let files = [];
const serverUrl = window.location.origin;
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
function setDestination(dest) {
destination = dest;
document.querySelectorAll('.dest-btn').forEach(btn => {
btn.classList.remove('active-klausur', 'active-rag');
});
if (dest === 'klausur') {
document.getElementById('btn-klausur').classList.add('active-klausur');
} else {
document.getElementById('btn-rag').classList.add('active-rag');
}
}
function formatSize(bytes) {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
}
function updateStats() {
const completed = files.filter(f => f.status === 'complete').length;
const total = files.reduce((sum, f) => sum + f.size, 0);
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
}
function renderFiles() {
const list = document.getElementById('file-list');
list.innerHTML = files.map(f => {
let statusHtml = '';
if (f.status === 'uploading' || f.status === 'pending') {
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
} else if (f.status === 'complete') {
statusHtml = '<div class="status-complete">&#x2713; Erfolgreich hochgeladen</div>';
} else if (f.status === 'error') {
statusHtml = '<div class="status-error">&#x26A0; ' + (f.error || 'Fehler beim Hochladen') + '</div>';
}
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">&times;</button></div>' + statusHtml + '</div>';
}).join('');
updateStats();
}
function removeFile(id) {
files = files.filter(f => f.id !== id);
renderFiles();
}
async function uploadFile(file, fileId) {
const updateProgress = (progress) => {
const f = files.find(f => f.id === fileId);
if (f) { f.progress = progress; renderFiles(); }
};
const setStatus = (status, error) => {
const f = files.find(f => f.id === fileId);
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
};
try {
setStatus('uploading');
if (file.size > 10 * 1024 * 1024) {
// Chunked upload
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
});
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
const { upload_id } = await initRes.json();
for (let i = 0; i < totalChunks; i++) {
const start = i * CHUNK_SIZE;
const end = Math.min(start + CHUNK_SIZE, file.size);
const chunk = file.slice(start, end);
const formData = new FormData();
formData.append('chunk', chunk);
formData.append('upload_id', upload_id);
formData.append('chunk_index', i.toString());
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
}
const finalizeForm = new FormData();
finalizeForm.append('upload_id', upload_id);
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
} else {
// Simple upload
const formData = new FormData();
formData.append('file', file);
formData.append('destination', destination);
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
if (!res.ok) throw new Error('Upload fehlgeschlagen');
updateProgress(100);
}
setStatus('complete');
} catch (e) {
setStatus('error', e.message);
}
}
function handleFiles(fileList) {
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
newFiles.forEach(file => {
const id = Math.random().toString(36).substr(2, 9);
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
renderFiles();
uploadFile(file, id);
});
}
// Drag & Drop
const zone = document.getElementById('upload-zone');
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
</script>
</body>
</html>'''
return HTMLResponse(content=html_content)
# Backward-compat shim -- module moved to upload/mobile.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("upload.mobile")