#!/usr/bin/env python3 """ TrOCR ONNX Export — exports TrOCR models to ONNX with int8 quantization. Supported models: - microsoft/trocr-base-printed - microsoft/trocr-base-handwritten Steps per model: 1. Load PyTorch model via optimum ORTModelForVision2Seq (export=True) 2. Save ONNX to output directory 3. Quantize to int8 via ORTQuantizer + AutoQuantizationConfig 4. Verify: compare PyTorch vs ONNX outputs (diff < 2%) 5. Report model sizes before/after quantization Usage: python scripts/export-trocr-onnx.py python scripts/export-trocr-onnx.py --model printed python scripts/export-trocr-onnx.py --model handwritten --skip-verify python scripts/export-trocr-onnx.py --output-dir models/onnx --skip-quantize """ import argparse import os import platform import sys import time from pathlib import Path # Add backend to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'klausur-service', 'backend')) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MODELS = { "printed": "microsoft/trocr-base-printed", "handwritten": "microsoft/trocr-base-handwritten", } DEFAULT_OUTPUT_DIR = os.path.join(os.path.dirname(__file__), '..', 'models', 'onnx') # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def dir_size_mb(path: str) -> float: """Return total size of all files under *path* in MB.""" total = 0 for root, _dirs, files in os.walk(path): for f in files: total += os.path.getsize(os.path.join(root, f)) return total / (1024 * 1024) def log(msg: str) -> None: """Print a timestamped log message to stderr.""" print(f"[export-onnx] {msg}", file=sys.stderr, flush=True) def _create_test_image(): """Create a simple synthetic text-line image for verification.""" from PIL import Image w, h = 384, 48 img = Image.new('RGB', (w, h), 'white') pixels = img.load() # Draw a dark region to simulate printed text for x in range(60, 220): for y in range(10, 38): pixels[x, y] = (25, 25, 25) return img # --------------------------------------------------------------------------- # Export # --------------------------------------------------------------------------- def export_to_onnx(model_name: str, output_dir: str) -> str: """Export a HuggingFace TrOCR model to ONNX via optimum. Returns the path to the saved ONNX directory. """ from optimum.onnxruntime import ORTModelForVision2Seq short_name = model_name.split("/")[-1] onnx_path = os.path.join(output_dir, short_name) log(f"Exporting {model_name} to ONNX ...") t0 = time.monotonic() model = ORTModelForVision2Seq.from_pretrained(model_name, export=True) model.save_pretrained(onnx_path) elapsed = time.monotonic() - t0 size = dir_size_mb(onnx_path) log(f" Exported in {elapsed:.1f}s — {size:.1f} MB on disk: {onnx_path}") return onnx_path # --------------------------------------------------------------------------- # Quantization # --------------------------------------------------------------------------- def quantize_onnx(onnx_path: str) -> str: """Apply int8 dynamic quantization to an ONNX model directory. Returns the path to the quantized model directory. """ from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoQuantizationConfig quantized_path = onnx_path + "-int8" log(f"Quantizing to int8 → {quantized_path} ...") t0 = time.monotonic() # Pick quantization config based on platform. # arm64 (Apple Silicon) does not have AVX-512; use arm64 config when # available, otherwise fall back to avx512_vnni which still works for # dynamic quantisation (weights-only). machine = platform.machine().lower() if "arm" in machine or "aarch" in machine: try: qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) log(" Using arm64 quantization config") except AttributeError: # Older optimum versions may lack arm64(); fall back. qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) log(" arm64 config not available, falling back to avx512_vnni") else: qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) log(" Using avx512_vnni quantization config") quantizer = ORTQuantizer.from_pretrained(onnx_path) quantizer.quantize(save_dir=quantized_path, quantization_config=qconfig) elapsed = time.monotonic() - t0 size = dir_size_mb(quantized_path) log(f" Quantized in {elapsed:.1f}s — {size:.1f} MB on disk") return quantized_path # --------------------------------------------------------------------------- # Verification # --------------------------------------------------------------------------- def verify_outputs(model_name: str, onnx_path: str) -> dict: """Compare PyTorch and ONNX model outputs on a synthetic image. Returns a dict with verification results including the max relative difference of generated token IDs and decoded text from both backends. """ import numpy as np import torch from optimum.onnxruntime import ORTModelForVision2Seq from transformers import TrOCRProcessor, VisionEncoderDecoderModel log(f"Verifying ONNX output against PyTorch for {model_name} ...") test_image = _create_test_image() # --- PyTorch inference --- processor = TrOCRProcessor.from_pretrained(model_name) pt_model = VisionEncoderDecoderModel.from_pretrained(model_name) pt_model.eval() pixel_values = processor(images=test_image, return_tensors="pt").pixel_values with torch.no_grad(): pt_ids = pt_model.generate(pixel_values, max_new_tokens=50) pt_text = processor.batch_decode(pt_ids, skip_special_tokens=True)[0] # --- ONNX inference --- ort_model = ORTModelForVision2Seq.from_pretrained(onnx_path) ort_pixel_values = processor(images=test_image, return_tensors="pt").pixel_values ort_ids = ort_model.generate(ort_pixel_values, max_new_tokens=50) ort_text = processor.batch_decode(ort_ids, skip_special_tokens=True)[0] # --- Compare --- pt_arr = pt_ids[0].numpy().astype(np.float64) ort_arr = ort_ids[0].numpy().astype(np.float64) # Pad to equal length for comparison max_len = max(len(pt_arr), len(ort_arr)) if len(pt_arr) < max_len: pt_arr = np.pad(pt_arr, (0, max_len - len(pt_arr))) if len(ort_arr) < max_len: ort_arr = np.pad(ort_arr, (0, max_len - len(ort_arr))) # Relative diff on token ids (treat 0-ids as 1 to avoid div-by-zero) denom = np.where(np.abs(pt_arr) > 0, np.abs(pt_arr), 1.0) rel_diff = np.abs(pt_arr - ort_arr) / denom max_diff_pct = float(np.max(rel_diff)) * 100.0 exact_match = bool(np.array_equal(pt_ids[0].numpy(), ort_ids[0].numpy())) passed = max_diff_pct < 2.0 result = { "passed": passed, "exact_token_match": exact_match, "max_relative_diff_pct": round(max_diff_pct, 4), "pytorch_text": pt_text, "onnx_text": ort_text, "text_match": pt_text == ort_text, } status = "PASS" if passed else "FAIL" log(f" Verification {status}: max_diff={max_diff_pct:.4f}% exact_match={exact_match}") log(f" PyTorch : '{pt_text}'") log(f" ONNX : '{ort_text}'") return result # --------------------------------------------------------------------------- # Per-model pipeline # --------------------------------------------------------------------------- def process_model( model_name: str, output_dir: str, skip_verify: bool = False, skip_quantize: bool = False, ) -> dict: """Run the full export pipeline for one model. Returns a summary dict with paths, sizes, and verification results. """ short_name = model_name.split("/")[-1] summary: dict = { "model": model_name, "short_name": short_name, "onnx_path": None, "onnx_size_mb": None, "quantized_path": None, "quantized_size_mb": None, "size_reduction_pct": None, "verification_fp32": None, "verification_int8": None, "error": None, } log(f"{'='*60}") log(f"Processing: {model_name}") log(f"{'='*60}") # Step 1 + 2: Export to ONNX try: onnx_path = export_to_onnx(model_name, output_dir) summary["onnx_path"] = onnx_path summary["onnx_size_mb"] = round(dir_size_mb(onnx_path), 1) except Exception as e: summary["error"] = f"ONNX export failed: {e}" log(f" ERROR: {summary['error']}") return summary # Step 3: Verify fp32 ONNX if not skip_verify: try: summary["verification_fp32"] = verify_outputs(model_name, onnx_path) except Exception as e: log(f" WARNING: fp32 verification failed: {e}") summary["verification_fp32"] = {"passed": False, "error": str(e)} # Step 4: Quantize to int8 if not skip_quantize: try: quantized_path = quantize_onnx(onnx_path) summary["quantized_path"] = quantized_path q_size = dir_size_mb(quantized_path) summary["quantized_size_mb"] = round(q_size, 1) if summary["onnx_size_mb"] and summary["onnx_size_mb"] > 0: reduction = (1 - q_size / dir_size_mb(onnx_path)) * 100 summary["size_reduction_pct"] = round(reduction, 1) except Exception as e: summary["error"] = f"Quantization failed: {e}" log(f" ERROR: {summary['error']}") return summary # Step 5: Verify int8 ONNX if not skip_verify: try: summary["verification_int8"] = verify_outputs(model_name, quantized_path) except Exception as e: log(f" WARNING: int8 verification failed: {e}") summary["verification_int8"] = {"passed": False, "error": str(e)} return summary # --------------------------------------------------------------------------- # Summary printing # --------------------------------------------------------------------------- def print_summary(results: list[dict]) -> None: """Print a human-readable summary table.""" print("\n" + "=" * 72, file=sys.stderr) print(" EXPORT SUMMARY", file=sys.stderr) print("=" * 72, file=sys.stderr) for r in results: print(f"\n Model: {r['model']}", file=sys.stderr) if r.get("error"): print(f" ERROR: {r['error']}", file=sys.stderr) continue # Sizes onnx_mb = r.get("onnx_size_mb", "?") q_mb = r.get("quantized_size_mb", "?") reduction = r.get("size_reduction_pct", "?") print(f" ONNX fp32 : {onnx_mb} MB ({r.get('onnx_path', '?')})", file=sys.stderr) if q_mb != "?": print(f" ONNX int8 : {q_mb} MB ({r.get('quantized_path', '?')})", file=sys.stderr) print(f" Reduction : {reduction}%", file=sys.stderr) # Verification for label, key in [("fp32 verify", "verification_fp32"), ("int8 verify", "verification_int8")]: v = r.get(key) if v is None: print(f" {label}: skipped", file=sys.stderr) elif v.get("error"): print(f" {label}: ERROR — {v['error']}", file=sys.stderr) else: status = "PASS" if v["passed"] else "FAIL" diff = v.get("max_relative_diff_pct", "?") match = v.get("text_match", "?") print(f" {label}: {status} (max_diff={diff}%, text_match={match})", file=sys.stderr) print("\n" + "=" * 72 + "\n", file=sys.stderr) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Export TrOCR models to ONNX with int8 quantization", ) parser.add_argument( "--model", choices=["printed", "handwritten", "both"], default="both", help="Which model to export (default: both)", ) parser.add_argument( "--output-dir", default=DEFAULT_OUTPUT_DIR, help=f"Output directory for ONNX models (default: {DEFAULT_OUTPUT_DIR})", ) parser.add_argument( "--skip-verify", action="store_true", help="Skip output verification against PyTorch", ) parser.add_argument( "--skip-quantize", action="store_true", help="Skip int8 quantization step", ) args = parser.parse_args() # Resolve output directory output_dir = os.path.abspath(args.output_dir) os.makedirs(output_dir, exist_ok=True) log(f"Output directory: {output_dir}") # Determine which models to process if args.model == "both": model_names = list(MODELS.values()) else: model_names = [MODELS[args.model]] # Process each model results = [] for model_name in model_names: result = process_model( model_name=model_name, output_dir=output_dir, skip_verify=args.skip_verify, skip_quantize=args.skip_quantize, ) results.append(result) # Print summary print_summary(results) # Exit with error code if any model failed verification any_fail = False for r in results: if r.get("error"): any_fail = True for vkey in ("verification_fp32", "verification_int8"): v = r.get(vkey) if v and not v.get("passed", True): any_fail = True if any_fail: log("One or more steps failed or did not pass verification.") sys.exit(1) else: log("All exports completed successfully.") sys.exit(0) if __name__ == "__main__": main()