Files
breakpilot-lehrer/scripts/export-trocr-onnx.py
Benjamin Admin be7f5f1872 feat: Sprint 2 — TrOCR ONNX, PP-DocLayout, Model Management
D2: TrOCR ONNX export script (printed + handwritten, int8 quantization)
D3: PP-DocLayout ONNX export script (download or Docker-based conversion)
B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config)
A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND)
A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND)
B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding
C3: TrOCR-ONNX.md documentation
C4: OCR-Pipeline.md ONNX section added
C5: mkdocs.yml nav updated, optimum added to requirements.txt

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 09:53:02 +01:00

413 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
"""
TrOCR ONNX Export — exports TrOCR models to ONNX with int8 quantization.
Supported models:
- microsoft/trocr-base-printed
- microsoft/trocr-base-handwritten
Steps per model:
1. Load PyTorch model via optimum ORTModelForVision2Seq (export=True)
2. Save ONNX to output directory
3. Quantize to int8 via ORTQuantizer + AutoQuantizationConfig
4. Verify: compare PyTorch vs ONNX outputs (diff < 2%)
5. Report model sizes before/after quantization
Usage:
python scripts/export-trocr-onnx.py
python scripts/export-trocr-onnx.py --model printed
python scripts/export-trocr-onnx.py --model handwritten --skip-verify
python scripts/export-trocr-onnx.py --output-dir models/onnx --skip-quantize
"""
import argparse
import os
import platform
import sys
import time
from pathlib import Path
# Add backend to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'klausur-service', 'backend'))
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MODELS = {
"printed": "microsoft/trocr-base-printed",
"handwritten": "microsoft/trocr-base-handwritten",
}
DEFAULT_OUTPUT_DIR = os.path.join(os.path.dirname(__file__), '..', 'models', 'onnx')
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def dir_size_mb(path: str) -> float:
"""Return total size of all files under *path* in MB."""
total = 0
for root, _dirs, files in os.walk(path):
for f in files:
total += os.path.getsize(os.path.join(root, f))
return total / (1024 * 1024)
def log(msg: str) -> None:
"""Print a timestamped log message to stderr."""
print(f"[export-onnx] {msg}", file=sys.stderr, flush=True)
def _create_test_image():
"""Create a simple synthetic text-line image for verification."""
from PIL import Image
w, h = 384, 48
img = Image.new('RGB', (w, h), 'white')
pixels = img.load()
# Draw a dark region to simulate printed text
for x in range(60, 220):
for y in range(10, 38):
pixels[x, y] = (25, 25, 25)
return img
# ---------------------------------------------------------------------------
# Export
# ---------------------------------------------------------------------------
def export_to_onnx(model_name: str, output_dir: str) -> str:
"""Export a HuggingFace TrOCR model to ONNX via optimum.
Returns the path to the saved ONNX directory.
"""
from optimum.onnxruntime import ORTModelForVision2Seq
short_name = model_name.split("/")[-1]
onnx_path = os.path.join(output_dir, short_name)
log(f"Exporting {model_name} to ONNX ...")
t0 = time.monotonic()
model = ORTModelForVision2Seq.from_pretrained(model_name, export=True)
model.save_pretrained(onnx_path)
elapsed = time.monotonic() - t0
size = dir_size_mb(onnx_path)
log(f" Exported in {elapsed:.1f}s — {size:.1f} MB on disk: {onnx_path}")
return onnx_path
# ---------------------------------------------------------------------------
# Quantization
# ---------------------------------------------------------------------------
def quantize_onnx(onnx_path: str) -> str:
"""Apply int8 dynamic quantization to an ONNX model directory.
Returns the path to the quantized model directory.
"""
from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
quantized_path = onnx_path + "-int8"
log(f"Quantizing to int8 → {quantized_path} ...")
t0 = time.monotonic()
# Pick quantization config based on platform.
# arm64 (Apple Silicon) does not have AVX-512; use arm64 config when
# available, otherwise fall back to avx512_vnni which still works for
# dynamic quantisation (weights-only).
machine = platform.machine().lower()
if "arm" in machine or "aarch" in machine:
try:
qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
log(" Using arm64 quantization config")
except AttributeError:
# Older optimum versions may lack arm64(); fall back.
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
log(" arm64 config not available, falling back to avx512_vnni")
else:
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
log(" Using avx512_vnni quantization config")
quantizer = ORTQuantizer.from_pretrained(onnx_path)
quantizer.quantize(save_dir=quantized_path, quantization_config=qconfig)
elapsed = time.monotonic() - t0
size = dir_size_mb(quantized_path)
log(f" Quantized in {elapsed:.1f}s — {size:.1f} MB on disk")
return quantized_path
# ---------------------------------------------------------------------------
# Verification
# ---------------------------------------------------------------------------
def verify_outputs(model_name: str, onnx_path: str) -> dict:
"""Compare PyTorch and ONNX model outputs on a synthetic image.
Returns a dict with verification results including the max relative
difference of generated token IDs and decoded text from both backends.
"""
import numpy as np
import torch
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
log(f"Verifying ONNX output against PyTorch for {model_name} ...")
test_image = _create_test_image()
# --- PyTorch inference ---
processor = TrOCRProcessor.from_pretrained(model_name)
pt_model = VisionEncoderDecoderModel.from_pretrained(model_name)
pt_model.eval()
pixel_values = processor(images=test_image, return_tensors="pt").pixel_values
with torch.no_grad():
pt_ids = pt_model.generate(pixel_values, max_new_tokens=50)
pt_text = processor.batch_decode(pt_ids, skip_special_tokens=True)[0]
# --- ONNX inference ---
ort_model = ORTModelForVision2Seq.from_pretrained(onnx_path)
ort_pixel_values = processor(images=test_image, return_tensors="pt").pixel_values
ort_ids = ort_model.generate(ort_pixel_values, max_new_tokens=50)
ort_text = processor.batch_decode(ort_ids, skip_special_tokens=True)[0]
# --- Compare ---
pt_arr = pt_ids[0].numpy().astype(np.float64)
ort_arr = ort_ids[0].numpy().astype(np.float64)
# Pad to equal length for comparison
max_len = max(len(pt_arr), len(ort_arr))
if len(pt_arr) < max_len:
pt_arr = np.pad(pt_arr, (0, max_len - len(pt_arr)))
if len(ort_arr) < max_len:
ort_arr = np.pad(ort_arr, (0, max_len - len(ort_arr)))
# Relative diff on token ids (treat 0-ids as 1 to avoid div-by-zero)
denom = np.where(np.abs(pt_arr) > 0, np.abs(pt_arr), 1.0)
rel_diff = np.abs(pt_arr - ort_arr) / denom
max_diff_pct = float(np.max(rel_diff)) * 100.0
exact_match = bool(np.array_equal(pt_ids[0].numpy(), ort_ids[0].numpy()))
passed = max_diff_pct < 2.0
result = {
"passed": passed,
"exact_token_match": exact_match,
"max_relative_diff_pct": round(max_diff_pct, 4),
"pytorch_text": pt_text,
"onnx_text": ort_text,
"text_match": pt_text == ort_text,
}
status = "PASS" if passed else "FAIL"
log(f" Verification {status}: max_diff={max_diff_pct:.4f}% exact_match={exact_match}")
log(f" PyTorch : '{pt_text}'")
log(f" ONNX : '{ort_text}'")
return result
# ---------------------------------------------------------------------------
# Per-model pipeline
# ---------------------------------------------------------------------------
def process_model(
model_name: str,
output_dir: str,
skip_verify: bool = False,
skip_quantize: bool = False,
) -> dict:
"""Run the full export pipeline for one model.
Returns a summary dict with paths, sizes, and verification results.
"""
short_name = model_name.split("/")[-1]
summary: dict = {
"model": model_name,
"short_name": short_name,
"onnx_path": None,
"onnx_size_mb": None,
"quantized_path": None,
"quantized_size_mb": None,
"size_reduction_pct": None,
"verification_fp32": None,
"verification_int8": None,
"error": None,
}
log(f"{'='*60}")
log(f"Processing: {model_name}")
log(f"{'='*60}")
# Step 1 + 2: Export to ONNX
try:
onnx_path = export_to_onnx(model_name, output_dir)
summary["onnx_path"] = onnx_path
summary["onnx_size_mb"] = round(dir_size_mb(onnx_path), 1)
except Exception as e:
summary["error"] = f"ONNX export failed: {e}"
log(f" ERROR: {summary['error']}")
return summary
# Step 3: Verify fp32 ONNX
if not skip_verify:
try:
summary["verification_fp32"] = verify_outputs(model_name, onnx_path)
except Exception as e:
log(f" WARNING: fp32 verification failed: {e}")
summary["verification_fp32"] = {"passed": False, "error": str(e)}
# Step 4: Quantize to int8
if not skip_quantize:
try:
quantized_path = quantize_onnx(onnx_path)
summary["quantized_path"] = quantized_path
q_size = dir_size_mb(quantized_path)
summary["quantized_size_mb"] = round(q_size, 1)
if summary["onnx_size_mb"] and summary["onnx_size_mb"] > 0:
reduction = (1 - q_size / dir_size_mb(onnx_path)) * 100
summary["size_reduction_pct"] = round(reduction, 1)
except Exception as e:
summary["error"] = f"Quantization failed: {e}"
log(f" ERROR: {summary['error']}")
return summary
# Step 5: Verify int8 ONNX
if not skip_verify:
try:
summary["verification_int8"] = verify_outputs(model_name, quantized_path)
except Exception as e:
log(f" WARNING: int8 verification failed: {e}")
summary["verification_int8"] = {"passed": False, "error": str(e)}
return summary
# ---------------------------------------------------------------------------
# Summary printing
# ---------------------------------------------------------------------------
def print_summary(results: list[dict]) -> None:
"""Print a human-readable summary table."""
print("\n" + "=" * 72, file=sys.stderr)
print(" EXPORT SUMMARY", file=sys.stderr)
print("=" * 72, file=sys.stderr)
for r in results:
print(f"\n Model: {r['model']}", file=sys.stderr)
if r.get("error"):
print(f" ERROR: {r['error']}", file=sys.stderr)
continue
# Sizes
onnx_mb = r.get("onnx_size_mb", "?")
q_mb = r.get("quantized_size_mb", "?")
reduction = r.get("size_reduction_pct", "?")
print(f" ONNX fp32 : {onnx_mb} MB ({r.get('onnx_path', '?')})", file=sys.stderr)
if q_mb != "?":
print(f" ONNX int8 : {q_mb} MB ({r.get('quantized_path', '?')})", file=sys.stderr)
print(f" Reduction : {reduction}%", file=sys.stderr)
# Verification
for label, key in [("fp32 verify", "verification_fp32"), ("int8 verify", "verification_int8")]:
v = r.get(key)
if v is None:
print(f" {label}: skipped", file=sys.stderr)
elif v.get("error"):
print(f" {label}: ERROR — {v['error']}", file=sys.stderr)
else:
status = "PASS" if v["passed"] else "FAIL"
diff = v.get("max_relative_diff_pct", "?")
match = v.get("text_match", "?")
print(f" {label}: {status} (max_diff={diff}%, text_match={match})", file=sys.stderr)
print("\n" + "=" * 72 + "\n", file=sys.stderr)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Export TrOCR models to ONNX with int8 quantization",
)
parser.add_argument(
"--model",
choices=["printed", "handwritten", "both"],
default="both",
help="Which model to export (default: both)",
)
parser.add_argument(
"--output-dir",
default=DEFAULT_OUTPUT_DIR,
help=f"Output directory for ONNX models (default: {DEFAULT_OUTPUT_DIR})",
)
parser.add_argument(
"--skip-verify",
action="store_true",
help="Skip output verification against PyTorch",
)
parser.add_argument(
"--skip-quantize",
action="store_true",
help="Skip int8 quantization step",
)
args = parser.parse_args()
# Resolve output directory
output_dir = os.path.abspath(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
log(f"Output directory: {output_dir}")
# Determine which models to process
if args.model == "both":
model_names = list(MODELS.values())
else:
model_names = [MODELS[args.model]]
# Process each model
results = []
for model_name in model_names:
result = process_model(
model_name=model_name,
output_dir=output_dir,
skip_verify=args.skip_verify,
skip_quantize=args.skip_quantize,
)
results.append(result)
# Print summary
print_summary(results)
# Exit with error code if any model failed verification
any_fail = False
for r in results:
if r.get("error"):
any_fail = True
for vkey in ("verification_fp32", "verification_int8"):
v = r.get(vkey)
if v and not v.get("passed", True):
any_fail = True
if any_fail:
log("One or more steps failed or did not pass verification.")
sys.exit(1)
else:
log("All exports completed successfully.")
sys.exit(0)
if __name__ == "__main__":
main()