Restructure: Move final 12 root files into packages (klausur-service)
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m23s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 19s
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m23s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 19s
ocr/spell/ (3): smart_spell, core, text upload/ (3): api, chunked, mobile crawler/ (3): github, github_core, github_parsers + unified_grid → grid/, tesseract_extractor → ocr/engines/, htr_api → ocr/pipeline/ 12 shims added. Only main.py, config.py, storage + RAG files remain at root. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
6
klausur-service/backend/crawler/__init__.py
Normal file
6
klausur-service/backend/crawler/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Crawler package — GitHub repository crawler for legal templates.
|
||||
|
||||
Moved from backend/ flat modules (github_crawler*.py).
|
||||
Backward-compatible shim files remain at the old locations.
|
||||
"""
|
||||
35
klausur-service/backend/crawler/github.py
Normal file
35
klausur-service/backend/crawler/github.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
GitHub Repository Crawler — Barrel Re-export
|
||||
|
||||
Split into:
|
||||
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
|
||||
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
# Parsers
|
||||
from .github_parsers import ( # noqa: F401
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
# Crawler and downloader
|
||||
from .github_core import ( # noqa: F401
|
||||
GITHUB_API_URL,
|
||||
GITLAB_API_URL,
|
||||
GITHUB_TOKEN,
|
||||
MAX_FILE_SIZE,
|
||||
REQUEST_TIMEOUT,
|
||||
RATE_LIMIT_DELAY,
|
||||
GitHubCrawler,
|
||||
RepositoryDownloader,
|
||||
crawl_source,
|
||||
main,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
411
klausur-service/backend/crawler/github_core.py
Normal file
411
klausur-service/backend/crawler/github_core.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
GitHub Crawler - Core Crawler and Downloader
|
||||
|
||||
GitHubCrawler for API-based repository crawling and RepositoryDownloader
|
||||
for ZIP-based local extraction.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from template_sources import SourceConfig
|
||||
from .github_parsers import (
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
GITLAB_API_URL = "https://gitlab.com/api/v4"
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
|
||||
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
|
||||
REQUEST_TIMEOUT = 60.0
|
||||
RATE_LIMIT_DELAY = 1.0
|
||||
|
||||
|
||||
class GitHubCrawler:
|
||||
"""Crawl GitHub repositories for legal templates."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token or GITHUB_TOKEN
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "LegalTemplatesCrawler/1.0",
|
||||
}
|
||||
if self.token:
|
||||
self.headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=self.headers,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
|
||||
"""Parse repository URL into owner, repo, and host."""
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"Invalid repository URL: {url}")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
if 'gitlab' in parsed.netloc:
|
||||
host = 'gitlab'
|
||||
else:
|
||||
host = 'github'
|
||||
|
||||
return owner, repo, host
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch of a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("default_branch", "main")
|
||||
|
||||
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
|
||||
"""Get the latest commit SHA for a branch."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("sha", "")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str = "",
|
||||
branch: str = "main",
|
||||
patterns: List[str] = None,
|
||||
exclude_patterns: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files in a repository matching the given patterns."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
patterns = patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("tree", []):
|
||||
if item["type"] != "blob":
|
||||
continue
|
||||
|
||||
file_path = item["path"]
|
||||
|
||||
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
if item.get("size", 0) > MAX_FILE_SIZE:
|
||||
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
|
||||
continue
|
||||
|
||||
files.append({
|
||||
"path": file_path,
|
||||
"sha": item["sha"],
|
||||
"size": item.get("size", 0),
|
||||
"url": item.get("url", ""),
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
|
||||
"""Get the content of a file from a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def crawl_repository(
|
||||
self,
|
||||
source: SourceConfig,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a repository and yield extracted documents."""
|
||||
if not source.repo_url:
|
||||
logger.warning(f"No repo URL for source: {source.name}")
|
||||
return
|
||||
|
||||
try:
|
||||
owner, repo, host = self._parse_repo_url(source.repo_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
|
||||
return
|
||||
|
||||
if host == "gitlab":
|
||||
logger.info(f"GitLab repos not yet supported: {source.name}")
|
||||
return
|
||||
|
||||
logger.info(f"Crawling repository: {owner}/{repo}")
|
||||
|
||||
try:
|
||||
branch = await self.get_default_branch(owner, repo)
|
||||
commit_sha = await self.get_latest_commit(owner, repo, branch)
|
||||
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
files = await self.list_files(
|
||||
owner, repo,
|
||||
branch=branch,
|
||||
patterns=source.file_patterns,
|
||||
exclude_patterns=source.exclude_patterns,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(files)} matching files in {source.name}")
|
||||
|
||||
for file_info in files:
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
try:
|
||||
content = await self.get_file_content(
|
||||
owner, repo, file_info["path"], branch
|
||||
)
|
||||
|
||||
file_path = file_info["path"]
|
||||
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
|
||||
|
||||
if file_path.endswith('.md'):
|
||||
doc = MarkdownParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.html') or file_path.endswith('.htm'):
|
||||
doc = HTMLParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.json'):
|
||||
docs = JSONParser.parse(content, file_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.txt'):
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=Path(file_path).stem,
|
||||
file_path=file_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
source_commit=commit_sha,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch {file_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error crawling {source.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {source.name}: {e}")
|
||||
|
||||
|
||||
class RepositoryDownloader:
|
||||
"""Download and extract repository archives."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
|
||||
"""Download repository as ZIP and extract to temp directory."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
|
||||
|
||||
parsed = urlparse(repo_url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
|
||||
logger.info(f"Downloading ZIP from {zip_url}")
|
||||
|
||||
response = await self.http_client.get(zip_url)
|
||||
response.raise_for_status()
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
zip_path = temp_dir / f"{repo}.zip"
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
extract_dir = temp_dir / repo
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
|
||||
if extracted_dirs:
|
||||
return extracted_dirs[0]
|
||||
|
||||
return extract_dir
|
||||
|
||||
async def crawl_local_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
source: SourceConfig,
|
||||
base_url: str,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a local directory for documents."""
|
||||
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = source.exclude_patterns or []
|
||||
|
||||
for pattern in patterns:
|
||||
for file_path in directory.rglob(pattern.replace("**/", "")):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
rel_path = str(file_path.relative_to(directory))
|
||||
|
||||
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
if file_path.stat().st_size > MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = file_path.read_text(encoding='latin-1')
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
source_url = f"{base_url}/{rel_path}"
|
||||
|
||||
if file_path.suffix == '.md':
|
||||
doc = MarkdownParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix in ['.html', '.htm']:
|
||||
doc = HTMLParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.json':
|
||||
docs = JSONParser.parse(content, rel_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.txt':
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=file_path.stem,
|
||||
file_path=rel_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
def cleanup(self, directory: Path):
|
||||
"""Clean up temporary directory."""
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
|
||||
|
||||
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
|
||||
"""Crawl a source configuration and return all extracted documents."""
|
||||
documents = []
|
||||
|
||||
if source.repo_url:
|
||||
async with GitHubCrawler() as crawler:
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# CLI for testing
|
||||
async def main():
|
||||
"""Test crawler with a sample source."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
|
||||
|
||||
async with GitHubCrawler() as crawler:
|
||||
count = 0
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
count += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Path: {doc.file_path}")
|
||||
print(f"Type: {doc.file_type}")
|
||||
print(f"Language: {doc.language}")
|
||||
print(f"URL: {doc.source_url}")
|
||||
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
|
||||
print(f"Text preview: {doc.text[:200]}...")
|
||||
|
||||
print(f"\n\nTotal documents: {count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
303
klausur-service/backend/crawler/github_parsers.py
Normal file
303
klausur-service/backend/crawler/github_parsers.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
GitHub Crawler - Document Parsers
|
||||
|
||||
Markdown, HTML, and JSON parsers for extracting structured content
|
||||
from legal template documents.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedDocument:
|
||||
"""A document extracted from a repository."""
|
||||
text: str
|
||||
title: str
|
||||
file_path: str
|
||||
file_type: str # "markdown", "html", "json", "text"
|
||||
source_url: str
|
||||
source_commit: Optional[str] = None
|
||||
source_hash: str = "" # SHA256 of original content
|
||||
sections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
placeholders: List[str] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.source_hash and self.text:
|
||||
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""Parse Markdown files into structured content."""
|
||||
|
||||
# Common placeholder patterns
|
||||
PLACEHOLDER_PATTERNS = [
|
||||
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
|
||||
r'\{([a-z_]+)\}', # {company_name}
|
||||
r'\{\{([a-z_]+)\}\}', # {{company_name}}
|
||||
r'__([A-Z_]+)__', # __COMPANY_NAME__
|
||||
r'<([A-Z_]+)>', # <COMPANY_NAME>
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse markdown content into an ExtractedDocument."""
|
||||
title = cls._extract_title(content, filename)
|
||||
sections = cls._extract_sections(content)
|
||||
placeholders = cls._find_placeholders(content)
|
||||
language = cls._detect_language(content)
|
||||
clean_text = cls._clean_for_indexing(content)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=clean_text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="markdown",
|
||||
source_url="",
|
||||
sections=sections,
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_title(cls, content: str, filename: str) -> str:
|
||||
"""Extract title from markdown heading or filename."""
|
||||
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
|
||||
frontmatter_match = re.search(
|
||||
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if frontmatter_match:
|
||||
return frontmatter_match.group(1).strip()
|
||||
|
||||
if filename:
|
||||
name = Path(filename).stem
|
||||
return name.replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
return "Untitled"
|
||||
|
||||
@classmethod
|
||||
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract sections from markdown content."""
|
||||
sections = []
|
||||
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
|
||||
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = current_section["content"].strip()
|
||||
sections.append(current_section.copy())
|
||||
|
||||
level = len(match.group(1))
|
||||
heading = match.group(2).strip()
|
||||
current_section = {
|
||||
"heading": heading,
|
||||
"level": level,
|
||||
"content": "",
|
||||
"start": match.end(),
|
||||
}
|
||||
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = content[current_section["start"]:].strip()
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def _find_placeholders(cls, content: str) -> List[str]:
|
||||
"""Find placeholder patterns in content."""
|
||||
placeholders = set()
|
||||
for pattern in cls.PLACEHOLDER_PATTERNS:
|
||||
for match in re.finditer(pattern, content):
|
||||
placeholder = match.group(0)
|
||||
placeholders.add(placeholder)
|
||||
return sorted(list(placeholders))
|
||||
|
||||
@classmethod
|
||||
def _detect_language(cls, content: str) -> str:
|
||||
"""Detect language from content."""
|
||||
german_indicators = [
|
||||
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
|
||||
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
|
||||
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
|
||||
]
|
||||
|
||||
lower_content = content.lower()
|
||||
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
|
||||
|
||||
if german_count >= 3:
|
||||
return "de"
|
||||
return "en"
|
||||
|
||||
@classmethod
|
||||
def _clean_for_indexing(cls, content: str) -> str:
|
||||
"""Clean markdown content for text indexing."""
|
||||
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<[^>]+>', '', content)
|
||||
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
|
||||
content = re.sub(r'\*(.+?)\*', r'\1', content)
|
||||
content = re.sub(r'`(.+?)`', r'\1', content)
|
||||
content = re.sub(r'~~(.+?)~~', r'\1', content)
|
||||
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' +', ' ', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
class HTMLParser:
|
||||
"""Parse HTML files into structured content."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse HTML content into an ExtractedDocument."""
|
||||
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
|
||||
title = title_match.group(1) if title_match else Path(filename).stem
|
||||
|
||||
text = cls._html_to_text(content)
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
|
||||
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
|
||||
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="html",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _html_to_text(cls, html: str) -> str:
|
||||
"""Convert HTML to clean text."""
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
|
||||
|
||||
html = html.replace(' ', ' ')
|
||||
html = html.replace('&', '&')
|
||||
html = html.replace('<', '<')
|
||||
html = html.replace('>', '>')
|
||||
html = html.replace('"', '"')
|
||||
html = html.replace(''', "'")
|
||||
|
||||
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
|
||||
|
||||
html = re.sub(r'<[^>]+>', '', html)
|
||||
|
||||
html = re.sub(r'[ \t]+', ' ', html)
|
||||
html = re.sub(r'\n[ \t]+', '\n', html)
|
||||
html = re.sub(r'[ \t]+\n', '\n', html)
|
||||
html = re.sub(r'\n{3,}', '\n\n', html)
|
||||
|
||||
return html.strip()
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""Parse JSON files containing legal template data."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
|
||||
"""Parse JSON content into ExtractedDocuments."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
documents.extend(cls._parse_dict(data, filename))
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
docs = cls._parse_dict(item, f"{filename}[{i}]")
|
||||
documents.extend(docs)
|
||||
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
|
||||
"""Parse a dictionary into documents."""
|
||||
documents = []
|
||||
|
||||
text_keys = ['text', 'content', 'body', 'description', 'value']
|
||||
title_keys = ['title', 'name', 'heading', 'label', 'key']
|
||||
|
||||
text = ""
|
||||
for key in text_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
text = data[key]
|
||||
break
|
||||
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(item, str) and len(item) > 50:
|
||||
documents.append(ExtractedDocument(
|
||||
text=item,
|
||||
title=f"{key} {i+1}",
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
language=MarkdownParser._detect_language(item),
|
||||
))
|
||||
return documents
|
||||
|
||||
title = ""
|
||||
for key in title_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
title = data[key]
|
||||
break
|
||||
|
||||
if not title:
|
||||
title = Path(filename).stem
|
||||
|
||||
metadata = {}
|
||||
for key, value in data.items():
|
||||
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
|
||||
metadata[key] = value
|
||||
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
|
||||
|
||||
documents.append(ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
metadata=metadata,
|
||||
))
|
||||
|
||||
return documents
|
||||
@@ -1,35 +1,4 @@
|
||||
"""
|
||||
GitHub Repository Crawler — Barrel Re-export
|
||||
|
||||
Split into:
|
||||
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
|
||||
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
# Parsers
|
||||
from github_crawler_parsers import ( # noqa: F401
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
# Crawler and downloader
|
||||
from github_crawler_core import ( # noqa: F401
|
||||
GITHUB_API_URL,
|
||||
GITLAB_API_URL,
|
||||
GITHUB_TOKEN,
|
||||
MAX_FILE_SIZE,
|
||||
REQUEST_TIMEOUT,
|
||||
RATE_LIMIT_DELAY,
|
||||
GitHubCrawler,
|
||||
RepositoryDownloader,
|
||||
crawl_source,
|
||||
main,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
# Backward-compat shim -- module moved to crawler/github.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("crawler.github")
|
||||
|
||||
@@ -1,411 +1,4 @@
|
||||
"""
|
||||
GitHub Crawler - Core Crawler and Downloader
|
||||
|
||||
GitHubCrawler for API-based repository crawling and RepositoryDownloader
|
||||
for ZIP-based local extraction.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from template_sources import SourceConfig
|
||||
from github_crawler_parsers import (
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
GITLAB_API_URL = "https://gitlab.com/api/v4"
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
|
||||
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
|
||||
REQUEST_TIMEOUT = 60.0
|
||||
RATE_LIMIT_DELAY = 1.0
|
||||
|
||||
|
||||
class GitHubCrawler:
|
||||
"""Crawl GitHub repositories for legal templates."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token or GITHUB_TOKEN
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "LegalTemplatesCrawler/1.0",
|
||||
}
|
||||
if self.token:
|
||||
self.headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=self.headers,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
|
||||
"""Parse repository URL into owner, repo, and host."""
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"Invalid repository URL: {url}")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
if 'gitlab' in parsed.netloc:
|
||||
host = 'gitlab'
|
||||
else:
|
||||
host = 'github'
|
||||
|
||||
return owner, repo, host
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch of a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("default_branch", "main")
|
||||
|
||||
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
|
||||
"""Get the latest commit SHA for a branch."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("sha", "")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str = "",
|
||||
branch: str = "main",
|
||||
patterns: List[str] = None,
|
||||
exclude_patterns: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files in a repository matching the given patterns."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
patterns = patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("tree", []):
|
||||
if item["type"] != "blob":
|
||||
continue
|
||||
|
||||
file_path = item["path"]
|
||||
|
||||
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
if item.get("size", 0) > MAX_FILE_SIZE:
|
||||
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
|
||||
continue
|
||||
|
||||
files.append({
|
||||
"path": file_path,
|
||||
"sha": item["sha"],
|
||||
"size": item.get("size", 0),
|
||||
"url": item.get("url", ""),
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
|
||||
"""Get the content of a file from a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def crawl_repository(
|
||||
self,
|
||||
source: SourceConfig,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a repository and yield extracted documents."""
|
||||
if not source.repo_url:
|
||||
logger.warning(f"No repo URL for source: {source.name}")
|
||||
return
|
||||
|
||||
try:
|
||||
owner, repo, host = self._parse_repo_url(source.repo_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
|
||||
return
|
||||
|
||||
if host == "gitlab":
|
||||
logger.info(f"GitLab repos not yet supported: {source.name}")
|
||||
return
|
||||
|
||||
logger.info(f"Crawling repository: {owner}/{repo}")
|
||||
|
||||
try:
|
||||
branch = await self.get_default_branch(owner, repo)
|
||||
commit_sha = await self.get_latest_commit(owner, repo, branch)
|
||||
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
files = await self.list_files(
|
||||
owner, repo,
|
||||
branch=branch,
|
||||
patterns=source.file_patterns,
|
||||
exclude_patterns=source.exclude_patterns,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(files)} matching files in {source.name}")
|
||||
|
||||
for file_info in files:
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
try:
|
||||
content = await self.get_file_content(
|
||||
owner, repo, file_info["path"], branch
|
||||
)
|
||||
|
||||
file_path = file_info["path"]
|
||||
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
|
||||
|
||||
if file_path.endswith('.md'):
|
||||
doc = MarkdownParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.html') or file_path.endswith('.htm'):
|
||||
doc = HTMLParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.json'):
|
||||
docs = JSONParser.parse(content, file_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.txt'):
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=Path(file_path).stem,
|
||||
file_path=file_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
source_commit=commit_sha,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch {file_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error crawling {source.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {source.name}: {e}")
|
||||
|
||||
|
||||
class RepositoryDownloader:
|
||||
"""Download and extract repository archives."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
|
||||
"""Download repository as ZIP and extract to temp directory."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
|
||||
|
||||
parsed = urlparse(repo_url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
|
||||
logger.info(f"Downloading ZIP from {zip_url}")
|
||||
|
||||
response = await self.http_client.get(zip_url)
|
||||
response.raise_for_status()
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
zip_path = temp_dir / f"{repo}.zip"
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
extract_dir = temp_dir / repo
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
|
||||
if extracted_dirs:
|
||||
return extracted_dirs[0]
|
||||
|
||||
return extract_dir
|
||||
|
||||
async def crawl_local_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
source: SourceConfig,
|
||||
base_url: str,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a local directory for documents."""
|
||||
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = source.exclude_patterns or []
|
||||
|
||||
for pattern in patterns:
|
||||
for file_path in directory.rglob(pattern.replace("**/", "")):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
rel_path = str(file_path.relative_to(directory))
|
||||
|
||||
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
if file_path.stat().st_size > MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = file_path.read_text(encoding='latin-1')
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
source_url = f"{base_url}/{rel_path}"
|
||||
|
||||
if file_path.suffix == '.md':
|
||||
doc = MarkdownParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix in ['.html', '.htm']:
|
||||
doc = HTMLParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.json':
|
||||
docs = JSONParser.parse(content, rel_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.txt':
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=file_path.stem,
|
||||
file_path=rel_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
def cleanup(self, directory: Path):
|
||||
"""Clean up temporary directory."""
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
|
||||
|
||||
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
|
||||
"""Crawl a source configuration and return all extracted documents."""
|
||||
documents = []
|
||||
|
||||
if source.repo_url:
|
||||
async with GitHubCrawler() as crawler:
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# CLI for testing
|
||||
async def main():
|
||||
"""Test crawler with a sample source."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
|
||||
|
||||
async with GitHubCrawler() as crawler:
|
||||
count = 0
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
count += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Path: {doc.file_path}")
|
||||
print(f"Type: {doc.file_type}")
|
||||
print(f"Language: {doc.language}")
|
||||
print(f"URL: {doc.source_url}")
|
||||
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
|
||||
print(f"Text preview: {doc.text[:200]}...")
|
||||
|
||||
print(f"\n\nTotal documents: {count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
# Backward-compat shim -- module moved to crawler/github_core.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("crawler.github_core")
|
||||
|
||||
@@ -1,303 +1,4 @@
|
||||
"""
|
||||
GitHub Crawler - Document Parsers
|
||||
|
||||
Markdown, HTML, and JSON parsers for extracting structured content
|
||||
from legal template documents.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedDocument:
|
||||
"""A document extracted from a repository."""
|
||||
text: str
|
||||
title: str
|
||||
file_path: str
|
||||
file_type: str # "markdown", "html", "json", "text"
|
||||
source_url: str
|
||||
source_commit: Optional[str] = None
|
||||
source_hash: str = "" # SHA256 of original content
|
||||
sections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
placeholders: List[str] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.source_hash and self.text:
|
||||
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""Parse Markdown files into structured content."""
|
||||
|
||||
# Common placeholder patterns
|
||||
PLACEHOLDER_PATTERNS = [
|
||||
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
|
||||
r'\{([a-z_]+)\}', # {company_name}
|
||||
r'\{\{([a-z_]+)\}\}', # {{company_name}}
|
||||
r'__([A-Z_]+)__', # __COMPANY_NAME__
|
||||
r'<([A-Z_]+)>', # <COMPANY_NAME>
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse markdown content into an ExtractedDocument."""
|
||||
title = cls._extract_title(content, filename)
|
||||
sections = cls._extract_sections(content)
|
||||
placeholders = cls._find_placeholders(content)
|
||||
language = cls._detect_language(content)
|
||||
clean_text = cls._clean_for_indexing(content)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=clean_text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="markdown",
|
||||
source_url="",
|
||||
sections=sections,
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_title(cls, content: str, filename: str) -> str:
|
||||
"""Extract title from markdown heading or filename."""
|
||||
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
|
||||
frontmatter_match = re.search(
|
||||
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if frontmatter_match:
|
||||
return frontmatter_match.group(1).strip()
|
||||
|
||||
if filename:
|
||||
name = Path(filename).stem
|
||||
return name.replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
return "Untitled"
|
||||
|
||||
@classmethod
|
||||
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract sections from markdown content."""
|
||||
sections = []
|
||||
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
|
||||
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = current_section["content"].strip()
|
||||
sections.append(current_section.copy())
|
||||
|
||||
level = len(match.group(1))
|
||||
heading = match.group(2).strip()
|
||||
current_section = {
|
||||
"heading": heading,
|
||||
"level": level,
|
||||
"content": "",
|
||||
"start": match.end(),
|
||||
}
|
||||
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = content[current_section["start"]:].strip()
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def _find_placeholders(cls, content: str) -> List[str]:
|
||||
"""Find placeholder patterns in content."""
|
||||
placeholders = set()
|
||||
for pattern in cls.PLACEHOLDER_PATTERNS:
|
||||
for match in re.finditer(pattern, content):
|
||||
placeholder = match.group(0)
|
||||
placeholders.add(placeholder)
|
||||
return sorted(list(placeholders))
|
||||
|
||||
@classmethod
|
||||
def _detect_language(cls, content: str) -> str:
|
||||
"""Detect language from content."""
|
||||
german_indicators = [
|
||||
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
|
||||
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
|
||||
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
|
||||
]
|
||||
|
||||
lower_content = content.lower()
|
||||
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
|
||||
|
||||
if german_count >= 3:
|
||||
return "de"
|
||||
return "en"
|
||||
|
||||
@classmethod
|
||||
def _clean_for_indexing(cls, content: str) -> str:
|
||||
"""Clean markdown content for text indexing."""
|
||||
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<[^>]+>', '', content)
|
||||
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
|
||||
content = re.sub(r'\*(.+?)\*', r'\1', content)
|
||||
content = re.sub(r'`(.+?)`', r'\1', content)
|
||||
content = re.sub(r'~~(.+?)~~', r'\1', content)
|
||||
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' +', ' ', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
class HTMLParser:
|
||||
"""Parse HTML files into structured content."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse HTML content into an ExtractedDocument."""
|
||||
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
|
||||
title = title_match.group(1) if title_match else Path(filename).stem
|
||||
|
||||
text = cls._html_to_text(content)
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
|
||||
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
|
||||
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="html",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _html_to_text(cls, html: str) -> str:
|
||||
"""Convert HTML to clean text."""
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
|
||||
|
||||
html = html.replace(' ', ' ')
|
||||
html = html.replace('&', '&')
|
||||
html = html.replace('<', '<')
|
||||
html = html.replace('>', '>')
|
||||
html = html.replace('"', '"')
|
||||
html = html.replace(''', "'")
|
||||
|
||||
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
|
||||
|
||||
html = re.sub(r'<[^>]+>', '', html)
|
||||
|
||||
html = re.sub(r'[ \t]+', ' ', html)
|
||||
html = re.sub(r'\n[ \t]+', '\n', html)
|
||||
html = re.sub(r'[ \t]+\n', '\n', html)
|
||||
html = re.sub(r'\n{3,}', '\n\n', html)
|
||||
|
||||
return html.strip()
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""Parse JSON files containing legal template data."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
|
||||
"""Parse JSON content into ExtractedDocuments."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
documents.extend(cls._parse_dict(data, filename))
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
docs = cls._parse_dict(item, f"{filename}[{i}]")
|
||||
documents.extend(docs)
|
||||
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
|
||||
"""Parse a dictionary into documents."""
|
||||
documents = []
|
||||
|
||||
text_keys = ['text', 'content', 'body', 'description', 'value']
|
||||
title_keys = ['title', 'name', 'heading', 'label', 'key']
|
||||
|
||||
text = ""
|
||||
for key in text_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
text = data[key]
|
||||
break
|
||||
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(item, str) and len(item) > 50:
|
||||
documents.append(ExtractedDocument(
|
||||
text=item,
|
||||
title=f"{key} {i+1}",
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
language=MarkdownParser._detect_language(item),
|
||||
))
|
||||
return documents
|
||||
|
||||
title = ""
|
||||
for key in title_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
title = data[key]
|
||||
break
|
||||
|
||||
if not title:
|
||||
title = Path(filename).stem
|
||||
|
||||
metadata = {}
|
||||
for key, value in data.items():
|
||||
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
|
||||
metadata[key] = value
|
||||
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
|
||||
|
||||
documents.append(ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
metadata=metadata,
|
||||
))
|
||||
|
||||
return documents
|
||||
# Backward-compat shim -- module moved to crawler/github_parsers.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("crawler.github_parsers")
|
||||
|
||||
425
klausur-service/backend/grid/unified.py
Normal file
425
klausur-service/backend/grid/unified.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unified Grid Builder — merges multi-zone grid into a single Excel-like grid.
|
||||
|
||||
Takes content zone + box zones and produces one unified zone where:
|
||||
- All content rows use the dominant row height
|
||||
- Full-width boxes are integrated directly (box rows replace standard rows)
|
||||
- Partial-width boxes: extra rows inserted if box has more lines than standard
|
||||
- Box-origin cells carry metadata (bg_color, border) for visual distinction
|
||||
|
||||
The result is a single-zone StructuredGrid that can be:
|
||||
- Rendered in an Excel-like editor
|
||||
- Exported to Excel/CSV
|
||||
- Edited with unified row/column numbering
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import statistics
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _compute_dominant_row_height(content_zone: Dict) -> float:
|
||||
"""Median of content row-to-row spacings, excluding box-gap jumps."""
|
||||
rows = content_zone.get("rows", [])
|
||||
if len(rows) < 2:
|
||||
return 47.0
|
||||
|
||||
spacings = []
|
||||
for i in range(len(rows) - 1):
|
||||
y1 = rows[i].get("y_min_px", rows[i].get("y_min", 0))
|
||||
y2 = rows[i + 1].get("y_min_px", rows[i + 1].get("y_min", 0))
|
||||
d = y2 - y1
|
||||
if 0 < d < 100: # exclude box-gap jumps
|
||||
spacings.append(d)
|
||||
|
||||
if not spacings:
|
||||
return 47.0
|
||||
spacings.sort()
|
||||
return spacings[len(spacings) // 2]
|
||||
|
||||
|
||||
def _classify_boxes(
|
||||
box_zones: List[Dict],
|
||||
content_width: float,
|
||||
) -> List[Dict]:
|
||||
"""Classify each box as full_width or partial_width."""
|
||||
result = []
|
||||
for bz in box_zones:
|
||||
bb = bz.get("bbox_px", {})
|
||||
bw = bb.get("w", 0)
|
||||
bx = bb.get("x", 0)
|
||||
|
||||
if bw >= content_width * 0.85:
|
||||
classification = "full_width"
|
||||
side = "center"
|
||||
else:
|
||||
classification = "partial_width"
|
||||
# Determine which side of the page the box is on
|
||||
page_center = content_width / 2
|
||||
box_center = bx + bw / 2
|
||||
side = "right" if box_center > page_center else "left"
|
||||
|
||||
# Count total text lines in box (including \n within cells)
|
||||
total_lines = sum(
|
||||
(c.get("text", "").count("\n") + 1)
|
||||
for c in bz.get("cells", [])
|
||||
)
|
||||
|
||||
result.append({
|
||||
"zone": bz,
|
||||
"classification": classification,
|
||||
"side": side,
|
||||
"y_start": bb.get("y", 0),
|
||||
"y_end": bb.get("y", 0) + bb.get("h", 0),
|
||||
"total_lines": total_lines,
|
||||
"bg_hex": bz.get("box_bg_hex", ""),
|
||||
"bg_color": bz.get("box_bg_color", ""),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def build_unified_grid(
|
||||
zones: List[Dict],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
layout_metrics: Dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a single-zone unified grid from multi-zone grid data.
|
||||
|
||||
Returns a StructuredGrid with one zone containing all rows and cells.
|
||||
"""
|
||||
content_zone = None
|
||||
box_zones = []
|
||||
for z in zones:
|
||||
if z.get("zone_type") == "content":
|
||||
content_zone = z
|
||||
elif z.get("zone_type") == "box":
|
||||
box_zones.append(z)
|
||||
|
||||
if not content_zone:
|
||||
logger.warning("build_unified_grid: no content zone found")
|
||||
return {"zones": zones} # fallback: return as-is
|
||||
|
||||
box_zones.sort(key=lambda b: b.get("bbox_px", {}).get("y", 0))
|
||||
|
||||
dominant_h = _compute_dominant_row_height(content_zone)
|
||||
content_bbox = content_zone.get("bbox_px", {})
|
||||
content_width = content_bbox.get("w", image_width)
|
||||
content_x = content_bbox.get("x", 0)
|
||||
content_cols = content_zone.get("columns", [])
|
||||
num_cols = len(content_cols)
|
||||
|
||||
box_infos = _classify_boxes(box_zones, content_width)
|
||||
|
||||
logger.info(
|
||||
"build_unified_grid: dominant_h=%.1f, %d content rows, %d boxes (%s)",
|
||||
dominant_h, len(content_zone.get("rows", [])), len(box_infos),
|
||||
[b["classification"] for b in box_infos],
|
||||
)
|
||||
|
||||
# --- Build unified row list + cell list ---
|
||||
unified_rows: List[Dict] = []
|
||||
unified_cells: List[Dict] = []
|
||||
unified_row_idx = 0
|
||||
|
||||
# Content rows and cells indexed by row_index
|
||||
content_rows = content_zone.get("rows", [])
|
||||
content_cells = content_zone.get("cells", [])
|
||||
content_cells_by_row: Dict[int, List[Dict]] = {}
|
||||
for c in content_cells:
|
||||
content_cells_by_row.setdefault(c.get("row_index", -1), []).append(c)
|
||||
|
||||
# Track which content rows we've processed
|
||||
content_row_ptr = 0
|
||||
|
||||
for bi, box_info in enumerate(box_infos):
|
||||
bz = box_info["zone"]
|
||||
by_start = box_info["y_start"]
|
||||
by_end = box_info["y_end"]
|
||||
|
||||
# --- Add content rows ABOVE this box ---
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry >= by_start:
|
||||
break
|
||||
# Add this content row
|
||||
_add_content_row(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
cr, content_cells_by_row, dominant_h, image_height,
|
||||
)
|
||||
unified_row_idx += 1
|
||||
content_row_ptr += 1
|
||||
|
||||
# --- Add box rows ---
|
||||
if box_info["classification"] == "full_width":
|
||||
# Full-width box: integrate box rows directly
|
||||
_add_full_width_box(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
bz, box_info, dominant_h, num_cols, image_height,
|
||||
)
|
||||
unified_row_idx += len(bz.get("rows", []))
|
||||
# Skip content rows that overlap with this box
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry > by_end:
|
||||
break
|
||||
content_row_ptr += 1
|
||||
|
||||
else:
|
||||
# Partial-width box: merge with adjacent content rows
|
||||
unified_row_idx = _add_partial_width_box(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
bz, box_info, content_rows, content_cells_by_row,
|
||||
content_row_ptr, dominant_h, num_cols, image_height,
|
||||
content_x, content_width,
|
||||
)
|
||||
# Advance content pointer past box region
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry > by_end:
|
||||
break
|
||||
content_row_ptr += 1
|
||||
|
||||
# --- Add remaining content rows BELOW all boxes ---
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
_add_content_row(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
cr, content_cells_by_row, dominant_h, image_height,
|
||||
)
|
||||
unified_row_idx += 1
|
||||
content_row_ptr += 1
|
||||
|
||||
# --- Build unified zone ---
|
||||
unified_zone = {
|
||||
"zone_index": 0,
|
||||
"zone_type": "unified",
|
||||
"bbox_px": content_bbox,
|
||||
"bbox_pct": content_zone.get("bbox_pct", {}),
|
||||
"border": None,
|
||||
"word_count": sum(len(c.get("word_boxes", [])) for c in unified_cells),
|
||||
"columns": content_cols,
|
||||
"rows": unified_rows,
|
||||
"cells": unified_cells,
|
||||
"header_rows": [],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"build_unified_grid: %d unified rows, %d cells (from %d content + %d box zones)",
|
||||
len(unified_rows), len(unified_cells),
|
||||
len(content_rows), len(box_zones),
|
||||
)
|
||||
|
||||
return {
|
||||
"zones": [unified_zone],
|
||||
"image_width": image_width,
|
||||
"image_height": image_height,
|
||||
"layout_metrics": layout_metrics,
|
||||
"summary": {
|
||||
"total_zones": 1,
|
||||
"total_columns": num_cols,
|
||||
"total_rows": len(unified_rows),
|
||||
"total_cells": len(unified_cells),
|
||||
},
|
||||
"is_unified": True,
|
||||
"dominant_row_h": dominant_h,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_row(idx: int, y: float, h: float, img_h: int, is_header: bool = False) -> Dict:
|
||||
return {
|
||||
"index": idx,
|
||||
"row_index": idx,
|
||||
"y_min_px": round(y),
|
||||
"y_max_px": round(y + h),
|
||||
"y_min_pct": round(y / img_h * 100, 2) if img_h else 0,
|
||||
"y_max_pct": round((y + h) / img_h * 100, 2) if img_h else 0,
|
||||
"is_header": is_header,
|
||||
}
|
||||
|
||||
|
||||
def _remap_cell(cell: Dict, new_row: int, new_col: int = None,
|
||||
source_type: str = "content", box_region: Dict = None) -> Dict:
|
||||
"""Create a new cell dict with remapped indices."""
|
||||
c = dict(cell)
|
||||
c["row_index"] = new_row
|
||||
if new_col is not None:
|
||||
c["col_index"] = new_col
|
||||
c["cell_id"] = f"U_R{new_row:02d}_C{c.get('col_index', 0)}"
|
||||
c["source_zone_type"] = source_type
|
||||
if box_region:
|
||||
c["box_region"] = box_region
|
||||
return c
|
||||
|
||||
|
||||
def _add_content_row(
|
||||
unified_rows, unified_cells, row_idx,
|
||||
content_row, cells_by_row, dominant_h, img_h,
|
||||
):
|
||||
"""Add a single content row to the unified grid."""
|
||||
y = content_row.get("y_min_px", content_row.get("y_min", 0))
|
||||
is_hdr = content_row.get("is_header", False)
|
||||
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h, is_hdr))
|
||||
|
||||
for cell in cells_by_row.get(content_row.get("index", -1), []):
|
||||
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
|
||||
|
||||
|
||||
def _add_full_width_box(
|
||||
unified_rows, unified_cells, start_row_idx,
|
||||
box_zone, box_info, dominant_h, num_cols, img_h,
|
||||
):
|
||||
"""Add a full-width box's rows to the unified grid."""
|
||||
box_rows = box_zone.get("rows", [])
|
||||
box_cells = box_zone.get("cells", [])
|
||||
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
|
||||
|
||||
# Distribute box height evenly among its rows
|
||||
box_h = box_info["y_end"] - box_info["y_start"]
|
||||
row_h = box_h / len(box_rows) if box_rows else dominant_h
|
||||
|
||||
for i, br in enumerate(box_rows):
|
||||
y = box_info["y_start"] + i * row_h
|
||||
new_idx = start_row_idx + i
|
||||
is_hdr = br.get("is_header", False)
|
||||
unified_rows.append(_make_row(new_idx, y, row_h, img_h, is_hdr))
|
||||
|
||||
for cell in box_cells:
|
||||
if cell.get("row_index") == br.get("index", i):
|
||||
unified_cells.append(
|
||||
_remap_cell(cell, new_idx, source_type="box", box_region=box_region)
|
||||
)
|
||||
|
||||
|
||||
def _add_partial_width_box(
|
||||
unified_rows, unified_cells, start_row_idx,
|
||||
box_zone, box_info, content_rows, content_cells_by_row,
|
||||
content_row_ptr, dominant_h, num_cols, img_h,
|
||||
content_x, content_width,
|
||||
) -> int:
|
||||
"""Add a partial-width box merged with content rows.
|
||||
|
||||
Returns the next unified_row_idx after processing.
|
||||
"""
|
||||
by_start = box_info["y_start"]
|
||||
by_end = box_info["y_end"]
|
||||
box_h = by_end - by_start
|
||||
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
|
||||
|
||||
# Content rows in the box's Y range
|
||||
overlap_content_rows = []
|
||||
ptr = content_row_ptr
|
||||
while ptr < len(content_rows):
|
||||
cr = content_rows[ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry > by_end:
|
||||
break
|
||||
if cry >= by_start:
|
||||
overlap_content_rows.append(cr)
|
||||
ptr += 1
|
||||
|
||||
# How many standard rows fit in the box height
|
||||
standard_rows = max(1, math.floor(box_h / dominant_h))
|
||||
# How many text lines the box actually has
|
||||
box_text_lines = box_info["total_lines"]
|
||||
# Extra rows needed
|
||||
extra_rows = max(0, box_text_lines - standard_rows)
|
||||
total_rows_for_region = standard_rows + extra_rows
|
||||
|
||||
logger.info(
|
||||
"partial box: standard=%d, box_lines=%d, extra=%d, content_overlap=%d",
|
||||
standard_rows, box_text_lines, extra_rows, len(overlap_content_rows),
|
||||
)
|
||||
|
||||
# Determine which columns the box occupies
|
||||
box_bb = box_zone.get("bbox_px", {})
|
||||
box_x = box_bb.get("x", 0)
|
||||
box_w = box_bb.get("w", 0)
|
||||
|
||||
# Map box to content columns: find which content columns overlap
|
||||
box_col_start = 0
|
||||
box_col_end = num_cols
|
||||
content_cols_list = []
|
||||
for z_col_idx in range(num_cols):
|
||||
# Find the column definition by checking all column entries
|
||||
# Simple heuristic: if box starts past halfway, it's the right columns
|
||||
pass
|
||||
|
||||
# Simpler approach: box on right side → last N columns
|
||||
# box on left side → first N columns
|
||||
if box_info["side"] == "right":
|
||||
# Box starts at x=box_x. Find first content column that overlaps
|
||||
box_col_start = num_cols # default: beyond all columns
|
||||
for z in (box_zone.get("columns") or [{"index": 0}]):
|
||||
pass
|
||||
# Use content column positions to determine overlap
|
||||
content_cols_data = [
|
||||
{"idx": c.get("index", i), "x_min": c.get("x_min_px", 0), "x_max": c.get("x_max_px", 0)}
|
||||
for i, c in enumerate(content_rows[0:0] or []) # placeholder
|
||||
]
|
||||
# Simple: split columns at midpoint
|
||||
box_col_start = num_cols // 2 # right half
|
||||
box_col_end = num_cols
|
||||
else:
|
||||
box_col_start = 0
|
||||
box_col_end = num_cols // 2
|
||||
|
||||
# Build rows for this region
|
||||
box_cells = box_zone.get("cells", [])
|
||||
box_rows = box_zone.get("rows", [])
|
||||
row_idx = start_row_idx
|
||||
|
||||
# Expand box cell texts with \n into individual lines for row mapping
|
||||
box_lines: List[Tuple[str, Dict]] = [] # (text_line, parent_cell)
|
||||
for bc in sorted(box_cells, key=lambda c: c.get("row_index", 0)):
|
||||
text = bc.get("text", "")
|
||||
for line in text.split("\n"):
|
||||
box_lines.append((line.strip(), bc))
|
||||
|
||||
for i in range(total_rows_for_region):
|
||||
y = by_start + i * dominant_h
|
||||
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h))
|
||||
|
||||
# Content cells for this row (from overlapping content rows)
|
||||
if i < len(overlap_content_rows):
|
||||
cr = overlap_content_rows[i]
|
||||
for cell in content_cells_by_row.get(cr.get("index", -1), []):
|
||||
# Only include cells from columns NOT covered by the box
|
||||
ci = cell.get("col_index", 0)
|
||||
if ci < box_col_start or ci >= box_col_end:
|
||||
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
|
||||
|
||||
# Box cell for this row
|
||||
if i < len(box_lines):
|
||||
line_text, parent_cell = box_lines[i]
|
||||
box_cell = {
|
||||
"cell_id": f"U_R{row_idx:02d}_C{box_col_start}",
|
||||
"row_index": row_idx,
|
||||
"col_index": box_col_start,
|
||||
"col_type": "spanning_header" if (box_col_end - box_col_start) > 1 else parent_cell.get("col_type", "column_1"),
|
||||
"colspan": box_col_end - box_col_start,
|
||||
"text": line_text,
|
||||
"confidence": parent_cell.get("confidence", 0),
|
||||
"bbox_px": parent_cell.get("bbox_px", {}),
|
||||
"bbox_pct": parent_cell.get("bbox_pct", {}),
|
||||
"word_boxes": [],
|
||||
"ocr_engine": parent_cell.get("ocr_engine", ""),
|
||||
"is_bold": parent_cell.get("is_bold", False),
|
||||
"source_zone_type": "box",
|
||||
"box_region": box_region,
|
||||
}
|
||||
unified_cells.append(box_cell)
|
||||
|
||||
row_idx += 1
|
||||
|
||||
return row_idx
|
||||
@@ -1,276 +1,4 @@
|
||||
"""
|
||||
Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen.
|
||||
|
||||
Endpoints:
|
||||
- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text
|
||||
- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen
|
||||
|
||||
Modell-Strategie:
|
||||
1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM)
|
||||
2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama)
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/htr", tags=["HTR"])
|
||||
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HTRSessionRequest(BaseModel):
|
||||
session_id: str
|
||||
model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large"
|
||||
use_clean: bool = True # Prefer clean_png (after handwriting removal)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
CLAHE contrast enhancement + upscale to improve HTR accuracy.
|
||||
Returns grayscale enhanced image.
|
||||
"""
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# Upscale if image is too small
|
||||
h, w = enhanced.shape
|
||||
if min(h, w) < 800:
|
||||
scale = 800 / min(h, w)
|
||||
enhanced = cv2.resize(
|
||||
enhanced, None, fx=scale, fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
|
||||
return enhanced
|
||||
|
||||
|
||||
def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes:
|
||||
"""Convert BGR ndarray to PNG bytes."""
|
||||
success, buf = cv2.imencode(".png", img_bgr)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to encode image to PNG")
|
||||
return buf.tobytes()
|
||||
|
||||
|
||||
def _preprocess_image_bytes(image_bytes: bytes) -> bytes:
|
||||
"""Load image, apply HTR preprocessing, return PNG bytes."""
|
||||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise ValueError("Could not decode image")
|
||||
|
||||
enhanced = _preprocess_for_htr(img_bgr)
|
||||
# Convert grayscale back to BGR for encoding
|
||||
enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
|
||||
return _bgr_to_png_bytes(enhanced_bgr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend: Ollama qwen2.5vl
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]:
|
||||
"""
|
||||
Send image to Ollama qwen2.5vl:32b for HTR.
|
||||
Returns extracted text or None on error.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
lang_hint = {
|
||||
"de": "Deutsch",
|
||||
"en": "Englisch",
|
||||
"de+en": "Deutsch und Englisch",
|
||||
}.get(language, "Deutsch")
|
||||
|
||||
prompt = (
|
||||
f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. "
|
||||
"Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. "
|
||||
"Antworte NUR mit dem erkannten Text, ohne Erklaerungen."
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"model": OLLAMA_HTR_MODEL,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("response", "").strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"Ollama qwen2.5vl HTR failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend: TrOCR-large fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]:
|
||||
"""
|
||||
Use microsoft/trocr-large-handwritten via trocr_service.py.
|
||||
Returns extracted text or None on error.
|
||||
"""
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr, _check_trocr_available
|
||||
if not _check_trocr_available():
|
||||
logger.warning("TrOCR not available for HTR fallback")
|
||||
return None
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large")
|
||||
return text.strip() if text else None
|
||||
except Exception as e:
|
||||
logger.warning(f"TrOCR-large HTR failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core recognition logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _do_recognize(
|
||||
image_bytes: bytes,
|
||||
model: str = "auto",
|
||||
preprocess: bool = True,
|
||||
language: str = "de",
|
||||
) -> dict:
|
||||
"""
|
||||
Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large.
|
||||
Returns dict with text, model_used, processing_time_ms.
|
||||
"""
|
||||
t0 = time.monotonic()
|
||||
|
||||
if preprocess:
|
||||
try:
|
||||
image_bytes = _preprocess_image_bytes(image_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"HTR preprocessing failed, using raw image: {e}")
|
||||
|
||||
text: Optional[str] = None
|
||||
model_used: str = "none"
|
||||
|
||||
use_qwen = model in ("auto", "qwen2.5vl")
|
||||
use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None)
|
||||
|
||||
if use_qwen:
|
||||
text = await _recognize_with_qwen_vl(image_bytes, language)
|
||||
if text is not None:
|
||||
model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})"
|
||||
|
||||
if text is None and (use_trocr or model == "trocr-large"):
|
||||
text = await _recognize_with_trocr_large(image_bytes)
|
||||
if text is not None:
|
||||
model_used = "trocr-large-handwritten"
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
model_used = "none (all backends failed)"
|
||||
|
||||
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"model_used": model_used,
|
||||
"processing_time_ms": elapsed_ms,
|
||||
"language": language,
|
||||
"preprocessed": preprocess,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/recognize")
|
||||
async def recognize_handwriting(
|
||||
file: UploadFile = File(...),
|
||||
model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"),
|
||||
preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"),
|
||||
language: str = Query("de", description="de | en | de+en"),
|
||||
):
|
||||
"""
|
||||
Upload an image and get back the handwritten text as plain text.
|
||||
|
||||
Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten.
|
||||
"""
|
||||
if model not in ("auto", "qwen2.5vl", "trocr-large"):
|
||||
raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large")
|
||||
if language not in ("de", "en", "de+en"):
|
||||
raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en")
|
||||
|
||||
image_bytes = await file.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language)
|
||||
|
||||
|
||||
@router.post("/recognize-session")
|
||||
async def recognize_from_session(req: HTRSessionRequest):
|
||||
"""
|
||||
Use an OCR-Pipeline session as image source for HTR.
|
||||
|
||||
Set use_clean=true to prefer the clean image (after handwriting removal step).
|
||||
This is useful when you want to do HTR on isolated handwriting regions.
|
||||
"""
|
||||
from ocr_pipeline_session_store import get_session_db, get_session_image
|
||||
|
||||
session = await get_session_db(req.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found")
|
||||
|
||||
# Choose source image
|
||||
image_bytes: Optional[bytes] = None
|
||||
source_used: str = ""
|
||||
|
||||
if req.use_clean:
|
||||
image_bytes = await get_session_image(req.session_id, "clean")
|
||||
if image_bytes:
|
||||
source_used = "clean"
|
||||
|
||||
if not image_bytes:
|
||||
image_bytes = await get_session_image(req.session_id, "deskewed")
|
||||
if image_bytes:
|
||||
source_used = "deskewed"
|
||||
|
||||
if not image_bytes:
|
||||
image_bytes = await get_session_image(req.session_id, "original")
|
||||
source_used = "original"
|
||||
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="No image available in session")
|
||||
|
||||
result = await _do_recognize(image_bytes, model=req.model)
|
||||
result["session_id"] = req.session_id
|
||||
result["source_image"] = source_used
|
||||
return result
|
||||
# Backward-compat shim -- module moved to ocr/pipeline/htr_api.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.htr_api")
|
||||
|
||||
346
klausur-service/backend/ocr/engines/tesseract_extractor.py
Normal file
346
klausur-service/backend/ocr/engines/tesseract_extractor.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Tesseract-based OCR extraction with word-level bounding boxes.
|
||||
|
||||
Uses Tesseract for spatial information (WHERE text is) while
|
||||
the Vision LLM handles semantic understanding (WHAT the text means).
|
||||
|
||||
Tesseract runs natively on ARM64 via Debian's apt package.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
TESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TESSERACT_AVAILABLE = False
|
||||
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
|
||||
|
||||
|
||||
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||
"""Run Tesseract OCR and return word-level bounding boxes.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG/JPEG image as bytes.
|
||||
lang: Tesseract language string (e.g. "eng+deu").
|
||||
|
||||
Returns:
|
||||
Dict with 'words' list and 'image_width'/'image_height'.
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
|
||||
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
|
||||
|
||||
words = []
|
||||
for i in range(len(data['text'])):
|
||||
text = data['text'][i].strip()
|
||||
conf = int(data['conf'][i])
|
||||
if not text or conf < 20:
|
||||
continue
|
||||
words.append({
|
||||
"text": text,
|
||||
"left": data['left'][i],
|
||||
"top": data['top'][i],
|
||||
"width": data['width'][i],
|
||||
"height": data['height'][i],
|
||||
"conf": conf,
|
||||
"block_num": data['block_num'][i],
|
||||
"par_num": data['par_num'][i],
|
||||
"line_num": data['line_num'][i],
|
||||
"word_num": data['word_num'][i],
|
||||
})
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"image_width": image.width,
|
||||
"image_height": image.height,
|
||||
}
|
||||
|
||||
|
||||
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
|
||||
"""Group words by their Y position into lines.
|
||||
|
||||
Args:
|
||||
words: List of word dicts from extract_bounding_boxes.
|
||||
y_tolerance_px: Max pixel distance to consider words on the same line.
|
||||
|
||||
Returns:
|
||||
List of lines, each line is a list of words sorted by X position.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
|
||||
# Sort by Y then X
|
||||
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
|
||||
|
||||
lines: List[List[dict]] = []
|
||||
current_line: List[dict] = [sorted_words[0]]
|
||||
current_y = sorted_words[0]['top']
|
||||
|
||||
for word in sorted_words[1:]:
|
||||
if abs(word['top'] - current_y) <= y_tolerance_px:
|
||||
current_line.append(word)
|
||||
else:
|
||||
current_line.sort(key=lambda w: w['left'])
|
||||
lines.append(current_line)
|
||||
current_line = [word]
|
||||
current_y = word['top']
|
||||
|
||||
if current_line:
|
||||
current_line.sort(key=lambda w: w['left'])
|
||||
lines.append(current_line)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
|
||||
"""Detect column boundaries from word positions.
|
||||
|
||||
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
|
||||
|
||||
Returns:
|
||||
Dict with column boundaries and type assignments.
|
||||
"""
|
||||
if not lines or image_width == 0:
|
||||
return {"columns": [], "column_types": []}
|
||||
|
||||
# Collect all word X positions
|
||||
all_x_positions = []
|
||||
for line in lines:
|
||||
for word in line:
|
||||
all_x_positions.append(word['left'])
|
||||
|
||||
if not all_x_positions:
|
||||
return {"columns": [], "column_types": []}
|
||||
|
||||
# Find X-position clusters (column starts)
|
||||
all_x_positions.sort()
|
||||
|
||||
# Simple gap-based column detection
|
||||
min_gap = image_width * 0.08 # 8% of page width = column gap
|
||||
clusters = []
|
||||
current_cluster = [all_x_positions[0]]
|
||||
|
||||
for x in all_x_positions[1:]:
|
||||
if x - current_cluster[-1] > min_gap:
|
||||
clusters.append(current_cluster)
|
||||
current_cluster = [x]
|
||||
else:
|
||||
current_cluster.append(x)
|
||||
|
||||
if current_cluster:
|
||||
clusters.append(current_cluster)
|
||||
|
||||
# Each cluster represents a column start
|
||||
columns = []
|
||||
for cluster in clusters:
|
||||
col_start = min(cluster)
|
||||
columns.append({
|
||||
"x_start": col_start,
|
||||
"x_start_pct": col_start / image_width * 100,
|
||||
"word_count": len(cluster),
|
||||
})
|
||||
|
||||
# Assign column types based on position (left→right: EN, DE, Example)
|
||||
type_map = ["english", "german", "example"]
|
||||
column_types = []
|
||||
for i, col in enumerate(columns):
|
||||
if i < len(type_map):
|
||||
column_types.append(type_map[i])
|
||||
else:
|
||||
column_types.append("unknown")
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
}
|
||||
|
||||
|
||||
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
|
||||
column_types: List[str], image_width: int,
|
||||
image_height: int) -> List[dict]:
|
||||
"""Convert grouped words into vocabulary entries using column positions.
|
||||
|
||||
Args:
|
||||
lines: Grouped word lines from group_words_into_lines.
|
||||
columns: Column boundaries from detect_columns.
|
||||
column_types: Column type assignments.
|
||||
image_width: Image width in pixels.
|
||||
image_height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
List of vocabulary entry dicts with english/german/example fields.
|
||||
"""
|
||||
if not columns or not lines:
|
||||
return []
|
||||
|
||||
# Build column boundaries for word assignment
|
||||
col_boundaries = []
|
||||
for i, col in enumerate(columns):
|
||||
start = col['x_start']
|
||||
if i + 1 < len(columns):
|
||||
end = columns[i + 1]['x_start']
|
||||
else:
|
||||
end = image_width
|
||||
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
|
||||
|
||||
entries = []
|
||||
for line in lines:
|
||||
entry = {"english": "", "german": "", "example": ""}
|
||||
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
|
||||
line_bbox: Dict[str, Optional[dict]] = {}
|
||||
|
||||
for word in line:
|
||||
word_center_x = word['left'] + word['width'] / 2
|
||||
assigned_type = "unknown"
|
||||
for start, end, col_type in col_boundaries:
|
||||
if start <= word_center_x < end:
|
||||
assigned_type = col_type
|
||||
break
|
||||
|
||||
if assigned_type in line_words_by_col:
|
||||
line_words_by_col[assigned_type].append(word['text'])
|
||||
# Track bounding box for the column
|
||||
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
|
||||
line_bbox[assigned_type] = {
|
||||
"left": word['left'],
|
||||
"top": word['top'],
|
||||
"right": word['left'] + word['width'],
|
||||
"bottom": word['top'] + word['height'],
|
||||
}
|
||||
else:
|
||||
bb = line_bbox[assigned_type]
|
||||
bb['left'] = min(bb['left'], word['left'])
|
||||
bb['top'] = min(bb['top'], word['top'])
|
||||
bb['right'] = max(bb['right'], word['left'] + word['width'])
|
||||
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
|
||||
|
||||
for col_type in ["english", "german", "example"]:
|
||||
if line_words_by_col[col_type]:
|
||||
entry[col_type] = " ".join(line_words_by_col[col_type])
|
||||
if line_bbox.get(col_type):
|
||||
bb = line_bbox[col_type]
|
||||
entry[f"{col_type}_bbox"] = {
|
||||
"x_pct": bb['left'] / image_width * 100,
|
||||
"y_pct": bb['top'] / image_height * 100,
|
||||
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
|
||||
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
|
||||
}
|
||||
|
||||
# Only add if at least one column has content
|
||||
if entry["english"] or entry["german"]:
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
|
||||
image_w: int, image_h: int,
|
||||
threshold: float = 0.6) -> List[dict]:
|
||||
"""Match Tesseract bounding boxes to LLM vocabulary entries.
|
||||
|
||||
For each LLM vocab entry, find the best-matching Tesseract word
|
||||
and attach its bounding box coordinates.
|
||||
|
||||
Args:
|
||||
tess_words: Word list from Tesseract with pixel coordinates.
|
||||
llm_vocab: Vocabulary list from Vision LLM.
|
||||
image_w: Image width in pixels.
|
||||
image_h: Image height in pixels.
|
||||
threshold: Minimum similarity ratio for a match.
|
||||
|
||||
Returns:
|
||||
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
|
||||
"""
|
||||
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
|
||||
return llm_vocab
|
||||
|
||||
for entry in llm_vocab:
|
||||
english = entry.get("english", "").lower().strip()
|
||||
german = entry.get("german", "").lower().strip()
|
||||
|
||||
if not english and not german:
|
||||
continue
|
||||
|
||||
# Try to match English word first, then German
|
||||
for field in ["english", "german"]:
|
||||
search_text = entry.get(field, "").lower().strip()
|
||||
if not search_text:
|
||||
continue
|
||||
|
||||
best_word = None
|
||||
best_ratio = 0.0
|
||||
|
||||
for word in tess_words:
|
||||
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_word = word
|
||||
|
||||
if best_word and best_ratio >= threshold:
|
||||
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
|
||||
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
|
||||
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
|
||||
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
|
||||
entry["bbox_match_field"] = field
|
||||
entry["bbox_match_ratio"] = round(best_ratio, 3)
|
||||
break # Found a match, no need to try the other field
|
||||
|
||||
return llm_vocab
|
||||
|
||||
|
||||
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG/JPEG image as bytes.
|
||||
lang: Tesseract language string.
|
||||
|
||||
Returns:
|
||||
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
|
||||
"""
|
||||
# Step 1: Extract bounding boxes
|
||||
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
|
||||
|
||||
if bbox_data.get("error"):
|
||||
return bbox_data
|
||||
|
||||
words = bbox_data["words"]
|
||||
image_w = bbox_data["image_width"]
|
||||
image_h = bbox_data["image_height"]
|
||||
|
||||
# Step 2: Group into lines
|
||||
lines = group_words_into_lines(words)
|
||||
|
||||
# Step 3: Detect columns
|
||||
col_info = detect_columns(lines, image_w)
|
||||
|
||||
# Step 4: Build vocabulary entries
|
||||
vocab = words_to_vocab_entries(
|
||||
lines,
|
||||
col_info["columns"],
|
||||
col_info["column_types"],
|
||||
image_w,
|
||||
image_h,
|
||||
)
|
||||
|
||||
return {
|
||||
"vocabulary": vocab,
|
||||
"words": words,
|
||||
"lines_count": len(lines),
|
||||
"columns": col_info["columns"],
|
||||
"column_types": col_info["column_types"],
|
||||
"image_width": image_w,
|
||||
"image_height": image_h,
|
||||
"word_count": len(words),
|
||||
}
|
||||
276
klausur-service/backend/ocr/pipeline/htr_api.py
Normal file
276
klausur-service/backend/ocr/pipeline/htr_api.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen.
|
||||
|
||||
Endpoints:
|
||||
- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text
|
||||
- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen
|
||||
|
||||
Modell-Strategie:
|
||||
1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM)
|
||||
2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama)
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/htr", tags=["HTR"])
|
||||
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HTRSessionRequest(BaseModel):
|
||||
session_id: str
|
||||
model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large"
|
||||
use_clean: bool = True # Prefer clean_png (after handwriting removal)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
CLAHE contrast enhancement + upscale to improve HTR accuracy.
|
||||
Returns grayscale enhanced image.
|
||||
"""
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# Upscale if image is too small
|
||||
h, w = enhanced.shape
|
||||
if min(h, w) < 800:
|
||||
scale = 800 / min(h, w)
|
||||
enhanced = cv2.resize(
|
||||
enhanced, None, fx=scale, fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
|
||||
return enhanced
|
||||
|
||||
|
||||
def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes:
|
||||
"""Convert BGR ndarray to PNG bytes."""
|
||||
success, buf = cv2.imencode(".png", img_bgr)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to encode image to PNG")
|
||||
return buf.tobytes()
|
||||
|
||||
|
||||
def _preprocess_image_bytes(image_bytes: bytes) -> bytes:
|
||||
"""Load image, apply HTR preprocessing, return PNG bytes."""
|
||||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise ValueError("Could not decode image")
|
||||
|
||||
enhanced = _preprocess_for_htr(img_bgr)
|
||||
# Convert grayscale back to BGR for encoding
|
||||
enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
|
||||
return _bgr_to_png_bytes(enhanced_bgr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend: Ollama qwen2.5vl
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]:
|
||||
"""
|
||||
Send image to Ollama qwen2.5vl:32b for HTR.
|
||||
Returns extracted text or None on error.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
lang_hint = {
|
||||
"de": "Deutsch",
|
||||
"en": "Englisch",
|
||||
"de+en": "Deutsch und Englisch",
|
||||
}.get(language, "Deutsch")
|
||||
|
||||
prompt = (
|
||||
f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. "
|
||||
"Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. "
|
||||
"Antworte NUR mit dem erkannten Text, ohne Erklaerungen."
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"model": OLLAMA_HTR_MODEL,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("response", "").strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"Ollama qwen2.5vl HTR failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend: TrOCR-large fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]:
|
||||
"""
|
||||
Use microsoft/trocr-large-handwritten via trocr_service.py.
|
||||
Returns extracted text or None on error.
|
||||
"""
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr, _check_trocr_available
|
||||
if not _check_trocr_available():
|
||||
logger.warning("TrOCR not available for HTR fallback")
|
||||
return None
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large")
|
||||
return text.strip() if text else None
|
||||
except Exception as e:
|
||||
logger.warning(f"TrOCR-large HTR failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core recognition logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _do_recognize(
|
||||
image_bytes: bytes,
|
||||
model: str = "auto",
|
||||
preprocess: bool = True,
|
||||
language: str = "de",
|
||||
) -> dict:
|
||||
"""
|
||||
Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large.
|
||||
Returns dict with text, model_used, processing_time_ms.
|
||||
"""
|
||||
t0 = time.monotonic()
|
||||
|
||||
if preprocess:
|
||||
try:
|
||||
image_bytes = _preprocess_image_bytes(image_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"HTR preprocessing failed, using raw image: {e}")
|
||||
|
||||
text: Optional[str] = None
|
||||
model_used: str = "none"
|
||||
|
||||
use_qwen = model in ("auto", "qwen2.5vl")
|
||||
use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None)
|
||||
|
||||
if use_qwen:
|
||||
text = await _recognize_with_qwen_vl(image_bytes, language)
|
||||
if text is not None:
|
||||
model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})"
|
||||
|
||||
if text is None and (use_trocr or model == "trocr-large"):
|
||||
text = await _recognize_with_trocr_large(image_bytes)
|
||||
if text is not None:
|
||||
model_used = "trocr-large-handwritten"
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
model_used = "none (all backends failed)"
|
||||
|
||||
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"model_used": model_used,
|
||||
"processing_time_ms": elapsed_ms,
|
||||
"language": language,
|
||||
"preprocessed": preprocess,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/recognize")
|
||||
async def recognize_handwriting(
|
||||
file: UploadFile = File(...),
|
||||
model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"),
|
||||
preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"),
|
||||
language: str = Query("de", description="de | en | de+en"),
|
||||
):
|
||||
"""
|
||||
Upload an image and get back the handwritten text as plain text.
|
||||
|
||||
Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten.
|
||||
"""
|
||||
if model not in ("auto", "qwen2.5vl", "trocr-large"):
|
||||
raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large")
|
||||
if language not in ("de", "en", "de+en"):
|
||||
raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en")
|
||||
|
||||
image_bytes = await file.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language)
|
||||
|
||||
|
||||
@router.post("/recognize-session")
|
||||
async def recognize_from_session(req: HTRSessionRequest):
|
||||
"""
|
||||
Use an OCR-Pipeline session as image source for HTR.
|
||||
|
||||
Set use_clean=true to prefer the clean image (after handwriting removal step).
|
||||
This is useful when you want to do HTR on isolated handwriting regions.
|
||||
"""
|
||||
from ocr_pipeline_session_store import get_session_db, get_session_image
|
||||
|
||||
session = await get_session_db(req.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found")
|
||||
|
||||
# Choose source image
|
||||
image_bytes: Optional[bytes] = None
|
||||
source_used: str = ""
|
||||
|
||||
if req.use_clean:
|
||||
image_bytes = await get_session_image(req.session_id, "clean")
|
||||
if image_bytes:
|
||||
source_used = "clean"
|
||||
|
||||
if not image_bytes:
|
||||
image_bytes = await get_session_image(req.session_id, "deskewed")
|
||||
if image_bytes:
|
||||
source_used = "deskewed"
|
||||
|
||||
if not image_bytes:
|
||||
image_bytes = await get_session_image(req.session_id, "original")
|
||||
source_used = "original"
|
||||
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="No image available in session")
|
||||
|
||||
result = await _do_recognize(image_bytes, model=req.model)
|
||||
result["session_id"] = req.session_id
|
||||
result["source_image"] = source_used
|
||||
return result
|
||||
7
klausur-service/backend/ocr/spell/__init__.py
Normal file
7
klausur-service/backend/ocr/spell/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
OCR spell-checking sub-package — language-aware OCR correction.
|
||||
|
||||
Moved from backend/ flat modules (smart_spell*.py).
|
||||
Backward-compatible shim files remain at the old locations.
|
||||
"""
|
||||
from .smart_spell import * # noqa: F401,F403
|
||||
298
klausur-service/backend/ocr/spell/core.py
Normal file
298
klausur-service/backend/ocr/spell/core.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
SmartSpellChecker Core — init, data types, language detection, word correction.
|
||||
|
||||
Extracted from smart_spell.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from spellchecker import SpellChecker as _SpellChecker
|
||||
_en_spell = _SpellChecker(language='en', distance=1)
|
||||
_de_spell = _SpellChecker(language='de', distance=1)
|
||||
_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AVAILABLE = False
|
||||
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
|
||||
|
||||
Lang = Literal["en", "de", "both", "unknown"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bigram context for a/I disambiguation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Words that commonly follow "I" (subject pronoun -> verb/modal)
|
||||
_I_FOLLOWERS: frozenset = frozenset({
|
||||
"am", "was", "have", "had", "do", "did", "will", "would", "can",
|
||||
"could", "should", "shall", "may", "might", "must",
|
||||
"think", "know", "see", "want", "need", "like", "love", "hate",
|
||||
"go", "went", "come", "came", "say", "said", "get", "got",
|
||||
"make", "made", "take", "took", "give", "gave", "tell", "told",
|
||||
"feel", "felt", "find", "found", "believe", "hope", "wish",
|
||||
"remember", "forget", "understand", "mean", "meant",
|
||||
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
|
||||
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
|
||||
"really", "just", "also", "always", "never", "often", "sometimes",
|
||||
})
|
||||
|
||||
# Words that commonly follow "a" (article -> noun/adjective)
|
||||
_A_FOLLOWERS: frozenset = frozenset({
|
||||
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
|
||||
"long", "short", "big", "small", "large", "huge", "tiny",
|
||||
"nice", "beautiful", "wonderful", "terrible", "horrible",
|
||||
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
|
||||
"book", "car", "house", "room", "school", "teacher", "student",
|
||||
"day", "week", "month", "year", "time", "place", "way",
|
||||
"friend", "family", "person", "problem", "question", "story",
|
||||
"very", "really", "quite", "rather", "pretty", "single",
|
||||
})
|
||||
|
||||
# Digit->letter substitutions (OCR confusion)
|
||||
_DIGIT_SUBS: Dict[str, List[str]] = {
|
||||
'0': ['o', 'O'],
|
||||
'1': ['l', 'I'],
|
||||
'5': ['s', 'S'],
|
||||
'6': ['g', 'G'],
|
||||
'8': ['b', 'B'],
|
||||
'|': ['I', 'l'],
|
||||
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
|
||||
}
|
||||
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
|
||||
|
||||
# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
|
||||
_UMLAUT_MAP = {
|
||||
'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
|
||||
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
|
||||
}
|
||||
|
||||
# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
|
||||
_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class CorrectionResult:
|
||||
original: str
|
||||
corrected: str
|
||||
lang_detected: Lang
|
||||
changed: bool
|
||||
changes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core class — language detection and word-level correction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SmartSpellCoreBase:
|
||||
"""Base class with language detection and single-word correction.
|
||||
|
||||
Not intended for direct use — SmartSpellChecker inherits from this.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not _AVAILABLE:
|
||||
raise RuntimeError("pyspellchecker not installed")
|
||||
self.en = _en_spell
|
||||
self.de = _de_spell
|
||||
|
||||
# --- Language detection ---
|
||||
|
||||
def detect_word_lang(self, word: str) -> Lang:
|
||||
"""Detect language of a single word using dual-dict heuristic."""
|
||||
w = word.lower().strip(".,;:!?\"'()")
|
||||
if not w:
|
||||
return "unknown"
|
||||
in_en = bool(self.en.known([w]))
|
||||
in_de = bool(self.de.known([w]))
|
||||
if in_en and in_de:
|
||||
return "both"
|
||||
if in_en:
|
||||
return "en"
|
||||
if in_de:
|
||||
return "de"
|
||||
return "unknown"
|
||||
|
||||
def detect_text_lang(self, text: str) -> Lang:
|
||||
"""Detect dominant language of a text string (sentence/phrase)."""
|
||||
words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
|
||||
if not words:
|
||||
return "unknown"
|
||||
|
||||
en_count = 0
|
||||
de_count = 0
|
||||
for w in words:
|
||||
lang = self.detect_word_lang(w)
|
||||
if lang == "en":
|
||||
en_count += 1
|
||||
elif lang == "de":
|
||||
de_count += 1
|
||||
# "both" doesn't count for either
|
||||
|
||||
if en_count > de_count:
|
||||
return "en"
|
||||
if de_count > en_count:
|
||||
return "de"
|
||||
if en_count == de_count and en_count > 0:
|
||||
return "both"
|
||||
return "unknown"
|
||||
|
||||
# --- Single-word correction ---
|
||||
|
||||
def _known(self, word: str) -> bool:
|
||||
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
|
||||
w = word.lower()
|
||||
if bool(self.en.known([w])) or bool(self.de.known([w])):
|
||||
return True
|
||||
# Also accept known abbreviations (sth, sb, adj, etc.)
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
if w in _KNOWN_ABBREVIATIONS:
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _word_freq(self, word: str) -> float:
|
||||
"""Get word frequency (max of EN and DE)."""
|
||||
w = word.lower()
|
||||
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
|
||||
|
||||
def _known_in(self, word: str, lang: str) -> bool:
|
||||
"""True if word is known in a specific language dictionary."""
|
||||
w = word.lower()
|
||||
spell = self.en if lang == "en" else self.de
|
||||
return bool(spell.known([w]))
|
||||
|
||||
def correct_word(self, word: str, lang: str = "en",
|
||||
prev_word: str = "", next_word: str = "") -> Optional[str]:
|
||||
"""Correct a single word for the given language.
|
||||
|
||||
Returns None if no correction needed, or the corrected string.
|
||||
"""
|
||||
if not word or not word.strip():
|
||||
return None
|
||||
|
||||
# Skip numbers, abbreviations with dots, very short tokens
|
||||
if word.isdigit() or '.' in word:
|
||||
return None
|
||||
|
||||
# Skip IPA/phonetic content in brackets
|
||||
if '[' in word or ']' in word:
|
||||
return None
|
||||
|
||||
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
|
||||
|
||||
# 1. Already known -> no fix
|
||||
if self._known(word):
|
||||
# But check a/I disambiguation for single-char words
|
||||
if word.lower() in ('l', '|') and next_word:
|
||||
return self._disambiguate_a_I(word, next_word)
|
||||
return None
|
||||
|
||||
# 2. Digit/pipe substitution
|
||||
if has_suspicious:
|
||||
if word == '|':
|
||||
return 'I'
|
||||
# Try single-char substitutions
|
||||
for i, ch in enumerate(word):
|
||||
if ch not in _DIGIT_SUBS:
|
||||
continue
|
||||
for replacement in _DIGIT_SUBS[ch]:
|
||||
candidate = word[:i] + replacement + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
# Try multi-char substitution (e.g., "sch00l" -> "school")
|
||||
multi = self._try_multi_digit_sub(word)
|
||||
if multi:
|
||||
return multi
|
||||
|
||||
# 3. Umlaut correction (German)
|
||||
if lang == "de" and len(word) >= 3 and word.isalpha():
|
||||
umlaut_fix = self._try_umlaut_fix(word)
|
||||
if umlaut_fix:
|
||||
return umlaut_fix
|
||||
|
||||
# 4. General spell correction
|
||||
if not has_suspicious and len(word) >= 3 and word.isalpha():
|
||||
# Safety: don't correct if the word is valid in the OTHER language
|
||||
other_lang = "de" if lang == "en" else "en"
|
||||
if self._known_in(word, other_lang):
|
||||
return None
|
||||
if other_lang == "de" and self._try_umlaut_fix(word):
|
||||
return None # has a valid DE umlaut variant -> don't touch
|
||||
|
||||
spell = self.en if lang == "en" else self.de
|
||||
correction = spell.correction(word.lower())
|
||||
if correction and correction != word.lower():
|
||||
if word[0].isupper():
|
||||
correction = correction[0].upper() + correction[1:]
|
||||
if self._known(correction):
|
||||
return correction
|
||||
|
||||
return None
|
||||
|
||||
# --- Multi-digit substitution ---
|
||||
|
||||
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
|
||||
"""Try replacing multiple digits simultaneously using BFS."""
|
||||
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
|
||||
if not positions or len(positions) > 4:
|
||||
return None
|
||||
|
||||
# BFS over substitution combinations
|
||||
queue = [list(word)]
|
||||
for pos, ch in positions:
|
||||
next_queue = []
|
||||
for current in queue:
|
||||
# Keep original
|
||||
next_queue.append(current[:])
|
||||
# Try each substitution
|
||||
for repl in _DIGIT_SUBS[ch]:
|
||||
variant = current[:]
|
||||
variant[pos] = repl
|
||||
next_queue.append(variant)
|
||||
queue = next_queue
|
||||
|
||||
# Check which combinations produce known words
|
||||
for combo in queue:
|
||||
candidate = "".join(combo)
|
||||
if candidate != word and self._known(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
# --- Umlaut fix ---
|
||||
|
||||
def _try_umlaut_fix(self, word: str) -> Optional[str]:
|
||||
"""Try single-char umlaut substitutions for German words."""
|
||||
for i, ch in enumerate(word):
|
||||
if ch in _UMLAUT_MAP:
|
||||
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
# --- a/I disambiguation ---
|
||||
|
||||
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
|
||||
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
|
||||
nw = next_word.lower().strip(".,;:!?")
|
||||
if nw in _I_FOLLOWERS:
|
||||
return "I"
|
||||
if nw in _A_FOLLOWERS:
|
||||
return "a"
|
||||
return None # uncertain, don't change
|
||||
25
klausur-service/backend/ocr/spell/smart_spell.py
Normal file
25
klausur-service/backend/ocr/spell/smart_spell.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
SmartSpellChecker — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
smart_spell_core — init, data types, language detection, word correction
|
||||
smart_spell_text — full text correction, boundary repair, context split
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
# Core: data types, lang detection (re-exported for tests)
|
||||
from .core import ( # noqa: F401
|
||||
_AVAILABLE,
|
||||
_DIGIT_SUBS,
|
||||
_SUSPICIOUS_CHARS,
|
||||
_UMLAUT_MAP,
|
||||
_TOKEN_RE,
|
||||
_I_FOLLOWERS,
|
||||
_A_FOLLOWERS,
|
||||
CorrectionResult,
|
||||
Lang,
|
||||
)
|
||||
|
||||
# Text: SmartSpellChecker class (the main public API)
|
||||
from .text import SmartSpellChecker # noqa: F401
|
||||
289
klausur-service/backend/ocr/spell/text.py
Normal file
289
klausur-service/backend/ocr/spell/text.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
SmartSpellChecker Text — full text correction, boundary repair, context split.
|
||||
|
||||
Extracted from smart_spell.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from .core import (
|
||||
_SmartSpellCoreBase,
|
||||
_TOKEN_RE,
|
||||
CorrectionResult,
|
||||
Lang,
|
||||
)
|
||||
|
||||
|
||||
class SmartSpellChecker(_SmartSpellCoreBase):
|
||||
"""Language-aware OCR spell checker using pyspellchecker (no LLM).
|
||||
|
||||
Inherits single-word correction from _SmartSpellCoreBase.
|
||||
Adds text-level passes: boundary repair, context split, full correction.
|
||||
"""
|
||||
|
||||
# --- Boundary repair (shifted word boundaries) ---
|
||||
|
||||
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
|
||||
"""Fix shifted word boundaries between adjacent tokens.
|
||||
|
||||
OCR sometimes shifts the boundary: "at sth." -> "ats th."
|
||||
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
|
||||
Returns (fixed_word1, fixed_word2) or None.
|
||||
"""
|
||||
# Import known abbreviations for vocabulary context
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
except ImportError:
|
||||
_KNOWN_ABBREVIATIONS = set()
|
||||
|
||||
# Strip trailing punctuation for checking, preserve for result
|
||||
w2_stripped = word2.rstrip(".,;:!?")
|
||||
w2_punct = word2[len(w2_stripped):]
|
||||
|
||||
# Try shifting 1-2 chars from word1 -> word2
|
||||
for shift in (1, 2):
|
||||
if len(word1) <= shift:
|
||||
continue
|
||||
new_w1 = word1[:-shift]
|
||||
new_w2_base = word1[-shift:] + w2_stripped
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
# Try shifting 1-2 chars from word2 -> word1
|
||||
for shift in (1, 2):
|
||||
if len(w2_stripped) <= shift:
|
||||
continue
|
||||
new_w1 = word1 + w2_stripped[:shift]
|
||||
new_w2_base = w2_stripped[shift:]
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
return None
|
||||
|
||||
# --- Context-based word split for ambiguous merges ---
|
||||
|
||||
# Patterns where a valid word is actually "a" + adjective/noun
|
||||
_ARTICLE_SPLIT_CANDIDATES = {
|
||||
# word -> (article, remainder) -- only when followed by a compatible word
|
||||
"anew": ("a", "new"),
|
||||
"areal": ("a", "real"),
|
||||
"alive": None, # genuinely one word, never split
|
||||
"alone": None,
|
||||
"aware": None,
|
||||
"alike": None,
|
||||
"apart": None,
|
||||
"aside": None,
|
||||
"above": None,
|
||||
"about": None,
|
||||
"among": None,
|
||||
"along": None,
|
||||
}
|
||||
|
||||
def _try_context_split(self, word: str, next_word: str,
|
||||
prev_word: str) -> Optional[str]:
|
||||
"""Split words like 'anew' -> 'a new' when context indicates a merge.
|
||||
|
||||
Only splits when:
|
||||
- The word is in the split candidates list
|
||||
- The following word makes sense as a noun (for "a + adj + noun" pattern)
|
||||
- OR the word is unknown and can be split into article + known word
|
||||
"""
|
||||
w_lower = word.lower()
|
||||
|
||||
# Check explicit candidates
|
||||
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
|
||||
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
|
||||
if split is None:
|
||||
return None # explicitly marked as "don't split"
|
||||
article, remainder = split
|
||||
# Only split if followed by a word (noun pattern)
|
||||
if next_word and next_word[0].islower():
|
||||
return f"{article} {remainder}"
|
||||
# Also split if remainder + next_word makes a common phrase
|
||||
if next_word and self._known(next_word):
|
||||
return f"{article} {remainder}"
|
||||
|
||||
# Generic: if word starts with 'a' and rest is a known adjective/word
|
||||
if (len(word) >= 4 and word[0].lower() == 'a'
|
||||
and not self._known(word) # only for UNKNOWN words
|
||||
and self._known(word[1:])):
|
||||
return f"a {word[1:]}"
|
||||
|
||||
return None
|
||||
|
||||
# --- Full text correction ---
|
||||
|
||||
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
|
||||
"""Correct a full text string (field value).
|
||||
|
||||
Three passes:
|
||||
1. Boundary repair -- fix shifted word boundaries between adjacent tokens
|
||||
2. Context split -- split ambiguous merges (anew -> a new)
|
||||
3. Per-word correction -- spell check individual words
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return CorrectionResult(text, text, "unknown", False)
|
||||
|
||||
detected = self.detect_text_lang(text) if lang == "auto" else lang
|
||||
effective_lang = detected if detected in ("en", "de") else "en"
|
||||
|
||||
changes: List[str] = []
|
||||
tokens = list(_TOKEN_RE.finditer(text))
|
||||
|
||||
# Extract token list: [(word, separator), ...]
|
||||
token_list: List[List[str]] = [] # [[word, sep], ...]
|
||||
for m in tokens:
|
||||
token_list.append([m.group(1), m.group(2)])
|
||||
|
||||
# --- Pass 1: Boundary repair between adjacent unknown words ---
|
||||
# Import abbreviations for the heuristic below
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
|
||||
except ImportError:
|
||||
_ABBREVS = set()
|
||||
|
||||
for i in range(len(token_list) - 1):
|
||||
w1 = token_list[i][0]
|
||||
w2_raw = token_list[i + 1][0]
|
||||
|
||||
# Skip boundary repair for IPA/bracket content
|
||||
# Brackets may be in the token OR in the adjacent separators
|
||||
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
|
||||
sep_after_w1 = token_list[i][1]
|
||||
sep_after_w2 = token_list[i + 1][1]
|
||||
has_bracket = (
|
||||
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
|
||||
or ']' in sep_after_w1 # w1 text was inside [brackets]
|
||||
or '[' in sep_after_w1 # w2 starts a bracket
|
||||
or ']' in sep_after_w2 # w2 text was inside [brackets]
|
||||
or '[' in sep_before_w1 # w1 starts a bracket
|
||||
)
|
||||
if has_bracket:
|
||||
continue
|
||||
|
||||
# Include trailing punct from separator in w2 for abbreviation matching
|
||||
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
|
||||
|
||||
# Try boundary repair -- always, even if both words are valid.
|
||||
# Use word-frequency scoring to decide if repair is better.
|
||||
repair = self._try_boundary_repair(w1, w2_with_punct)
|
||||
if not repair and w2_with_punct != w2_raw:
|
||||
repair = self._try_boundary_repair(w1, w2_raw)
|
||||
if repair:
|
||||
new_w1, new_w2_full = repair
|
||||
new_w2_base = new_w2_full.rstrip(".,;:!?")
|
||||
|
||||
# Frequency-based scoring: product of word frequencies
|
||||
# Higher product = more common word pair = better
|
||||
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
|
||||
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
|
||||
|
||||
# Abbreviation bonus: if repair produces a known abbreviation
|
||||
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
|
||||
if has_abbrev:
|
||||
# Accept abbreviation repair ONLY if at least one of the
|
||||
# original words is rare/unknown (prevents "Can I" -> "Ca nI"
|
||||
# where both original words are common and correct).
|
||||
RARE_THRESHOLD = 1e-6
|
||||
orig_both_common = (
|
||||
self._word_freq(w1) > RARE_THRESHOLD
|
||||
and self._word_freq(w2_raw) > RARE_THRESHOLD
|
||||
)
|
||||
if not orig_both_common:
|
||||
new_freq = max(new_freq, old_freq * 10)
|
||||
else:
|
||||
has_abbrev = False # both originals common -> don't trust
|
||||
|
||||
# Accept if repair produces a more frequent word pair
|
||||
# (threshold: at least 5x more frequent to avoid false positives)
|
||||
if new_freq > old_freq * 5:
|
||||
new_w2_punct = new_w2_full[len(new_w2_base):]
|
||||
changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
|
||||
token_list[i][0] = new_w1
|
||||
token_list[i + 1][0] = new_w2_base
|
||||
if new_w2_punct:
|
||||
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
|
||||
|
||||
# --- Pass 2: Context split (anew -> a new) ---
|
||||
expanded: List[List[str]] = []
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
split = self._try_context_split(word, next_word, prev_word)
|
||||
if split and split != word:
|
||||
changes.append(f"{word}\u2192{split}")
|
||||
expanded.append([split, sep])
|
||||
else:
|
||||
expanded.append([word, sep])
|
||||
token_list = expanded
|
||||
|
||||
# --- Pass 3: Per-word correction ---
|
||||
parts: List[str] = []
|
||||
|
||||
# Preserve any leading text before the first token match
|
||||
first_start = tokens[0].start() if tokens else 0
|
||||
if first_start > 0:
|
||||
parts.append(text[:first_start])
|
||||
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
# Skip words inside IPA brackets (brackets land in separators)
|
||||
prev_sep = token_list[i - 1][1] if i > 0 else ""
|
||||
if '[' in prev_sep or ']' in sep:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
continue
|
||||
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
|
||||
correction = self.correct_word(
|
||||
word, lang=effective_lang,
|
||||
prev_word=prev_word, next_word=next_word,
|
||||
)
|
||||
if correction and correction != word:
|
||||
changes.append(f"{word}\u2192{correction}")
|
||||
parts.append(correction)
|
||||
else:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
|
||||
# Append any trailing text
|
||||
last_end = tokens[-1].end() if tokens else 0
|
||||
if last_end < len(text):
|
||||
parts.append(text[last_end:])
|
||||
|
||||
corrected = "".join(parts)
|
||||
return CorrectionResult(
|
||||
original=text,
|
||||
corrected=corrected,
|
||||
lang_detected=detected,
|
||||
changed=corrected != text,
|
||||
changes=changes,
|
||||
)
|
||||
|
||||
# --- Vocabulary entry correction ---
|
||||
|
||||
def correct_vocab_entry(self, english: str, german: str,
|
||||
example: str = "") -> Dict[str, CorrectionResult]:
|
||||
"""Correct a full vocabulary entry (EN + DE + example).
|
||||
|
||||
Uses column position to determine language -- the most reliable signal.
|
||||
"""
|
||||
results = {}
|
||||
results["english"] = self.correct_text(english, lang="en")
|
||||
results["german"] = self.correct_text(german, lang="de")
|
||||
if example:
|
||||
# For examples, auto-detect language
|
||||
results["example"] = self.correct_text(example, lang="auto")
|
||||
return results
|
||||
@@ -1,25 +1,4 @@
|
||||
"""
|
||||
SmartSpellChecker — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
smart_spell_core — init, data types, language detection, word correction
|
||||
smart_spell_text — full text correction, boundary repair, context split
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
# Core: data types, lang detection (re-exported for tests)
|
||||
from smart_spell_core import ( # noqa: F401
|
||||
_AVAILABLE,
|
||||
_DIGIT_SUBS,
|
||||
_SUSPICIOUS_CHARS,
|
||||
_UMLAUT_MAP,
|
||||
_TOKEN_RE,
|
||||
_I_FOLLOWERS,
|
||||
_A_FOLLOWERS,
|
||||
CorrectionResult,
|
||||
Lang,
|
||||
)
|
||||
|
||||
# Text: SmartSpellChecker class (the main public API)
|
||||
from smart_spell_text import SmartSpellChecker # noqa: F401
|
||||
# Backward-compat shim -- module moved to ocr/spell/smart_spell.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("ocr.spell.smart_spell")
|
||||
|
||||
@@ -1,298 +1,4 @@
|
||||
"""
|
||||
SmartSpellChecker Core — init, data types, language detection, word correction.
|
||||
|
||||
Extracted from smart_spell.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from spellchecker import SpellChecker as _SpellChecker
|
||||
_en_spell = _SpellChecker(language='en', distance=1)
|
||||
_de_spell = _SpellChecker(language='de', distance=1)
|
||||
_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AVAILABLE = False
|
||||
logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
|
||||
|
||||
Lang = Literal["en", "de", "both", "unknown"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bigram context for a/I disambiguation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Words that commonly follow "I" (subject pronoun -> verb/modal)
|
||||
_I_FOLLOWERS: frozenset = frozenset({
|
||||
"am", "was", "have", "had", "do", "did", "will", "would", "can",
|
||||
"could", "should", "shall", "may", "might", "must",
|
||||
"think", "know", "see", "want", "need", "like", "love", "hate",
|
||||
"go", "went", "come", "came", "say", "said", "get", "got",
|
||||
"make", "made", "take", "took", "give", "gave", "tell", "told",
|
||||
"feel", "felt", "find", "found", "believe", "hope", "wish",
|
||||
"remember", "forget", "understand", "mean", "meant",
|
||||
"don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
|
||||
"shouldn't", "haven't", "hadn't", "isn't", "wasn't",
|
||||
"really", "just", "also", "always", "never", "often", "sometimes",
|
||||
})
|
||||
|
||||
# Words that commonly follow "a" (article -> noun/adjective)
|
||||
_A_FOLLOWERS: frozenset = frozenset({
|
||||
"lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
|
||||
"long", "short", "big", "small", "large", "huge", "tiny",
|
||||
"nice", "beautiful", "wonderful", "terrible", "horrible",
|
||||
"man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
|
||||
"book", "car", "house", "room", "school", "teacher", "student",
|
||||
"day", "week", "month", "year", "time", "place", "way",
|
||||
"friend", "family", "person", "problem", "question", "story",
|
||||
"very", "really", "quite", "rather", "pretty", "single",
|
||||
})
|
||||
|
||||
# Digit->letter substitutions (OCR confusion)
|
||||
_DIGIT_SUBS: Dict[str, List[str]] = {
|
||||
'0': ['o', 'O'],
|
||||
'1': ['l', 'I'],
|
||||
'5': ['s', 'S'],
|
||||
'6': ['g', 'G'],
|
||||
'8': ['b', 'B'],
|
||||
'|': ['I', 'l'],
|
||||
'/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
|
||||
}
|
||||
_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
|
||||
|
||||
# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
|
||||
_UMLAUT_MAP = {
|
||||
'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
|
||||
'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
|
||||
}
|
||||
|
||||
# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
|
||||
_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class CorrectionResult:
|
||||
original: str
|
||||
corrected: str
|
||||
lang_detected: Lang
|
||||
changed: bool
|
||||
changes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core class — language detection and word-level correction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SmartSpellCoreBase:
|
||||
"""Base class with language detection and single-word correction.
|
||||
|
||||
Not intended for direct use — SmartSpellChecker inherits from this.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not _AVAILABLE:
|
||||
raise RuntimeError("pyspellchecker not installed")
|
||||
self.en = _en_spell
|
||||
self.de = _de_spell
|
||||
|
||||
# --- Language detection ---
|
||||
|
||||
def detect_word_lang(self, word: str) -> Lang:
|
||||
"""Detect language of a single word using dual-dict heuristic."""
|
||||
w = word.lower().strip(".,;:!?\"'()")
|
||||
if not w:
|
||||
return "unknown"
|
||||
in_en = bool(self.en.known([w]))
|
||||
in_de = bool(self.de.known([w]))
|
||||
if in_en and in_de:
|
||||
return "both"
|
||||
if in_en:
|
||||
return "en"
|
||||
if in_de:
|
||||
return "de"
|
||||
return "unknown"
|
||||
|
||||
def detect_text_lang(self, text: str) -> Lang:
|
||||
"""Detect dominant language of a text string (sentence/phrase)."""
|
||||
words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
|
||||
if not words:
|
||||
return "unknown"
|
||||
|
||||
en_count = 0
|
||||
de_count = 0
|
||||
for w in words:
|
||||
lang = self.detect_word_lang(w)
|
||||
if lang == "en":
|
||||
en_count += 1
|
||||
elif lang == "de":
|
||||
de_count += 1
|
||||
# "both" doesn't count for either
|
||||
|
||||
if en_count > de_count:
|
||||
return "en"
|
||||
if de_count > en_count:
|
||||
return "de"
|
||||
if en_count == de_count and en_count > 0:
|
||||
return "both"
|
||||
return "unknown"
|
||||
|
||||
# --- Single-word correction ---
|
||||
|
||||
def _known(self, word: str) -> bool:
|
||||
"""True if word is known in EN or DE dictionary, or is a known abbreviation."""
|
||||
w = word.lower()
|
||||
if bool(self.en.known([w])) or bool(self.de.known([w])):
|
||||
return True
|
||||
# Also accept known abbreviations (sth, sb, adj, etc.)
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
if w in _KNOWN_ABBREVIATIONS:
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _word_freq(self, word: str) -> float:
|
||||
"""Get word frequency (max of EN and DE)."""
|
||||
w = word.lower()
|
||||
return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
|
||||
|
||||
def _known_in(self, word: str, lang: str) -> bool:
|
||||
"""True if word is known in a specific language dictionary."""
|
||||
w = word.lower()
|
||||
spell = self.en if lang == "en" else self.de
|
||||
return bool(spell.known([w]))
|
||||
|
||||
def correct_word(self, word: str, lang: str = "en",
|
||||
prev_word: str = "", next_word: str = "") -> Optional[str]:
|
||||
"""Correct a single word for the given language.
|
||||
|
||||
Returns None if no correction needed, or the corrected string.
|
||||
"""
|
||||
if not word or not word.strip():
|
||||
return None
|
||||
|
||||
# Skip numbers, abbreviations with dots, very short tokens
|
||||
if word.isdigit() or '.' in word:
|
||||
return None
|
||||
|
||||
# Skip IPA/phonetic content in brackets
|
||||
if '[' in word or ']' in word:
|
||||
return None
|
||||
|
||||
has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
|
||||
|
||||
# 1. Already known -> no fix
|
||||
if self._known(word):
|
||||
# But check a/I disambiguation for single-char words
|
||||
if word.lower() in ('l', '|') and next_word:
|
||||
return self._disambiguate_a_I(word, next_word)
|
||||
return None
|
||||
|
||||
# 2. Digit/pipe substitution
|
||||
if has_suspicious:
|
||||
if word == '|':
|
||||
return 'I'
|
||||
# Try single-char substitutions
|
||||
for i, ch in enumerate(word):
|
||||
if ch not in _DIGIT_SUBS:
|
||||
continue
|
||||
for replacement in _DIGIT_SUBS[ch]:
|
||||
candidate = word[:i] + replacement + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
# Try multi-char substitution (e.g., "sch00l" -> "school")
|
||||
multi = self._try_multi_digit_sub(word)
|
||||
if multi:
|
||||
return multi
|
||||
|
||||
# 3. Umlaut correction (German)
|
||||
if lang == "de" and len(word) >= 3 and word.isalpha():
|
||||
umlaut_fix = self._try_umlaut_fix(word)
|
||||
if umlaut_fix:
|
||||
return umlaut_fix
|
||||
|
||||
# 4. General spell correction
|
||||
if not has_suspicious and len(word) >= 3 and word.isalpha():
|
||||
# Safety: don't correct if the word is valid in the OTHER language
|
||||
other_lang = "de" if lang == "en" else "en"
|
||||
if self._known_in(word, other_lang):
|
||||
return None
|
||||
if other_lang == "de" and self._try_umlaut_fix(word):
|
||||
return None # has a valid DE umlaut variant -> don't touch
|
||||
|
||||
spell = self.en if lang == "en" else self.de
|
||||
correction = spell.correction(word.lower())
|
||||
if correction and correction != word.lower():
|
||||
if word[0].isupper():
|
||||
correction = correction[0].upper() + correction[1:]
|
||||
if self._known(correction):
|
||||
return correction
|
||||
|
||||
return None
|
||||
|
||||
# --- Multi-digit substitution ---
|
||||
|
||||
def _try_multi_digit_sub(self, word: str) -> Optional[str]:
|
||||
"""Try replacing multiple digits simultaneously using BFS."""
|
||||
positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
|
||||
if not positions or len(positions) > 4:
|
||||
return None
|
||||
|
||||
# BFS over substitution combinations
|
||||
queue = [list(word)]
|
||||
for pos, ch in positions:
|
||||
next_queue = []
|
||||
for current in queue:
|
||||
# Keep original
|
||||
next_queue.append(current[:])
|
||||
# Try each substitution
|
||||
for repl in _DIGIT_SUBS[ch]:
|
||||
variant = current[:]
|
||||
variant[pos] = repl
|
||||
next_queue.append(variant)
|
||||
queue = next_queue
|
||||
|
||||
# Check which combinations produce known words
|
||||
for combo in queue:
|
||||
candidate = "".join(combo)
|
||||
if candidate != word and self._known(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
# --- Umlaut fix ---
|
||||
|
||||
def _try_umlaut_fix(self, word: str) -> Optional[str]:
|
||||
"""Try single-char umlaut substitutions for German words."""
|
||||
for i, ch in enumerate(word):
|
||||
if ch in _UMLAUT_MAP:
|
||||
candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
|
||||
if self._known(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
# --- a/I disambiguation ---
|
||||
|
||||
def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
|
||||
"""Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
|
||||
nw = next_word.lower().strip(".,;:!?")
|
||||
if nw in _I_FOLLOWERS:
|
||||
return "I"
|
||||
if nw in _A_FOLLOWERS:
|
||||
return "a"
|
||||
return None # uncertain, don't change
|
||||
# Backward-compat shim -- module moved to ocr/spell/core.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("ocr.spell.core")
|
||||
|
||||
@@ -1,289 +1,4 @@
|
||||
"""
|
||||
SmartSpellChecker Text — full text correction, boundary repair, context split.
|
||||
|
||||
Extracted from smart_spell.py for modularity.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from smart_spell_core import (
|
||||
_SmartSpellCoreBase,
|
||||
_TOKEN_RE,
|
||||
CorrectionResult,
|
||||
Lang,
|
||||
)
|
||||
|
||||
|
||||
class SmartSpellChecker(_SmartSpellCoreBase):
|
||||
"""Language-aware OCR spell checker using pyspellchecker (no LLM).
|
||||
|
||||
Inherits single-word correction from _SmartSpellCoreBase.
|
||||
Adds text-level passes: boundary repair, context split, full correction.
|
||||
"""
|
||||
|
||||
# --- Boundary repair (shifted word boundaries) ---
|
||||
|
||||
def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
|
||||
"""Fix shifted word boundaries between adjacent tokens.
|
||||
|
||||
OCR sometimes shifts the boundary: "at sth." -> "ats th."
|
||||
Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
|
||||
Returns (fixed_word1, fixed_word2) or None.
|
||||
"""
|
||||
# Import known abbreviations for vocabulary context
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS
|
||||
except ImportError:
|
||||
_KNOWN_ABBREVIATIONS = set()
|
||||
|
||||
# Strip trailing punctuation for checking, preserve for result
|
||||
w2_stripped = word2.rstrip(".,;:!?")
|
||||
w2_punct = word2[len(w2_stripped):]
|
||||
|
||||
# Try shifting 1-2 chars from word1 -> word2
|
||||
for shift in (1, 2):
|
||||
if len(word1) <= shift:
|
||||
continue
|
||||
new_w1 = word1[:-shift]
|
||||
new_w2_base = word1[-shift:] + w2_stripped
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
# Try shifting 1-2 chars from word2 -> word1
|
||||
for shift in (1, 2):
|
||||
if len(w2_stripped) <= shift:
|
||||
continue
|
||||
new_w1 = word1 + w2_stripped[:shift]
|
||||
new_w2_base = w2_stripped[shift:]
|
||||
|
||||
w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
|
||||
w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
|
||||
|
||||
if w1_ok and w2_ok:
|
||||
return (new_w1, new_w2_base + w2_punct)
|
||||
|
||||
return None
|
||||
|
||||
# --- Context-based word split for ambiguous merges ---
|
||||
|
||||
# Patterns where a valid word is actually "a" + adjective/noun
|
||||
_ARTICLE_SPLIT_CANDIDATES = {
|
||||
# word -> (article, remainder) -- only when followed by a compatible word
|
||||
"anew": ("a", "new"),
|
||||
"areal": ("a", "real"),
|
||||
"alive": None, # genuinely one word, never split
|
||||
"alone": None,
|
||||
"aware": None,
|
||||
"alike": None,
|
||||
"apart": None,
|
||||
"aside": None,
|
||||
"above": None,
|
||||
"about": None,
|
||||
"among": None,
|
||||
"along": None,
|
||||
}
|
||||
|
||||
def _try_context_split(self, word: str, next_word: str,
|
||||
prev_word: str) -> Optional[str]:
|
||||
"""Split words like 'anew' -> 'a new' when context indicates a merge.
|
||||
|
||||
Only splits when:
|
||||
- The word is in the split candidates list
|
||||
- The following word makes sense as a noun (for "a + adj + noun" pattern)
|
||||
- OR the word is unknown and can be split into article + known word
|
||||
"""
|
||||
w_lower = word.lower()
|
||||
|
||||
# Check explicit candidates
|
||||
if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
|
||||
split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
|
||||
if split is None:
|
||||
return None # explicitly marked as "don't split"
|
||||
article, remainder = split
|
||||
# Only split if followed by a word (noun pattern)
|
||||
if next_word and next_word[0].islower():
|
||||
return f"{article} {remainder}"
|
||||
# Also split if remainder + next_word makes a common phrase
|
||||
if next_word and self._known(next_word):
|
||||
return f"{article} {remainder}"
|
||||
|
||||
# Generic: if word starts with 'a' and rest is a known adjective/word
|
||||
if (len(word) >= 4 and word[0].lower() == 'a'
|
||||
and not self._known(word) # only for UNKNOWN words
|
||||
and self._known(word[1:])):
|
||||
return f"a {word[1:]}"
|
||||
|
||||
return None
|
||||
|
||||
# --- Full text correction ---
|
||||
|
||||
def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
|
||||
"""Correct a full text string (field value).
|
||||
|
||||
Three passes:
|
||||
1. Boundary repair -- fix shifted word boundaries between adjacent tokens
|
||||
2. Context split -- split ambiguous merges (anew -> a new)
|
||||
3. Per-word correction -- spell check individual words
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return CorrectionResult(text, text, "unknown", False)
|
||||
|
||||
detected = self.detect_text_lang(text) if lang == "auto" else lang
|
||||
effective_lang = detected if detected in ("en", "de") else "en"
|
||||
|
||||
changes: List[str] = []
|
||||
tokens = list(_TOKEN_RE.finditer(text))
|
||||
|
||||
# Extract token list: [(word, separator), ...]
|
||||
token_list: List[List[str]] = [] # [[word, sep], ...]
|
||||
for m in tokens:
|
||||
token_list.append([m.group(1), m.group(2)])
|
||||
|
||||
# --- Pass 1: Boundary repair between adjacent unknown words ---
|
||||
# Import abbreviations for the heuristic below
|
||||
try:
|
||||
from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
|
||||
except ImportError:
|
||||
_ABBREVS = set()
|
||||
|
||||
for i in range(len(token_list) - 1):
|
||||
w1 = token_list[i][0]
|
||||
w2_raw = token_list[i + 1][0]
|
||||
|
||||
# Skip boundary repair for IPA/bracket content
|
||||
# Brackets may be in the token OR in the adjacent separators
|
||||
sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
|
||||
sep_after_w1 = token_list[i][1]
|
||||
sep_after_w2 = token_list[i + 1][1]
|
||||
has_bracket = (
|
||||
'[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
|
||||
or ']' in sep_after_w1 # w1 text was inside [brackets]
|
||||
or '[' in sep_after_w1 # w2 starts a bracket
|
||||
or ']' in sep_after_w2 # w2 text was inside [brackets]
|
||||
or '[' in sep_before_w1 # w1 starts a bracket
|
||||
)
|
||||
if has_bracket:
|
||||
continue
|
||||
|
||||
# Include trailing punct from separator in w2 for abbreviation matching
|
||||
w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
|
||||
|
||||
# Try boundary repair -- always, even if both words are valid.
|
||||
# Use word-frequency scoring to decide if repair is better.
|
||||
repair = self._try_boundary_repair(w1, w2_with_punct)
|
||||
if not repair and w2_with_punct != w2_raw:
|
||||
repair = self._try_boundary_repair(w1, w2_raw)
|
||||
if repair:
|
||||
new_w1, new_w2_full = repair
|
||||
new_w2_base = new_w2_full.rstrip(".,;:!?")
|
||||
|
||||
# Frequency-based scoring: product of word frequencies
|
||||
# Higher product = more common word pair = better
|
||||
old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
|
||||
new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
|
||||
|
||||
# Abbreviation bonus: if repair produces a known abbreviation
|
||||
has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
|
||||
if has_abbrev:
|
||||
# Accept abbreviation repair ONLY if at least one of the
|
||||
# original words is rare/unknown (prevents "Can I" -> "Ca nI"
|
||||
# where both original words are common and correct).
|
||||
RARE_THRESHOLD = 1e-6
|
||||
orig_both_common = (
|
||||
self._word_freq(w1) > RARE_THRESHOLD
|
||||
and self._word_freq(w2_raw) > RARE_THRESHOLD
|
||||
)
|
||||
if not orig_both_common:
|
||||
new_freq = max(new_freq, old_freq * 10)
|
||||
else:
|
||||
has_abbrev = False # both originals common -> don't trust
|
||||
|
||||
# Accept if repair produces a more frequent word pair
|
||||
# (threshold: at least 5x more frequent to avoid false positives)
|
||||
if new_freq > old_freq * 5:
|
||||
new_w2_punct = new_w2_full[len(new_w2_base):]
|
||||
changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
|
||||
token_list[i][0] = new_w1
|
||||
token_list[i + 1][0] = new_w2_base
|
||||
if new_w2_punct:
|
||||
token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
|
||||
|
||||
# --- Pass 2: Context split (anew -> a new) ---
|
||||
expanded: List[List[str]] = []
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
split = self._try_context_split(word, next_word, prev_word)
|
||||
if split and split != word:
|
||||
changes.append(f"{word}\u2192{split}")
|
||||
expanded.append([split, sep])
|
||||
else:
|
||||
expanded.append([word, sep])
|
||||
token_list = expanded
|
||||
|
||||
# --- Pass 3: Per-word correction ---
|
||||
parts: List[str] = []
|
||||
|
||||
# Preserve any leading text before the first token match
|
||||
first_start = tokens[0].start() if tokens else 0
|
||||
if first_start > 0:
|
||||
parts.append(text[:first_start])
|
||||
|
||||
for i, (word, sep) in enumerate(token_list):
|
||||
# Skip words inside IPA brackets (brackets land in separators)
|
||||
prev_sep = token_list[i - 1][1] if i > 0 else ""
|
||||
if '[' in prev_sep or ']' in sep:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
continue
|
||||
|
||||
next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
|
||||
prev_word = token_list[i - 1][0] if i > 0 else ""
|
||||
|
||||
correction = self.correct_word(
|
||||
word, lang=effective_lang,
|
||||
prev_word=prev_word, next_word=next_word,
|
||||
)
|
||||
if correction and correction != word:
|
||||
changes.append(f"{word}\u2192{correction}")
|
||||
parts.append(correction)
|
||||
else:
|
||||
parts.append(word)
|
||||
parts.append(sep)
|
||||
|
||||
# Append any trailing text
|
||||
last_end = tokens[-1].end() if tokens else 0
|
||||
if last_end < len(text):
|
||||
parts.append(text[last_end:])
|
||||
|
||||
corrected = "".join(parts)
|
||||
return CorrectionResult(
|
||||
original=text,
|
||||
corrected=corrected,
|
||||
lang_detected=detected,
|
||||
changed=corrected != text,
|
||||
changes=changes,
|
||||
)
|
||||
|
||||
# --- Vocabulary entry correction ---
|
||||
|
||||
def correct_vocab_entry(self, english: str, german: str,
|
||||
example: str = "") -> Dict[str, CorrectionResult]:
|
||||
"""Correct a full vocabulary entry (EN + DE + example).
|
||||
|
||||
Uses column position to determine language -- the most reliable signal.
|
||||
"""
|
||||
results = {}
|
||||
results["english"] = self.correct_text(english, lang="en")
|
||||
results["german"] = self.correct_text(german, lang="de")
|
||||
if example:
|
||||
# For examples, auto-detect language
|
||||
results["example"] = self.correct_text(example, lang="auto")
|
||||
return results
|
||||
# Backward-compat shim -- module moved to ocr/spell/text.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("ocr.spell.text")
|
||||
|
||||
@@ -1,346 +1,4 @@
|
||||
"""
|
||||
Tesseract-based OCR extraction with word-level bounding boxes.
|
||||
|
||||
Uses Tesseract for spatial information (WHERE text is) while
|
||||
the Vision LLM handles semantic understanding (WHAT the text means).
|
||||
|
||||
Tesseract runs natively on ARM64 via Debian's apt package.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
TESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TESSERACT_AVAILABLE = False
|
||||
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
|
||||
|
||||
|
||||
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||
"""Run Tesseract OCR and return word-level bounding boxes.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG/JPEG image as bytes.
|
||||
lang: Tesseract language string (e.g. "eng+deu").
|
||||
|
||||
Returns:
|
||||
Dict with 'words' list and 'image_width'/'image_height'.
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
|
||||
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
|
||||
|
||||
words = []
|
||||
for i in range(len(data['text'])):
|
||||
text = data['text'][i].strip()
|
||||
conf = int(data['conf'][i])
|
||||
if not text or conf < 20:
|
||||
continue
|
||||
words.append({
|
||||
"text": text,
|
||||
"left": data['left'][i],
|
||||
"top": data['top'][i],
|
||||
"width": data['width'][i],
|
||||
"height": data['height'][i],
|
||||
"conf": conf,
|
||||
"block_num": data['block_num'][i],
|
||||
"par_num": data['par_num'][i],
|
||||
"line_num": data['line_num'][i],
|
||||
"word_num": data['word_num'][i],
|
||||
})
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"image_width": image.width,
|
||||
"image_height": image.height,
|
||||
}
|
||||
|
||||
|
||||
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
|
||||
"""Group words by their Y position into lines.
|
||||
|
||||
Args:
|
||||
words: List of word dicts from extract_bounding_boxes.
|
||||
y_tolerance_px: Max pixel distance to consider words on the same line.
|
||||
|
||||
Returns:
|
||||
List of lines, each line is a list of words sorted by X position.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
|
||||
# Sort by Y then X
|
||||
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
|
||||
|
||||
lines: List[List[dict]] = []
|
||||
current_line: List[dict] = [sorted_words[0]]
|
||||
current_y = sorted_words[0]['top']
|
||||
|
||||
for word in sorted_words[1:]:
|
||||
if abs(word['top'] - current_y) <= y_tolerance_px:
|
||||
current_line.append(word)
|
||||
else:
|
||||
current_line.sort(key=lambda w: w['left'])
|
||||
lines.append(current_line)
|
||||
current_line = [word]
|
||||
current_y = word['top']
|
||||
|
||||
if current_line:
|
||||
current_line.sort(key=lambda w: w['left'])
|
||||
lines.append(current_line)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
|
||||
"""Detect column boundaries from word positions.
|
||||
|
||||
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
|
||||
|
||||
Returns:
|
||||
Dict with column boundaries and type assignments.
|
||||
"""
|
||||
if not lines or image_width == 0:
|
||||
return {"columns": [], "column_types": []}
|
||||
|
||||
# Collect all word X positions
|
||||
all_x_positions = []
|
||||
for line in lines:
|
||||
for word in line:
|
||||
all_x_positions.append(word['left'])
|
||||
|
||||
if not all_x_positions:
|
||||
return {"columns": [], "column_types": []}
|
||||
|
||||
# Find X-position clusters (column starts)
|
||||
all_x_positions.sort()
|
||||
|
||||
# Simple gap-based column detection
|
||||
min_gap = image_width * 0.08 # 8% of page width = column gap
|
||||
clusters = []
|
||||
current_cluster = [all_x_positions[0]]
|
||||
|
||||
for x in all_x_positions[1:]:
|
||||
if x - current_cluster[-1] > min_gap:
|
||||
clusters.append(current_cluster)
|
||||
current_cluster = [x]
|
||||
else:
|
||||
current_cluster.append(x)
|
||||
|
||||
if current_cluster:
|
||||
clusters.append(current_cluster)
|
||||
|
||||
# Each cluster represents a column start
|
||||
columns = []
|
||||
for cluster in clusters:
|
||||
col_start = min(cluster)
|
||||
columns.append({
|
||||
"x_start": col_start,
|
||||
"x_start_pct": col_start / image_width * 100,
|
||||
"word_count": len(cluster),
|
||||
})
|
||||
|
||||
# Assign column types based on position (left→right: EN, DE, Example)
|
||||
type_map = ["english", "german", "example"]
|
||||
column_types = []
|
||||
for i, col in enumerate(columns):
|
||||
if i < len(type_map):
|
||||
column_types.append(type_map[i])
|
||||
else:
|
||||
column_types.append("unknown")
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
}
|
||||
|
||||
|
||||
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
|
||||
column_types: List[str], image_width: int,
|
||||
image_height: int) -> List[dict]:
|
||||
"""Convert grouped words into vocabulary entries using column positions.
|
||||
|
||||
Args:
|
||||
lines: Grouped word lines from group_words_into_lines.
|
||||
columns: Column boundaries from detect_columns.
|
||||
column_types: Column type assignments.
|
||||
image_width: Image width in pixels.
|
||||
image_height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
List of vocabulary entry dicts with english/german/example fields.
|
||||
"""
|
||||
if not columns or not lines:
|
||||
return []
|
||||
|
||||
# Build column boundaries for word assignment
|
||||
col_boundaries = []
|
||||
for i, col in enumerate(columns):
|
||||
start = col['x_start']
|
||||
if i + 1 < len(columns):
|
||||
end = columns[i + 1]['x_start']
|
||||
else:
|
||||
end = image_width
|
||||
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
|
||||
|
||||
entries = []
|
||||
for line in lines:
|
||||
entry = {"english": "", "german": "", "example": ""}
|
||||
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
|
||||
line_bbox: Dict[str, Optional[dict]] = {}
|
||||
|
||||
for word in line:
|
||||
word_center_x = word['left'] + word['width'] / 2
|
||||
assigned_type = "unknown"
|
||||
for start, end, col_type in col_boundaries:
|
||||
if start <= word_center_x < end:
|
||||
assigned_type = col_type
|
||||
break
|
||||
|
||||
if assigned_type in line_words_by_col:
|
||||
line_words_by_col[assigned_type].append(word['text'])
|
||||
# Track bounding box for the column
|
||||
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
|
||||
line_bbox[assigned_type] = {
|
||||
"left": word['left'],
|
||||
"top": word['top'],
|
||||
"right": word['left'] + word['width'],
|
||||
"bottom": word['top'] + word['height'],
|
||||
}
|
||||
else:
|
||||
bb = line_bbox[assigned_type]
|
||||
bb['left'] = min(bb['left'], word['left'])
|
||||
bb['top'] = min(bb['top'], word['top'])
|
||||
bb['right'] = max(bb['right'], word['left'] + word['width'])
|
||||
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
|
||||
|
||||
for col_type in ["english", "german", "example"]:
|
||||
if line_words_by_col[col_type]:
|
||||
entry[col_type] = " ".join(line_words_by_col[col_type])
|
||||
if line_bbox.get(col_type):
|
||||
bb = line_bbox[col_type]
|
||||
entry[f"{col_type}_bbox"] = {
|
||||
"x_pct": bb['left'] / image_width * 100,
|
||||
"y_pct": bb['top'] / image_height * 100,
|
||||
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
|
||||
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
|
||||
}
|
||||
|
||||
# Only add if at least one column has content
|
||||
if entry["english"] or entry["german"]:
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
|
||||
image_w: int, image_h: int,
|
||||
threshold: float = 0.6) -> List[dict]:
|
||||
"""Match Tesseract bounding boxes to LLM vocabulary entries.
|
||||
|
||||
For each LLM vocab entry, find the best-matching Tesseract word
|
||||
and attach its bounding box coordinates.
|
||||
|
||||
Args:
|
||||
tess_words: Word list from Tesseract with pixel coordinates.
|
||||
llm_vocab: Vocabulary list from Vision LLM.
|
||||
image_w: Image width in pixels.
|
||||
image_h: Image height in pixels.
|
||||
threshold: Minimum similarity ratio for a match.
|
||||
|
||||
Returns:
|
||||
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
|
||||
"""
|
||||
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
|
||||
return llm_vocab
|
||||
|
||||
for entry in llm_vocab:
|
||||
english = entry.get("english", "").lower().strip()
|
||||
german = entry.get("german", "").lower().strip()
|
||||
|
||||
if not english and not german:
|
||||
continue
|
||||
|
||||
# Try to match English word first, then German
|
||||
for field in ["english", "german"]:
|
||||
search_text = entry.get(field, "").lower().strip()
|
||||
if not search_text:
|
||||
continue
|
||||
|
||||
best_word = None
|
||||
best_ratio = 0.0
|
||||
|
||||
for word in tess_words:
|
||||
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_word = word
|
||||
|
||||
if best_word and best_ratio >= threshold:
|
||||
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
|
||||
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
|
||||
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
|
||||
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
|
||||
entry["bbox_match_field"] = field
|
||||
entry["bbox_match_ratio"] = round(best_ratio, 3)
|
||||
break # Found a match, no need to try the other field
|
||||
|
||||
return llm_vocab
|
||||
|
||||
|
||||
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG/JPEG image as bytes.
|
||||
lang: Tesseract language string.
|
||||
|
||||
Returns:
|
||||
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
|
||||
"""
|
||||
# Step 1: Extract bounding boxes
|
||||
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
|
||||
|
||||
if bbox_data.get("error"):
|
||||
return bbox_data
|
||||
|
||||
words = bbox_data["words"]
|
||||
image_w = bbox_data["image_width"]
|
||||
image_h = bbox_data["image_height"]
|
||||
|
||||
# Step 2: Group into lines
|
||||
lines = group_words_into_lines(words)
|
||||
|
||||
# Step 3: Detect columns
|
||||
col_info = detect_columns(lines, image_w)
|
||||
|
||||
# Step 4: Build vocabulary entries
|
||||
vocab = words_to_vocab_entries(
|
||||
lines,
|
||||
col_info["columns"],
|
||||
col_info["column_types"],
|
||||
image_w,
|
||||
image_h,
|
||||
)
|
||||
|
||||
return {
|
||||
"vocabulary": vocab,
|
||||
"words": words,
|
||||
"lines_count": len(lines),
|
||||
"columns": col_info["columns"],
|
||||
"column_types": col_info["column_types"],
|
||||
"image_width": image_w,
|
||||
"image_height": image_h,
|
||||
"word_count": len(words),
|
||||
}
|
||||
# Backward-compat shim -- module moved to ocr/engines/tesseract_extractor.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("ocr.engines.tesseract_extractor")
|
||||
|
||||
@@ -1,425 +1,4 @@
|
||||
"""
|
||||
Unified Grid Builder — merges multi-zone grid into a single Excel-like grid.
|
||||
|
||||
Takes content zone + box zones and produces one unified zone where:
|
||||
- All content rows use the dominant row height
|
||||
- Full-width boxes are integrated directly (box rows replace standard rows)
|
||||
- Partial-width boxes: extra rows inserted if box has more lines than standard
|
||||
- Box-origin cells carry metadata (bg_color, border) for visual distinction
|
||||
|
||||
The result is a single-zone StructuredGrid that can be:
|
||||
- Rendered in an Excel-like editor
|
||||
- Exported to Excel/CSV
|
||||
- Edited with unified row/column numbering
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import statistics
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _compute_dominant_row_height(content_zone: Dict) -> float:
|
||||
"""Median of content row-to-row spacings, excluding box-gap jumps."""
|
||||
rows = content_zone.get("rows", [])
|
||||
if len(rows) < 2:
|
||||
return 47.0
|
||||
|
||||
spacings = []
|
||||
for i in range(len(rows) - 1):
|
||||
y1 = rows[i].get("y_min_px", rows[i].get("y_min", 0))
|
||||
y2 = rows[i + 1].get("y_min_px", rows[i + 1].get("y_min", 0))
|
||||
d = y2 - y1
|
||||
if 0 < d < 100: # exclude box-gap jumps
|
||||
spacings.append(d)
|
||||
|
||||
if not spacings:
|
||||
return 47.0
|
||||
spacings.sort()
|
||||
return spacings[len(spacings) // 2]
|
||||
|
||||
|
||||
def _classify_boxes(
|
||||
box_zones: List[Dict],
|
||||
content_width: float,
|
||||
) -> List[Dict]:
|
||||
"""Classify each box as full_width or partial_width."""
|
||||
result = []
|
||||
for bz in box_zones:
|
||||
bb = bz.get("bbox_px", {})
|
||||
bw = bb.get("w", 0)
|
||||
bx = bb.get("x", 0)
|
||||
|
||||
if bw >= content_width * 0.85:
|
||||
classification = "full_width"
|
||||
side = "center"
|
||||
else:
|
||||
classification = "partial_width"
|
||||
# Determine which side of the page the box is on
|
||||
page_center = content_width / 2
|
||||
box_center = bx + bw / 2
|
||||
side = "right" if box_center > page_center else "left"
|
||||
|
||||
# Count total text lines in box (including \n within cells)
|
||||
total_lines = sum(
|
||||
(c.get("text", "").count("\n") + 1)
|
||||
for c in bz.get("cells", [])
|
||||
)
|
||||
|
||||
result.append({
|
||||
"zone": bz,
|
||||
"classification": classification,
|
||||
"side": side,
|
||||
"y_start": bb.get("y", 0),
|
||||
"y_end": bb.get("y", 0) + bb.get("h", 0),
|
||||
"total_lines": total_lines,
|
||||
"bg_hex": bz.get("box_bg_hex", ""),
|
||||
"bg_color": bz.get("box_bg_color", ""),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def build_unified_grid(
|
||||
zones: List[Dict],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
layout_metrics: Dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a single-zone unified grid from multi-zone grid data.
|
||||
|
||||
Returns a StructuredGrid with one zone containing all rows and cells.
|
||||
"""
|
||||
content_zone = None
|
||||
box_zones = []
|
||||
for z in zones:
|
||||
if z.get("zone_type") == "content":
|
||||
content_zone = z
|
||||
elif z.get("zone_type") == "box":
|
||||
box_zones.append(z)
|
||||
|
||||
if not content_zone:
|
||||
logger.warning("build_unified_grid: no content zone found")
|
||||
return {"zones": zones} # fallback: return as-is
|
||||
|
||||
box_zones.sort(key=lambda b: b.get("bbox_px", {}).get("y", 0))
|
||||
|
||||
dominant_h = _compute_dominant_row_height(content_zone)
|
||||
content_bbox = content_zone.get("bbox_px", {})
|
||||
content_width = content_bbox.get("w", image_width)
|
||||
content_x = content_bbox.get("x", 0)
|
||||
content_cols = content_zone.get("columns", [])
|
||||
num_cols = len(content_cols)
|
||||
|
||||
box_infos = _classify_boxes(box_zones, content_width)
|
||||
|
||||
logger.info(
|
||||
"build_unified_grid: dominant_h=%.1f, %d content rows, %d boxes (%s)",
|
||||
dominant_h, len(content_zone.get("rows", [])), len(box_infos),
|
||||
[b["classification"] for b in box_infos],
|
||||
)
|
||||
|
||||
# --- Build unified row list + cell list ---
|
||||
unified_rows: List[Dict] = []
|
||||
unified_cells: List[Dict] = []
|
||||
unified_row_idx = 0
|
||||
|
||||
# Content rows and cells indexed by row_index
|
||||
content_rows = content_zone.get("rows", [])
|
||||
content_cells = content_zone.get("cells", [])
|
||||
content_cells_by_row: Dict[int, List[Dict]] = {}
|
||||
for c in content_cells:
|
||||
content_cells_by_row.setdefault(c.get("row_index", -1), []).append(c)
|
||||
|
||||
# Track which content rows we've processed
|
||||
content_row_ptr = 0
|
||||
|
||||
for bi, box_info in enumerate(box_infos):
|
||||
bz = box_info["zone"]
|
||||
by_start = box_info["y_start"]
|
||||
by_end = box_info["y_end"]
|
||||
|
||||
# --- Add content rows ABOVE this box ---
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry >= by_start:
|
||||
break
|
||||
# Add this content row
|
||||
_add_content_row(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
cr, content_cells_by_row, dominant_h, image_height,
|
||||
)
|
||||
unified_row_idx += 1
|
||||
content_row_ptr += 1
|
||||
|
||||
# --- Add box rows ---
|
||||
if box_info["classification"] == "full_width":
|
||||
# Full-width box: integrate box rows directly
|
||||
_add_full_width_box(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
bz, box_info, dominant_h, num_cols, image_height,
|
||||
)
|
||||
unified_row_idx += len(bz.get("rows", []))
|
||||
# Skip content rows that overlap with this box
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry > by_end:
|
||||
break
|
||||
content_row_ptr += 1
|
||||
|
||||
else:
|
||||
# Partial-width box: merge with adjacent content rows
|
||||
unified_row_idx = _add_partial_width_box(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
bz, box_info, content_rows, content_cells_by_row,
|
||||
content_row_ptr, dominant_h, num_cols, image_height,
|
||||
content_x, content_width,
|
||||
)
|
||||
# Advance content pointer past box region
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry > by_end:
|
||||
break
|
||||
content_row_ptr += 1
|
||||
|
||||
# --- Add remaining content rows BELOW all boxes ---
|
||||
while content_row_ptr < len(content_rows):
|
||||
cr = content_rows[content_row_ptr]
|
||||
_add_content_row(
|
||||
unified_rows, unified_cells, unified_row_idx,
|
||||
cr, content_cells_by_row, dominant_h, image_height,
|
||||
)
|
||||
unified_row_idx += 1
|
||||
content_row_ptr += 1
|
||||
|
||||
# --- Build unified zone ---
|
||||
unified_zone = {
|
||||
"zone_index": 0,
|
||||
"zone_type": "unified",
|
||||
"bbox_px": content_bbox,
|
||||
"bbox_pct": content_zone.get("bbox_pct", {}),
|
||||
"border": None,
|
||||
"word_count": sum(len(c.get("word_boxes", [])) for c in unified_cells),
|
||||
"columns": content_cols,
|
||||
"rows": unified_rows,
|
||||
"cells": unified_cells,
|
||||
"header_rows": [],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"build_unified_grid: %d unified rows, %d cells (from %d content + %d box zones)",
|
||||
len(unified_rows), len(unified_cells),
|
||||
len(content_rows), len(box_zones),
|
||||
)
|
||||
|
||||
return {
|
||||
"zones": [unified_zone],
|
||||
"image_width": image_width,
|
||||
"image_height": image_height,
|
||||
"layout_metrics": layout_metrics,
|
||||
"summary": {
|
||||
"total_zones": 1,
|
||||
"total_columns": num_cols,
|
||||
"total_rows": len(unified_rows),
|
||||
"total_cells": len(unified_cells),
|
||||
},
|
||||
"is_unified": True,
|
||||
"dominant_row_h": dominant_h,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_row(idx: int, y: float, h: float, img_h: int, is_header: bool = False) -> Dict:
|
||||
return {
|
||||
"index": idx,
|
||||
"row_index": idx,
|
||||
"y_min_px": round(y),
|
||||
"y_max_px": round(y + h),
|
||||
"y_min_pct": round(y / img_h * 100, 2) if img_h else 0,
|
||||
"y_max_pct": round((y + h) / img_h * 100, 2) if img_h else 0,
|
||||
"is_header": is_header,
|
||||
}
|
||||
|
||||
|
||||
def _remap_cell(cell: Dict, new_row: int, new_col: int = None,
|
||||
source_type: str = "content", box_region: Dict = None) -> Dict:
|
||||
"""Create a new cell dict with remapped indices."""
|
||||
c = dict(cell)
|
||||
c["row_index"] = new_row
|
||||
if new_col is not None:
|
||||
c["col_index"] = new_col
|
||||
c["cell_id"] = f"U_R{new_row:02d}_C{c.get('col_index', 0)}"
|
||||
c["source_zone_type"] = source_type
|
||||
if box_region:
|
||||
c["box_region"] = box_region
|
||||
return c
|
||||
|
||||
|
||||
def _add_content_row(
|
||||
unified_rows, unified_cells, row_idx,
|
||||
content_row, cells_by_row, dominant_h, img_h,
|
||||
):
|
||||
"""Add a single content row to the unified grid."""
|
||||
y = content_row.get("y_min_px", content_row.get("y_min", 0))
|
||||
is_hdr = content_row.get("is_header", False)
|
||||
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h, is_hdr))
|
||||
|
||||
for cell in cells_by_row.get(content_row.get("index", -1), []):
|
||||
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
|
||||
|
||||
|
||||
def _add_full_width_box(
|
||||
unified_rows, unified_cells, start_row_idx,
|
||||
box_zone, box_info, dominant_h, num_cols, img_h,
|
||||
):
|
||||
"""Add a full-width box's rows to the unified grid."""
|
||||
box_rows = box_zone.get("rows", [])
|
||||
box_cells = box_zone.get("cells", [])
|
||||
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
|
||||
|
||||
# Distribute box height evenly among its rows
|
||||
box_h = box_info["y_end"] - box_info["y_start"]
|
||||
row_h = box_h / len(box_rows) if box_rows else dominant_h
|
||||
|
||||
for i, br in enumerate(box_rows):
|
||||
y = box_info["y_start"] + i * row_h
|
||||
new_idx = start_row_idx + i
|
||||
is_hdr = br.get("is_header", False)
|
||||
unified_rows.append(_make_row(new_idx, y, row_h, img_h, is_hdr))
|
||||
|
||||
for cell in box_cells:
|
||||
if cell.get("row_index") == br.get("index", i):
|
||||
unified_cells.append(
|
||||
_remap_cell(cell, new_idx, source_type="box", box_region=box_region)
|
||||
)
|
||||
|
||||
|
||||
def _add_partial_width_box(
|
||||
unified_rows, unified_cells, start_row_idx,
|
||||
box_zone, box_info, content_rows, content_cells_by_row,
|
||||
content_row_ptr, dominant_h, num_cols, img_h,
|
||||
content_x, content_width,
|
||||
) -> int:
|
||||
"""Add a partial-width box merged with content rows.
|
||||
|
||||
Returns the next unified_row_idx after processing.
|
||||
"""
|
||||
by_start = box_info["y_start"]
|
||||
by_end = box_info["y_end"]
|
||||
box_h = by_end - by_start
|
||||
box_region = {"bg_hex": box_info["bg_hex"], "bg_color": box_info["bg_color"], "border": True}
|
||||
|
||||
# Content rows in the box's Y range
|
||||
overlap_content_rows = []
|
||||
ptr = content_row_ptr
|
||||
while ptr < len(content_rows):
|
||||
cr = content_rows[ptr]
|
||||
cry = cr.get("y_min_px", cr.get("y_min", 0))
|
||||
if cry > by_end:
|
||||
break
|
||||
if cry >= by_start:
|
||||
overlap_content_rows.append(cr)
|
||||
ptr += 1
|
||||
|
||||
# How many standard rows fit in the box height
|
||||
standard_rows = max(1, math.floor(box_h / dominant_h))
|
||||
# How many text lines the box actually has
|
||||
box_text_lines = box_info["total_lines"]
|
||||
# Extra rows needed
|
||||
extra_rows = max(0, box_text_lines - standard_rows)
|
||||
total_rows_for_region = standard_rows + extra_rows
|
||||
|
||||
logger.info(
|
||||
"partial box: standard=%d, box_lines=%d, extra=%d, content_overlap=%d",
|
||||
standard_rows, box_text_lines, extra_rows, len(overlap_content_rows),
|
||||
)
|
||||
|
||||
# Determine which columns the box occupies
|
||||
box_bb = box_zone.get("bbox_px", {})
|
||||
box_x = box_bb.get("x", 0)
|
||||
box_w = box_bb.get("w", 0)
|
||||
|
||||
# Map box to content columns: find which content columns overlap
|
||||
box_col_start = 0
|
||||
box_col_end = num_cols
|
||||
content_cols_list = []
|
||||
for z_col_idx in range(num_cols):
|
||||
# Find the column definition by checking all column entries
|
||||
# Simple heuristic: if box starts past halfway, it's the right columns
|
||||
pass
|
||||
|
||||
# Simpler approach: box on right side → last N columns
|
||||
# box on left side → first N columns
|
||||
if box_info["side"] == "right":
|
||||
# Box starts at x=box_x. Find first content column that overlaps
|
||||
box_col_start = num_cols # default: beyond all columns
|
||||
for z in (box_zone.get("columns") or [{"index": 0}]):
|
||||
pass
|
||||
# Use content column positions to determine overlap
|
||||
content_cols_data = [
|
||||
{"idx": c.get("index", i), "x_min": c.get("x_min_px", 0), "x_max": c.get("x_max_px", 0)}
|
||||
for i, c in enumerate(content_rows[0:0] or []) # placeholder
|
||||
]
|
||||
# Simple: split columns at midpoint
|
||||
box_col_start = num_cols // 2 # right half
|
||||
box_col_end = num_cols
|
||||
else:
|
||||
box_col_start = 0
|
||||
box_col_end = num_cols // 2
|
||||
|
||||
# Build rows for this region
|
||||
box_cells = box_zone.get("cells", [])
|
||||
box_rows = box_zone.get("rows", [])
|
||||
row_idx = start_row_idx
|
||||
|
||||
# Expand box cell texts with \n into individual lines for row mapping
|
||||
box_lines: List[Tuple[str, Dict]] = [] # (text_line, parent_cell)
|
||||
for bc in sorted(box_cells, key=lambda c: c.get("row_index", 0)):
|
||||
text = bc.get("text", "")
|
||||
for line in text.split("\n"):
|
||||
box_lines.append((line.strip(), bc))
|
||||
|
||||
for i in range(total_rows_for_region):
|
||||
y = by_start + i * dominant_h
|
||||
unified_rows.append(_make_row(row_idx, y, dominant_h, img_h))
|
||||
|
||||
# Content cells for this row (from overlapping content rows)
|
||||
if i < len(overlap_content_rows):
|
||||
cr = overlap_content_rows[i]
|
||||
for cell in content_cells_by_row.get(cr.get("index", -1), []):
|
||||
# Only include cells from columns NOT covered by the box
|
||||
ci = cell.get("col_index", 0)
|
||||
if ci < box_col_start or ci >= box_col_end:
|
||||
unified_cells.append(_remap_cell(cell, row_idx, source_type="content"))
|
||||
|
||||
# Box cell for this row
|
||||
if i < len(box_lines):
|
||||
line_text, parent_cell = box_lines[i]
|
||||
box_cell = {
|
||||
"cell_id": f"U_R{row_idx:02d}_C{box_col_start}",
|
||||
"row_index": row_idx,
|
||||
"col_index": box_col_start,
|
||||
"col_type": "spanning_header" if (box_col_end - box_col_start) > 1 else parent_cell.get("col_type", "column_1"),
|
||||
"colspan": box_col_end - box_col_start,
|
||||
"text": line_text,
|
||||
"confidence": parent_cell.get("confidence", 0),
|
||||
"bbox_px": parent_cell.get("bbox_px", {}),
|
||||
"bbox_pct": parent_cell.get("bbox_pct", {}),
|
||||
"word_boxes": [],
|
||||
"ocr_engine": parent_cell.get("ocr_engine", ""),
|
||||
"is_bold": parent_cell.get("is_bold", False),
|
||||
"source_zone_type": "box",
|
||||
"box_region": box_region,
|
||||
}
|
||||
unified_cells.append(box_cell)
|
||||
|
||||
row_idx += 1
|
||||
|
||||
return row_idx
|
||||
# Backward-compat shim -- module moved to grid/unified.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("grid.unified")
|
||||
|
||||
6
klausur-service/backend/upload/__init__.py
Normal file
6
klausur-service/backend/upload/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Upload package — chunked and mobile upload endpoints.
|
||||
|
||||
Moved from backend/ flat modules (upload_api*.py).
|
||||
Backward-compatible shim files remain at the old locations.
|
||||
"""
|
||||
29
klausur-service/backend/upload/api.py
Normal file
29
klausur-service/backend/upload/api.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Mobile Upload API — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
|
||||
upload_api_mobile — mobile HTML upload page
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .chunked import ( # noqa: F401
|
||||
router as _chunked_router,
|
||||
UPLOAD_DIR,
|
||||
CHUNK_DIR,
|
||||
EH_UPLOAD_DIR,
|
||||
_upload_sessions,
|
||||
InitUploadRequest,
|
||||
InitUploadResponse,
|
||||
ChunkUploadResponse,
|
||||
FinalizeResponse,
|
||||
)
|
||||
from .mobile import router as _mobile_router # noqa: F401
|
||||
|
||||
# Composite router that includes both sub-routers
|
||||
router = APIRouter()
|
||||
router.include_router(_chunked_router)
|
||||
router.include_router(_mobile_router)
|
||||
320
klausur-service/backend/upload/chunked.py
Normal file
320
klausur-service/backend/upload/chunked.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
|
||||
|
||||
Extracted from upload_api.py for modularity.
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configuration
|
||||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||||
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
|
||||
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
|
||||
|
||||
# Ensure directories exist
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# In-memory storage for upload sessions (for simplicity)
|
||||
# In production, use Redis or database
|
||||
_upload_sessions: Dict[str, dict] = {}
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
class InitUploadRequest(BaseModel):
|
||||
filename: str
|
||||
filesize: int
|
||||
chunks: int
|
||||
destination: str = "klausur" # "klausur" or "rag"
|
||||
|
||||
|
||||
class InitUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_size: int
|
||||
total_chunks: int
|
||||
message: str
|
||||
|
||||
|
||||
class ChunkUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_index: int
|
||||
received: bool
|
||||
chunks_received: int
|
||||
total_chunks: int
|
||||
|
||||
|
||||
class FinalizeResponse(BaseModel):
|
||||
upload_id: str
|
||||
filename: str
|
||||
filepath: str
|
||||
filesize: int
|
||||
checksum: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/init", response_model=InitUploadResponse)
|
||||
async def init_upload(request: InitUploadRequest):
|
||||
"""
|
||||
Initialize a chunked upload session.
|
||||
|
||||
Returns an upload_id that must be used for subsequent chunk uploads.
|
||||
"""
|
||||
upload_id = str(uuid.uuid4())
|
||||
|
||||
# Create session directory
|
||||
session_dir = CHUNK_DIR / upload_id
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store session info
|
||||
_upload_sessions[upload_id] = {
|
||||
"filename": request.filename,
|
||||
"filesize": request.filesize,
|
||||
"total_chunks": request.chunks,
|
||||
"received_chunks": set(),
|
||||
"destination": request.destination,
|
||||
"session_dir": str(session_dir),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
return InitUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_size=5 * 1024 * 1024, # 5 MB
|
||||
total_chunks=request.chunks,
|
||||
message="Upload-Session erstellt"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chunk", response_model=ChunkUploadResponse)
|
||||
async def upload_chunk(
|
||||
chunk: UploadFile = File(...),
|
||||
upload_id: str = Form(...),
|
||||
chunk_index: int = Form(...)
|
||||
):
|
||||
"""
|
||||
Upload a single chunk of a file.
|
||||
|
||||
Chunks are stored temporarily until finalize is called.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
|
||||
)
|
||||
|
||||
# Save chunk
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
|
||||
|
||||
with open(chunk_path, "wb") as f:
|
||||
content = await chunk.read()
|
||||
f.write(content)
|
||||
|
||||
# Track received chunks
|
||||
session["received_chunks"].add(chunk_index)
|
||||
|
||||
return ChunkUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_index=chunk_index,
|
||||
received=True,
|
||||
chunks_received=len(session["received_chunks"]),
|
||||
total_chunks=session["total_chunks"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/finalize", response_model=FinalizeResponse)
|
||||
async def finalize_upload(upload_id: str = Form(...)):
|
||||
"""
|
||||
Finalize the upload by combining all chunks into a single file.
|
||||
|
||||
Validates that all chunks were received and calculates checksum.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Check if all chunks received
|
||||
if len(session["received_chunks"]) != session["total_chunks"]:
|
||||
missing = session["total_chunks"] - len(session["received_chunks"])
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
|
||||
)
|
||||
|
||||
# Determine destination directory
|
||||
if session["destination"] == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = session["filename"].replace(" ", "_")
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Combine chunks
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as outfile:
|
||||
for i in range(session["total_chunks"]):
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
|
||||
|
||||
if not chunk_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Chunk {i} nicht gefunden"
|
||||
)
|
||||
|
||||
with open(chunk_path, "rb") as infile:
|
||||
data = infile.read()
|
||||
outfile.write(data)
|
||||
hasher.update(data)
|
||||
total_size += len(data)
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
checksum = hasher.hexdigest()
|
||||
|
||||
return FinalizeResponse(
|
||||
upload_id=upload_id,
|
||||
filename=final_filename,
|
||||
filepath=str(final_path),
|
||||
filesize=total_size,
|
||||
checksum=checksum,
|
||||
message="Upload erfolgreich abgeschlossen"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/simple")
|
||||
async def simple_upload(
|
||||
file: UploadFile = File(...),
|
||||
destination: str = Form("klausur")
|
||||
):
|
||||
"""
|
||||
Simple single-request upload for smaller files (<10MB).
|
||||
|
||||
For larger files, use the chunked upload endpoints.
|
||||
"""
|
||||
# Determine destination directory
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Calculate checksum while writing
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024) # Read 1MB at a time
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
return {
|
||||
"filename": final_filename,
|
||||
"filepath": str(final_path),
|
||||
"filesize": total_size,
|
||||
"checksum": hasher.hexdigest(),
|
||||
"message": "Upload erfolgreich"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status/{upload_id}")
|
||||
async def get_upload_status(upload_id: str):
|
||||
"""
|
||||
Get the status of an ongoing upload.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
return {
|
||||
"upload_id": upload_id,
|
||||
"filename": session["filename"],
|
||||
"total_chunks": session["total_chunks"],
|
||||
"received_chunks": len(session["received_chunks"]),
|
||||
"progress_percent": round(
|
||||
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
|
||||
),
|
||||
"destination": session["destination"],
|
||||
"created_at": session["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cancel/{upload_id}")
|
||||
async def cancel_upload(upload_id: str):
|
||||
"""
|
||||
Cancel an ongoing upload and clean up temporary files.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
return {"message": "Upload abgebrochen", "upload_id": upload_id}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_uploads(destination: str = "klausur"):
|
||||
"""
|
||||
List all uploaded files in the specified destination.
|
||||
"""
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
files = []
|
||||
|
||||
for f in dest_dir.iterdir():
|
||||
if f.is_file() and f.suffix.lower() == ".pdf":
|
||||
stat = f.stat()
|
||||
files.append({
|
||||
"filename": f.name,
|
||||
"size": stat.st_size,
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
})
|
||||
|
||||
files.sort(key=lambda x: x["modified"], reverse=True)
|
||||
|
||||
return {
|
||||
"destination": destination,
|
||||
"count": len(files),
|
||||
"files": files[:50] # Limit to 50 most recent
|
||||
}
|
||||
292
klausur-service/backend/upload/mobile.py
Normal file
292
klausur-service/backend/upload/mobile.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
|
||||
|
||||
Extracted from upload_api.py for modularity.
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
@router.get("/mobile", response_class=HTMLResponse)
|
||||
async def mobile_upload_page():
|
||||
"""
|
||||
Serve the mobile upload page directly from the klausur-service.
|
||||
This allows mobile devices to upload without needing the Next.js website.
|
||||
"""
|
||||
html_content = '''<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<title>BreakPilot Upload</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
|
||||
color: white;
|
||||
min-height: 100vh;
|
||||
padding: 16px;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
border-bottom: 1px solid #334155;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.header h1 { font-size: 20px; color: #60a5fa; }
|
||||
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
|
||||
.destination-selector {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.dest-btn {
|
||||
flex: 1;
|
||||
padding: 14px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
|
||||
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
|
||||
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
|
||||
.upload-zone {
|
||||
border: 2px dashed #475569;
|
||||
border-radius: 16px;
|
||||
padding: 40px 20px;
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
transition: all 0.2s;
|
||||
position: relative;
|
||||
}
|
||||
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
|
||||
.upload-zone input[type="file"] {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
}
|
||||
.upload-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
background: #334155;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0 auto 16px;
|
||||
font-size: 28px;
|
||||
}
|
||||
.upload-title { font-size: 18px; margin-bottom: 8px; }
|
||||
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
|
||||
.upload-hint { font-size: 12px; color: #64748b; }
|
||||
.file-list { margin-bottom: 24px; }
|
||||
.file-item {
|
||||
background: #1e293b;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
|
||||
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
|
||||
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
|
||||
.file-name { font-weight: 500; word-break: break-all; }
|
||||
.file-size { font-size: 14px; color: #94a3b8; }
|
||||
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
|
||||
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
|
||||
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
|
||||
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
|
||||
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
|
||||
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
|
||||
.info-box {
|
||||
background: rgba(30, 41, 59, 0.5);
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
font-size: 14px;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
|
||||
.info-box ul { padding-left: 20px; }
|
||||
.info-box li { margin-bottom: 4px; }
|
||||
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
|
||||
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header class="header">
|
||||
<h1>BreakPilot Upload</h1>
|
||||
<span class="badge">DSGVO-konform</span>
|
||||
</header>
|
||||
|
||||
<div class="destination-selector">
|
||||
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
|
||||
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
|
||||
</div>
|
||||
|
||||
<div class="upload-zone" id="upload-zone">
|
||||
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
|
||||
<div class="upload-icon">☁</div>
|
||||
<div class="upload-title">PDF-Dateien hochladen</div>
|
||||
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
|
||||
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
|
||||
</div>
|
||||
|
||||
<div class="stats" id="stats" style="display: none;">
|
||||
<span id="completed-count">0 von 0 fertig</span>
|
||||
<span id="total-size">0 B gesamt</span>
|
||||
</div>
|
||||
|
||||
<div class="file-list" id="file-list"></div>
|
||||
|
||||
<div class="info-box">
|
||||
<h3>Hinweise:</h3>
|
||||
<ul>
|
||||
<li>Die Dateien werden lokal im WLAN uebertragen</li>
|
||||
<li>Keine Daten werden ins Internet gesendet</li>
|
||||
<li>Unterstuetzte Formate: PDF</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
|
||||
|
||||
<script>
|
||||
const CHUNK_SIZE = 5 * 1024 * 1024;
|
||||
let destination = 'klausur';
|
||||
let files = [];
|
||||
const serverUrl = window.location.origin;
|
||||
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
|
||||
|
||||
function setDestination(dest) {
|
||||
destination = dest;
|
||||
document.querySelectorAll('.dest-btn').forEach(btn => {
|
||||
btn.classList.remove('active-klausur', 'active-rag');
|
||||
});
|
||||
if (dest === 'klausur') {
|
||||
document.getElementById('btn-klausur').classList.add('active-klausur');
|
||||
} else {
|
||||
document.getElementById('btn-rag').classList.add('active-rag');
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes) {
|
||||
if (bytes === 0) return '0 B';
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
}
|
||||
|
||||
function updateStats() {
|
||||
const completed = files.filter(f => f.status === 'complete').length;
|
||||
const total = files.reduce((sum, f) => sum + f.size, 0);
|
||||
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
|
||||
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
|
||||
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
|
||||
}
|
||||
|
||||
function renderFiles() {
|
||||
const list = document.getElementById('file-list');
|
||||
list.innerHTML = files.map(f => {
|
||||
let statusHtml = '';
|
||||
if (f.status === 'uploading' || f.status === 'pending') {
|
||||
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
|
||||
} else if (f.status === 'complete') {
|
||||
statusHtml = '<div class="status-complete">✓ Erfolgreich hochgeladen</div>';
|
||||
} else if (f.status === 'error') {
|
||||
statusHtml = '<div class="status-error">⚠ ' + (f.error || 'Fehler beim Hochladen') + '</div>';
|
||||
}
|
||||
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">×</button></div>' + statusHtml + '</div>';
|
||||
}).join('');
|
||||
updateStats();
|
||||
}
|
||||
|
||||
function removeFile(id) {
|
||||
files = files.filter(f => f.id !== id);
|
||||
renderFiles();
|
||||
}
|
||||
|
||||
async function uploadFile(file, fileId) {
|
||||
const updateProgress = (progress) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.progress = progress; renderFiles(); }
|
||||
};
|
||||
const setStatus = (status, error) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
|
||||
};
|
||||
|
||||
try {
|
||||
setStatus('uploading');
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
// Chunked upload
|
||||
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
|
||||
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
|
||||
});
|
||||
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
|
||||
const { upload_id } = await initRes.json();
|
||||
|
||||
for (let i = 0; i < totalChunks; i++) {
|
||||
const start = i * CHUNK_SIZE;
|
||||
const end = Math.min(start + CHUNK_SIZE, file.size);
|
||||
const chunk = file.slice(start, end);
|
||||
const formData = new FormData();
|
||||
formData.append('chunk', chunk);
|
||||
formData.append('upload_id', upload_id);
|
||||
formData.append('chunk_index', i.toString());
|
||||
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
|
||||
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
|
||||
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
|
||||
}
|
||||
const finalizeForm = new FormData();
|
||||
finalizeForm.append('upload_id', upload_id);
|
||||
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
|
||||
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
|
||||
} else {
|
||||
// Simple upload
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('destination', destination);
|
||||
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
|
||||
if (!res.ok) throw new Error('Upload fehlgeschlagen');
|
||||
updateProgress(100);
|
||||
}
|
||||
setStatus('complete');
|
||||
} catch (e) {
|
||||
setStatus('error', e.message);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFiles(fileList) {
|
||||
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
|
||||
newFiles.forEach(file => {
|
||||
const id = Math.random().toString(36).substr(2, 9);
|
||||
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
|
||||
renderFiles();
|
||||
uploadFile(file, id);
|
||||
});
|
||||
}
|
||||
|
||||
// Drag & Drop
|
||||
const zone = document.getElementById('upload-zone');
|
||||
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
|
||||
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
|
||||
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
|
||||
</script>
|
||||
</body>
|
||||
</html>'''
|
||||
return HTMLResponse(content=html_content)
|
||||
@@ -1,29 +1,4 @@
|
||||
"""
|
||||
Mobile Upload API — barrel re-export.
|
||||
|
||||
All implementation split into:
|
||||
upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
|
||||
upload_api_mobile — mobile HTML upload page
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from upload_api_chunked import ( # noqa: F401
|
||||
router as _chunked_router,
|
||||
UPLOAD_DIR,
|
||||
CHUNK_DIR,
|
||||
EH_UPLOAD_DIR,
|
||||
_upload_sessions,
|
||||
InitUploadRequest,
|
||||
InitUploadResponse,
|
||||
ChunkUploadResponse,
|
||||
FinalizeResponse,
|
||||
)
|
||||
from upload_api_mobile import router as _mobile_router # noqa: F401
|
||||
|
||||
# Composite router that includes both sub-routers
|
||||
router = APIRouter()
|
||||
router.include_router(_chunked_router)
|
||||
router.include_router(_mobile_router)
|
||||
# Backward-compat shim -- module moved to upload/api.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("upload.api")
|
||||
|
||||
@@ -1,320 +1,4 @@
|
||||
"""
|
||||
Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
|
||||
|
||||
Extracted from upload_api.py for modularity.
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configuration
|
||||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||||
CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
|
||||
EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
|
||||
|
||||
# Ensure directories exist
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CHUNK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# In-memory storage for upload sessions (for simplicity)
|
||||
# In production, use Redis or database
|
||||
_upload_sessions: Dict[str, dict] = {}
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
class InitUploadRequest(BaseModel):
|
||||
filename: str
|
||||
filesize: int
|
||||
chunks: int
|
||||
destination: str = "klausur" # "klausur" or "rag"
|
||||
|
||||
|
||||
class InitUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_size: int
|
||||
total_chunks: int
|
||||
message: str
|
||||
|
||||
|
||||
class ChunkUploadResponse(BaseModel):
|
||||
upload_id: str
|
||||
chunk_index: int
|
||||
received: bool
|
||||
chunks_received: int
|
||||
total_chunks: int
|
||||
|
||||
|
||||
class FinalizeResponse(BaseModel):
|
||||
upload_id: str
|
||||
filename: str
|
||||
filepath: str
|
||||
filesize: int
|
||||
checksum: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/init", response_model=InitUploadResponse)
|
||||
async def init_upload(request: InitUploadRequest):
|
||||
"""
|
||||
Initialize a chunked upload session.
|
||||
|
||||
Returns an upload_id that must be used for subsequent chunk uploads.
|
||||
"""
|
||||
upload_id = str(uuid.uuid4())
|
||||
|
||||
# Create session directory
|
||||
session_dir = CHUNK_DIR / upload_id
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store session info
|
||||
_upload_sessions[upload_id] = {
|
||||
"filename": request.filename,
|
||||
"filesize": request.filesize,
|
||||
"total_chunks": request.chunks,
|
||||
"received_chunks": set(),
|
||||
"destination": request.destination,
|
||||
"session_dir": str(session_dir),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
return InitUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_size=5 * 1024 * 1024, # 5 MB
|
||||
total_chunks=request.chunks,
|
||||
message="Upload-Session erstellt"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chunk", response_model=ChunkUploadResponse)
|
||||
async def upload_chunk(
|
||||
chunk: UploadFile = File(...),
|
||||
upload_id: str = Form(...),
|
||||
chunk_index: int = Form(...)
|
||||
):
|
||||
"""
|
||||
Upload a single chunk of a file.
|
||||
|
||||
Chunks are stored temporarily until finalize is called.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Ungueltiger Chunk-Index: {chunk_index}"
|
||||
)
|
||||
|
||||
# Save chunk
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
|
||||
|
||||
with open(chunk_path, "wb") as f:
|
||||
content = await chunk.read()
|
||||
f.write(content)
|
||||
|
||||
# Track received chunks
|
||||
session["received_chunks"].add(chunk_index)
|
||||
|
||||
return ChunkUploadResponse(
|
||||
upload_id=upload_id,
|
||||
chunk_index=chunk_index,
|
||||
received=True,
|
||||
chunks_received=len(session["received_chunks"]),
|
||||
total_chunks=session["total_chunks"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/finalize", response_model=FinalizeResponse)
|
||||
async def finalize_upload(upload_id: str = Form(...)):
|
||||
"""
|
||||
Finalize the upload by combining all chunks into a single file.
|
||||
|
||||
Validates that all chunks were received and calculates checksum.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Check if all chunks received
|
||||
if len(session["received_chunks"]) != session["total_chunks"]:
|
||||
missing = session["total_chunks"] - len(session["received_chunks"])
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
|
||||
)
|
||||
|
||||
# Determine destination directory
|
||||
if session["destination"] == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = session["filename"].replace(" ", "_")
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Combine chunks
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as outfile:
|
||||
for i in range(session["total_chunks"]):
|
||||
chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
|
||||
|
||||
if not chunk_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Chunk {i} nicht gefunden"
|
||||
)
|
||||
|
||||
with open(chunk_path, "rb") as infile:
|
||||
data = infile.read()
|
||||
outfile.write(data)
|
||||
hasher.update(data)
|
||||
total_size += len(data)
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
checksum = hasher.hexdigest()
|
||||
|
||||
return FinalizeResponse(
|
||||
upload_id=upload_id,
|
||||
filename=final_filename,
|
||||
filepath=str(final_path),
|
||||
filesize=total_size,
|
||||
checksum=checksum,
|
||||
message="Upload erfolgreich abgeschlossen"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/simple")
|
||||
async def simple_upload(
|
||||
file: UploadFile = File(...),
|
||||
destination: str = Form("klausur")
|
||||
):
|
||||
"""
|
||||
Simple single-request upload for smaller files (<10MB).
|
||||
|
||||
For larger files, use the chunked upload endpoints.
|
||||
"""
|
||||
# Determine destination directory
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
|
||||
final_filename = f"{timestamp}_{safe_filename}"
|
||||
final_path = dest_dir / final_filename
|
||||
|
||||
# Calculate checksum while writing
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(final_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024) # Read 1MB at a time
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
return {
|
||||
"filename": final_filename,
|
||||
"filepath": str(final_path),
|
||||
"filesize": total_size,
|
||||
"checksum": hasher.hexdigest(),
|
||||
"message": "Upload erfolgreich"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status/{upload_id}")
|
||||
async def get_upload_status(upload_id: str):
|
||||
"""
|
||||
Get the status of an ongoing upload.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
return {
|
||||
"upload_id": upload_id,
|
||||
"filename": session["filename"],
|
||||
"total_chunks": session["total_chunks"],
|
||||
"received_chunks": len(session["received_chunks"]),
|
||||
"progress_percent": round(
|
||||
len(session["received_chunks"]) / session["total_chunks"] * 100, 1
|
||||
),
|
||||
"destination": session["destination"],
|
||||
"created_at": session["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cancel/{upload_id}")
|
||||
async def cancel_upload(upload_id: str):
|
||||
"""
|
||||
Cancel an ongoing upload and clean up temporary files.
|
||||
"""
|
||||
if upload_id not in _upload_sessions:
|
||||
raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
|
||||
|
||||
session = _upload_sessions[upload_id]
|
||||
|
||||
# Clean up chunks
|
||||
shutil.rmtree(session["session_dir"], ignore_errors=True)
|
||||
del _upload_sessions[upload_id]
|
||||
|
||||
return {"message": "Upload abgebrochen", "upload_id": upload_id}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_uploads(destination: str = "klausur"):
|
||||
"""
|
||||
List all uploaded files in the specified destination.
|
||||
"""
|
||||
if destination == "rag":
|
||||
dest_dir = EH_UPLOAD_DIR
|
||||
else:
|
||||
dest_dir = UPLOAD_DIR
|
||||
|
||||
files = []
|
||||
|
||||
for f in dest_dir.iterdir():
|
||||
if f.is_file() and f.suffix.lower() == ".pdf":
|
||||
stat = f.stat()
|
||||
files.append({
|
||||
"filename": f.name,
|
||||
"size": stat.st_size,
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
})
|
||||
|
||||
files.sort(key=lambda x: x["modified"], reverse=True)
|
||||
|
||||
return {
|
||||
"destination": destination,
|
||||
"count": len(files),
|
||||
"files": files[:50] # Limit to 50 most recent
|
||||
}
|
||||
# Backward-compat shim -- module moved to upload/chunked.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("upload.chunked")
|
||||
|
||||
@@ -1,292 +1,4 @@
|
||||
"""
|
||||
Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
|
||||
|
||||
Extracted from upload_api.py for modularity.
|
||||
|
||||
DSGVO-konform: Data stays local in WLAN, no external transmission.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
|
||||
|
||||
|
||||
@router.get("/mobile", response_class=HTMLResponse)
|
||||
async def mobile_upload_page():
|
||||
"""
|
||||
Serve the mobile upload page directly from the klausur-service.
|
||||
This allows mobile devices to upload without needing the Next.js website.
|
||||
"""
|
||||
html_content = '''<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<title>BreakPilot Upload</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
|
||||
color: white;
|
||||
min-height: 100vh;
|
||||
padding: 16px;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
border-bottom: 1px solid #334155;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.header h1 { font-size: 20px; color: #60a5fa; }
|
||||
.badge { font-size: 10px; background: #1e293b; padding: 4px 8px; border-radius: 4px; color: #94a3b8; }
|
||||
.destination-selector {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.dest-btn {
|
||||
flex: 1;
|
||||
padding: 14px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.dest-btn.active-klausur { background: #2563eb; color: white; box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3); }
|
||||
.dest-btn.active-rag { background: #7c3aed; color: white; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3); }
|
||||
.dest-btn:not(.active-klausur):not(.active-rag) { background: #1e293b; color: #94a3b8; }
|
||||
.upload-zone {
|
||||
border: 2px dashed #475569;
|
||||
border-radius: 16px;
|
||||
padding: 40px 20px;
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
transition: all 0.2s;
|
||||
position: relative;
|
||||
}
|
||||
.upload-zone.dragover { border-color: #60a5fa; background: rgba(96, 165, 250, 0.1); transform: scale(1.02); }
|
||||
.upload-zone input[type="file"] {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
}
|
||||
.upload-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
background: #334155;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0 auto 16px;
|
||||
font-size: 28px;
|
||||
}
|
||||
.upload-title { font-size: 18px; margin-bottom: 8px; }
|
||||
.upload-subtitle { font-size: 14px; color: #94a3b8; margin-bottom: 16px; }
|
||||
.upload-hint { font-size: 12px; color: #64748b; }
|
||||
.file-list { margin-bottom: 24px; }
|
||||
.file-item {
|
||||
background: #1e293b;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.file-item.error { border: 2px solid rgba(239, 68, 68, 0.5); }
|
||||
.file-item.complete { border: 2px solid rgba(34, 197, 94, 0.3); }
|
||||
.file-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 8px; }
|
||||
.file-name { font-weight: 500; word-break: break-all; }
|
||||
.file-size { font-size: 14px; color: #94a3b8; }
|
||||
.remove-btn { background: none; border: none; color: #94a3b8; font-size: 20px; cursor: pointer; padding: 4px; }
|
||||
.progress-bar { height: 6px; background: #334155; border-radius: 3px; overflow: hidden; margin-top: 12px; }
|
||||
.progress-fill { height: 100%; background: linear-gradient(90deg, #3b82f6, #60a5fa); transition: width 0.3s; }
|
||||
.progress-text { font-size: 12px; color: #94a3b8; margin-top: 4px; }
|
||||
.status-complete { display: flex; align-items: center; gap: 8px; color: #22c55e; font-size: 14px; margin-top: 12px; }
|
||||
.status-error { display: flex; align-items: center; gap: 8px; color: #ef4444; font-size: 14px; margin-top: 12px; }
|
||||
.info-box {
|
||||
background: rgba(30, 41, 59, 0.5);
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
font-size: 14px;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.info-box h3 { color: #cbd5e1; margin-bottom: 8px; font-size: 14px; }
|
||||
.info-box ul { padding-left: 20px; }
|
||||
.info-box li { margin-bottom: 4px; }
|
||||
.server-info { text-align: center; font-size: 12px; color: #64748b; margin-top: 16px; }
|
||||
.stats { display: flex; justify-content: space-between; font-size: 14px; color: #94a3b8; padding: 0 8px; margin-bottom: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header class="header">
|
||||
<h1>BreakPilot Upload</h1>
|
||||
<span class="badge">DSGVO-konform</span>
|
||||
</header>
|
||||
|
||||
<div class="destination-selector">
|
||||
<button class="dest-btn active-klausur" id="btn-klausur" onclick="setDestination('klausur')">Klausuren</button>
|
||||
<button class="dest-btn" id="btn-rag" onclick="setDestination('rag')">Erwartungshorizonte</button>
|
||||
</div>
|
||||
|
||||
<div class="upload-zone" id="upload-zone">
|
||||
<input type="file" accept=".pdf" multiple onchange="handleFiles(this.files)">
|
||||
<div class="upload-icon">☁</div>
|
||||
<div class="upload-title">PDF-Dateien hochladen</div>
|
||||
<div class="upload-subtitle">Tippen zum Auswaehlen oder hierher ziehen</div>
|
||||
<div class="upload-hint">Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen</div>
|
||||
</div>
|
||||
|
||||
<div class="stats" id="stats" style="display: none;">
|
||||
<span id="completed-count">0 von 0 fertig</span>
|
||||
<span id="total-size">0 B gesamt</span>
|
||||
</div>
|
||||
|
||||
<div class="file-list" id="file-list"></div>
|
||||
|
||||
<div class="info-box">
|
||||
<h3>Hinweise:</h3>
|
||||
<ul>
|
||||
<li>Die Dateien werden lokal im WLAN uebertragen</li>
|
||||
<li>Keine Daten werden ins Internet gesendet</li>
|
||||
<li>Unterstuetzte Formate: PDF</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="server-info" id="server-info">Server: wird ermittelt...</div>
|
||||
|
||||
<script>
|
||||
const CHUNK_SIZE = 5 * 1024 * 1024;
|
||||
let destination = 'klausur';
|
||||
let files = [];
|
||||
const serverUrl = window.location.origin;
|
||||
document.getElementById('server-info').textContent = 'Server: ' + serverUrl;
|
||||
|
||||
function setDestination(dest) {
|
||||
destination = dest;
|
||||
document.querySelectorAll('.dest-btn').forEach(btn => {
|
||||
btn.classList.remove('active-klausur', 'active-rag');
|
||||
});
|
||||
if (dest === 'klausur') {
|
||||
document.getElementById('btn-klausur').classList.add('active-klausur');
|
||||
} else {
|
||||
document.getElementById('btn-rag').classList.add('active-rag');
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes) {
|
||||
if (bytes === 0) return '0 B';
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
}
|
||||
|
||||
function updateStats() {
|
||||
const completed = files.filter(f => f.status === 'complete').length;
|
||||
const total = files.reduce((sum, f) => sum + f.size, 0);
|
||||
document.getElementById('completed-count').textContent = completed + ' von ' + files.length + ' fertig';
|
||||
document.getElementById('total-size').textContent = formatSize(total) + ' gesamt';
|
||||
document.getElementById('stats').style.display = files.length > 0 ? 'flex' : 'none';
|
||||
}
|
||||
|
||||
function renderFiles() {
|
||||
const list = document.getElementById('file-list');
|
||||
list.innerHTML = files.map(f => {
|
||||
let statusHtml = '';
|
||||
if (f.status === 'uploading' || f.status === 'pending') {
|
||||
statusHtml = '<div class="progress-bar"><div class="progress-fill" style="width: ' + f.progress + '%"></div></div><div class="progress-text">' + f.progress + '% hochgeladen</div>';
|
||||
} else if (f.status === 'complete') {
|
||||
statusHtml = '<div class="status-complete">✓ Erfolgreich hochgeladen</div>';
|
||||
} else if (f.status === 'error') {
|
||||
statusHtml = '<div class="status-error">⚠ ' + (f.error || 'Fehler beim Hochladen') + '</div>';
|
||||
}
|
||||
return '<div class="file-item ' + f.status + '"><div class="file-header"><div><div class="file-name">' + f.name + '</div><div class="file-size">' + formatSize(f.size) + '</div></div><button class="remove-btn" onclick="removeFile(\\'' + f.id + '\\')">×</button></div>' + statusHtml + '</div>';
|
||||
}).join('');
|
||||
updateStats();
|
||||
}
|
||||
|
||||
function removeFile(id) {
|
||||
files = files.filter(f => f.id !== id);
|
||||
renderFiles();
|
||||
}
|
||||
|
||||
async function uploadFile(file, fileId) {
|
||||
const updateProgress = (progress) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.progress = progress; renderFiles(); }
|
||||
};
|
||||
const setStatus = (status, error) => {
|
||||
const f = files.find(f => f.id === fileId);
|
||||
if (f) { f.status = status; if (error) f.error = error; renderFiles(); }
|
||||
};
|
||||
|
||||
try {
|
||||
setStatus('uploading');
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
// Chunked upload
|
||||
const totalChunks = Math.ceil(file.size / CHUNK_SIZE);
|
||||
const initRes = await fetch(serverUrl + '/api/v1/upload/init', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ filename: file.name, filesize: file.size, chunks: totalChunks, destination: destination })
|
||||
});
|
||||
if (!initRes.ok) throw new Error('Konnte Upload nicht starten');
|
||||
const { upload_id } = await initRes.json();
|
||||
|
||||
for (let i = 0; i < totalChunks; i++) {
|
||||
const start = i * CHUNK_SIZE;
|
||||
const end = Math.min(start + CHUNK_SIZE, file.size);
|
||||
const chunk = file.slice(start, end);
|
||||
const formData = new FormData();
|
||||
formData.append('chunk', chunk);
|
||||
formData.append('upload_id', upload_id);
|
||||
formData.append('chunk_index', i.toString());
|
||||
const chunkRes = await fetch(serverUrl + '/api/v1/upload/chunk', { method: 'POST', body: formData });
|
||||
if (!chunkRes.ok) throw new Error('Fehler bei Teil ' + (i + 1));
|
||||
updateProgress(Math.round(((i + 1) / totalChunks) * 100));
|
||||
}
|
||||
const finalizeForm = new FormData();
|
||||
finalizeForm.append('upload_id', upload_id);
|
||||
const finalRes = await fetch(serverUrl + '/api/v1/upload/finalize', { method: 'POST', body: finalizeForm });
|
||||
if (!finalRes.ok) throw new Error('Fehler beim Abschliessen');
|
||||
} else {
|
||||
// Simple upload
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('destination', destination);
|
||||
const res = await fetch(serverUrl + '/api/v1/upload/simple', { method: 'POST', body: formData });
|
||||
if (!res.ok) throw new Error('Upload fehlgeschlagen');
|
||||
updateProgress(100);
|
||||
}
|
||||
setStatus('complete');
|
||||
} catch (e) {
|
||||
setStatus('error', e.message);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFiles(fileList) {
|
||||
const newFiles = Array.from(fileList).filter(f => f.type === 'application/pdf');
|
||||
newFiles.forEach(file => {
|
||||
const id = Math.random().toString(36).substr(2, 9);
|
||||
files.push({ id, name: file.name, size: file.size, progress: 0, status: 'pending', file });
|
||||
renderFiles();
|
||||
uploadFile(file, id);
|
||||
});
|
||||
}
|
||||
|
||||
// Drag & Drop
|
||||
const zone = document.getElementById('upload-zone');
|
||||
zone.addEventListener('dragover', e => { e.preventDefault(); zone.classList.add('dragover'); });
|
||||
zone.addEventListener('dragleave', e => { e.preventDefault(); zone.classList.remove('dragover'); });
|
||||
zone.addEventListener('drop', e => { e.preventDefault(); zone.classList.remove('dragover'); handleFiles(e.dataTransfer.files); });
|
||||
</script>
|
||||
</body>
|
||||
</html>'''
|
||||
return HTMLResponse(content=html_content)
|
||||
# Backward-compat shim -- module moved to upload/mobile.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("upload.mobile")
|
||||
|
||||
Reference in New Issue
Block a user