D2: TrOCR ONNX export script (printed + handwritten, int8 quantization) D3: PP-DocLayout ONNX export script (download or Docker-based conversion) B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config) A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND) A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND) B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding C3: TrOCR-ONNX.md documentation C4: OCR-Pipeline.md ONNX section added C5: mkdocs.yml nav updated, optimum added to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
413 lines
14 KiB
Python
Executable File
413 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
TrOCR ONNX Export — exports TrOCR models to ONNX with int8 quantization.
|
|
|
|
Supported models:
|
|
- microsoft/trocr-base-printed
|
|
- microsoft/trocr-base-handwritten
|
|
|
|
Steps per model:
|
|
1. Load PyTorch model via optimum ORTModelForVision2Seq (export=True)
|
|
2. Save ONNX to output directory
|
|
3. Quantize to int8 via ORTQuantizer + AutoQuantizationConfig
|
|
4. Verify: compare PyTorch vs ONNX outputs (diff < 2%)
|
|
5. Report model sizes before/after quantization
|
|
|
|
Usage:
|
|
python scripts/export-trocr-onnx.py
|
|
python scripts/export-trocr-onnx.py --model printed
|
|
python scripts/export-trocr-onnx.py --model handwritten --skip-verify
|
|
python scripts/export-trocr-onnx.py --output-dir models/onnx --skip-quantize
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import platform
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Add backend to path for imports
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'klausur-service', 'backend'))
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
MODELS = {
|
|
"printed": "microsoft/trocr-base-printed",
|
|
"handwritten": "microsoft/trocr-base-handwritten",
|
|
}
|
|
|
|
DEFAULT_OUTPUT_DIR = os.path.join(os.path.dirname(__file__), '..', 'models', 'onnx')
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def dir_size_mb(path: str) -> float:
|
|
"""Return total size of all files under *path* in MB."""
|
|
total = 0
|
|
for root, _dirs, files in os.walk(path):
|
|
for f in files:
|
|
total += os.path.getsize(os.path.join(root, f))
|
|
return total / (1024 * 1024)
|
|
|
|
|
|
def log(msg: str) -> None:
|
|
"""Print a timestamped log message to stderr."""
|
|
print(f"[export-onnx] {msg}", file=sys.stderr, flush=True)
|
|
|
|
|
|
def _create_test_image():
|
|
"""Create a simple synthetic text-line image for verification."""
|
|
from PIL import Image
|
|
|
|
w, h = 384, 48
|
|
img = Image.new('RGB', (w, h), 'white')
|
|
pixels = img.load()
|
|
# Draw a dark region to simulate printed text
|
|
for x in range(60, 220):
|
|
for y in range(10, 38):
|
|
pixels[x, y] = (25, 25, 25)
|
|
return img
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Export
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def export_to_onnx(model_name: str, output_dir: str) -> str:
|
|
"""Export a HuggingFace TrOCR model to ONNX via optimum.
|
|
|
|
Returns the path to the saved ONNX directory.
|
|
"""
|
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
|
|
|
short_name = model_name.split("/")[-1]
|
|
onnx_path = os.path.join(output_dir, short_name)
|
|
|
|
log(f"Exporting {model_name} to ONNX ...")
|
|
t0 = time.monotonic()
|
|
|
|
model = ORTModelForVision2Seq.from_pretrained(model_name, export=True)
|
|
model.save_pretrained(onnx_path)
|
|
|
|
elapsed = time.monotonic() - t0
|
|
size = dir_size_mb(onnx_path)
|
|
log(f" Exported in {elapsed:.1f}s — {size:.1f} MB on disk: {onnx_path}")
|
|
|
|
return onnx_path
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Quantization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def quantize_onnx(onnx_path: str) -> str:
|
|
"""Apply int8 dynamic quantization to an ONNX model directory.
|
|
|
|
Returns the path to the quantized model directory.
|
|
"""
|
|
from optimum.onnxruntime import ORTQuantizer
|
|
from optimum.onnxruntime.configuration import AutoQuantizationConfig
|
|
|
|
quantized_path = onnx_path + "-int8"
|
|
log(f"Quantizing to int8 → {quantized_path} ...")
|
|
t0 = time.monotonic()
|
|
|
|
# Pick quantization config based on platform.
|
|
# arm64 (Apple Silicon) does not have AVX-512; use arm64 config when
|
|
# available, otherwise fall back to avx512_vnni which still works for
|
|
# dynamic quantisation (weights-only).
|
|
machine = platform.machine().lower()
|
|
if "arm" in machine or "aarch" in machine:
|
|
try:
|
|
qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
|
|
log(" Using arm64 quantization config")
|
|
except AttributeError:
|
|
# Older optimum versions may lack arm64(); fall back.
|
|
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
|
|
log(" arm64 config not available, falling back to avx512_vnni")
|
|
else:
|
|
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
|
|
log(" Using avx512_vnni quantization config")
|
|
|
|
quantizer = ORTQuantizer.from_pretrained(onnx_path)
|
|
quantizer.quantize(save_dir=quantized_path, quantization_config=qconfig)
|
|
|
|
elapsed = time.monotonic() - t0
|
|
size = dir_size_mb(quantized_path)
|
|
log(f" Quantized in {elapsed:.1f}s — {size:.1f} MB on disk")
|
|
|
|
return quantized_path
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Verification
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def verify_outputs(model_name: str, onnx_path: str) -> dict:
|
|
"""Compare PyTorch and ONNX model outputs on a synthetic image.
|
|
|
|
Returns a dict with verification results including the max relative
|
|
difference of generated token IDs and decoded text from both backends.
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
|
|
log(f"Verifying ONNX output against PyTorch for {model_name} ...")
|
|
test_image = _create_test_image()
|
|
|
|
# --- PyTorch inference ---
|
|
processor = TrOCRProcessor.from_pretrained(model_name)
|
|
pt_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
|
pt_model.eval()
|
|
|
|
pixel_values = processor(images=test_image, return_tensors="pt").pixel_values
|
|
with torch.no_grad():
|
|
pt_ids = pt_model.generate(pixel_values, max_new_tokens=50)
|
|
pt_text = processor.batch_decode(pt_ids, skip_special_tokens=True)[0]
|
|
|
|
# --- ONNX inference ---
|
|
ort_model = ORTModelForVision2Seq.from_pretrained(onnx_path)
|
|
ort_pixel_values = processor(images=test_image, return_tensors="pt").pixel_values
|
|
ort_ids = ort_model.generate(ort_pixel_values, max_new_tokens=50)
|
|
ort_text = processor.batch_decode(ort_ids, skip_special_tokens=True)[0]
|
|
|
|
# --- Compare ---
|
|
pt_arr = pt_ids[0].numpy().astype(np.float64)
|
|
ort_arr = ort_ids[0].numpy().astype(np.float64)
|
|
|
|
# Pad to equal length for comparison
|
|
max_len = max(len(pt_arr), len(ort_arr))
|
|
if len(pt_arr) < max_len:
|
|
pt_arr = np.pad(pt_arr, (0, max_len - len(pt_arr)))
|
|
if len(ort_arr) < max_len:
|
|
ort_arr = np.pad(ort_arr, (0, max_len - len(ort_arr)))
|
|
|
|
# Relative diff on token ids (treat 0-ids as 1 to avoid div-by-zero)
|
|
denom = np.where(np.abs(pt_arr) > 0, np.abs(pt_arr), 1.0)
|
|
rel_diff = np.abs(pt_arr - ort_arr) / denom
|
|
max_diff_pct = float(np.max(rel_diff)) * 100.0
|
|
exact_match = bool(np.array_equal(pt_ids[0].numpy(), ort_ids[0].numpy()))
|
|
|
|
passed = max_diff_pct < 2.0
|
|
|
|
result = {
|
|
"passed": passed,
|
|
"exact_token_match": exact_match,
|
|
"max_relative_diff_pct": round(max_diff_pct, 4),
|
|
"pytorch_text": pt_text,
|
|
"onnx_text": ort_text,
|
|
"text_match": pt_text == ort_text,
|
|
}
|
|
|
|
status = "PASS" if passed else "FAIL"
|
|
log(f" Verification {status}: max_diff={max_diff_pct:.4f}% exact_match={exact_match}")
|
|
log(f" PyTorch : '{pt_text}'")
|
|
log(f" ONNX : '{ort_text}'")
|
|
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Per-model pipeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def process_model(
|
|
model_name: str,
|
|
output_dir: str,
|
|
skip_verify: bool = False,
|
|
skip_quantize: bool = False,
|
|
) -> dict:
|
|
"""Run the full export pipeline for one model.
|
|
|
|
Returns a summary dict with paths, sizes, and verification results.
|
|
"""
|
|
short_name = model_name.split("/")[-1]
|
|
summary: dict = {
|
|
"model": model_name,
|
|
"short_name": short_name,
|
|
"onnx_path": None,
|
|
"onnx_size_mb": None,
|
|
"quantized_path": None,
|
|
"quantized_size_mb": None,
|
|
"size_reduction_pct": None,
|
|
"verification_fp32": None,
|
|
"verification_int8": None,
|
|
"error": None,
|
|
}
|
|
|
|
log(f"{'='*60}")
|
|
log(f"Processing: {model_name}")
|
|
log(f"{'='*60}")
|
|
|
|
# Step 1 + 2: Export to ONNX
|
|
try:
|
|
onnx_path = export_to_onnx(model_name, output_dir)
|
|
summary["onnx_path"] = onnx_path
|
|
summary["onnx_size_mb"] = round(dir_size_mb(onnx_path), 1)
|
|
except Exception as e:
|
|
summary["error"] = f"ONNX export failed: {e}"
|
|
log(f" ERROR: {summary['error']}")
|
|
return summary
|
|
|
|
# Step 3: Verify fp32 ONNX
|
|
if not skip_verify:
|
|
try:
|
|
summary["verification_fp32"] = verify_outputs(model_name, onnx_path)
|
|
except Exception as e:
|
|
log(f" WARNING: fp32 verification failed: {e}")
|
|
summary["verification_fp32"] = {"passed": False, "error": str(e)}
|
|
|
|
# Step 4: Quantize to int8
|
|
if not skip_quantize:
|
|
try:
|
|
quantized_path = quantize_onnx(onnx_path)
|
|
summary["quantized_path"] = quantized_path
|
|
q_size = dir_size_mb(quantized_path)
|
|
summary["quantized_size_mb"] = round(q_size, 1)
|
|
|
|
if summary["onnx_size_mb"] and summary["onnx_size_mb"] > 0:
|
|
reduction = (1 - q_size / dir_size_mb(onnx_path)) * 100
|
|
summary["size_reduction_pct"] = round(reduction, 1)
|
|
except Exception as e:
|
|
summary["error"] = f"Quantization failed: {e}"
|
|
log(f" ERROR: {summary['error']}")
|
|
return summary
|
|
|
|
# Step 5: Verify int8 ONNX
|
|
if not skip_verify:
|
|
try:
|
|
summary["verification_int8"] = verify_outputs(model_name, quantized_path)
|
|
except Exception as e:
|
|
log(f" WARNING: int8 verification failed: {e}")
|
|
summary["verification_int8"] = {"passed": False, "error": str(e)}
|
|
|
|
return summary
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Summary printing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def print_summary(results: list[dict]) -> None:
|
|
"""Print a human-readable summary table."""
|
|
print("\n" + "=" * 72, file=sys.stderr)
|
|
print(" EXPORT SUMMARY", file=sys.stderr)
|
|
print("=" * 72, file=sys.stderr)
|
|
|
|
for r in results:
|
|
print(f"\n Model: {r['model']}", file=sys.stderr)
|
|
|
|
if r.get("error"):
|
|
print(f" ERROR: {r['error']}", file=sys.stderr)
|
|
continue
|
|
|
|
# Sizes
|
|
onnx_mb = r.get("onnx_size_mb", "?")
|
|
q_mb = r.get("quantized_size_mb", "?")
|
|
reduction = r.get("size_reduction_pct", "?")
|
|
|
|
print(f" ONNX fp32 : {onnx_mb} MB ({r.get('onnx_path', '?')})", file=sys.stderr)
|
|
if q_mb != "?":
|
|
print(f" ONNX int8 : {q_mb} MB ({r.get('quantized_path', '?')})", file=sys.stderr)
|
|
print(f" Reduction : {reduction}%", file=sys.stderr)
|
|
|
|
# Verification
|
|
for label, key in [("fp32 verify", "verification_fp32"), ("int8 verify", "verification_int8")]:
|
|
v = r.get(key)
|
|
if v is None:
|
|
print(f" {label}: skipped", file=sys.stderr)
|
|
elif v.get("error"):
|
|
print(f" {label}: ERROR — {v['error']}", file=sys.stderr)
|
|
else:
|
|
status = "PASS" if v["passed"] else "FAIL"
|
|
diff = v.get("max_relative_diff_pct", "?")
|
|
match = v.get("text_match", "?")
|
|
print(f" {label}: {status} (max_diff={diff}%, text_match={match})", file=sys.stderr)
|
|
|
|
print("\n" + "=" * 72 + "\n", file=sys.stderr)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Export TrOCR models to ONNX with int8 quantization",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
choices=["printed", "handwritten", "both"],
|
|
default="both",
|
|
help="Which model to export (default: both)",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
default=DEFAULT_OUTPUT_DIR,
|
|
help=f"Output directory for ONNX models (default: {DEFAULT_OUTPUT_DIR})",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-verify",
|
|
action="store_true",
|
|
help="Skip output verification against PyTorch",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-quantize",
|
|
action="store_true",
|
|
help="Skip int8 quantization step",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Resolve output directory
|
|
output_dir = os.path.abspath(args.output_dir)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
log(f"Output directory: {output_dir}")
|
|
|
|
# Determine which models to process
|
|
if args.model == "both":
|
|
model_names = list(MODELS.values())
|
|
else:
|
|
model_names = [MODELS[args.model]]
|
|
|
|
# Process each model
|
|
results = []
|
|
for model_name in model_names:
|
|
result = process_model(
|
|
model_name=model_name,
|
|
output_dir=output_dir,
|
|
skip_verify=args.skip_verify,
|
|
skip_quantize=args.skip_quantize,
|
|
)
|
|
results.append(result)
|
|
|
|
# Print summary
|
|
print_summary(results)
|
|
|
|
# Exit with error code if any model failed verification
|
|
any_fail = False
|
|
for r in results:
|
|
if r.get("error"):
|
|
any_fail = True
|
|
for vkey in ("verification_fp32", "verification_int8"):
|
|
v = r.get(vkey)
|
|
if v and not v.get("passed", True):
|
|
any_fail = True
|
|
|
|
if any_fail:
|
|
log("One or more steps failed or did not pass verification.")
|
|
sys.exit(1)
|
|
else:
|
|
log("All exports completed successfully.")
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|