Compare commits
2 Commits
c695b659fb
...
dccbb909bc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dccbb909bc | ||
|
|
be7f5f1872 |
550
admin-lehrer/app/(admin)/ai/model-management/page.tsx
Normal file
550
admin-lehrer/app/(admin)/ai/model-management/page.tsx
Normal file
@@ -0,0 +1,550 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* Model Management Page
|
||||
*
|
||||
* Manage ML model backends (PyTorch vs ONNX), view status,
|
||||
* run benchmarks, and configure inference settings.
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { PagePurpose } from '@/components/common/PagePurpose'
|
||||
import { AIToolsSidebarResponsive } from '@/components/ai/AIToolsSidebar'
|
||||
|
||||
const KLAUSUR_API = '/klausur-api'
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type BackendMode = 'auto' | 'pytorch' | 'onnx'
|
||||
type ModelStatus = 'available' | 'not_found' | 'loading' | 'error'
|
||||
type Tab = 'overview' | 'benchmarks' | 'configuration'
|
||||
|
||||
interface ModelInfo {
|
||||
name: string
|
||||
key: string
|
||||
pytorch: { status: ModelStatus; size_mb: number; ram_mb: number }
|
||||
onnx: { status: ModelStatus; size_mb: number; ram_mb: number; quantized: boolean }
|
||||
}
|
||||
|
||||
interface BenchmarkRow {
|
||||
model: string
|
||||
backend: string
|
||||
quantization: string
|
||||
size_mb: number
|
||||
ram_mb: number
|
||||
inference_ms: number
|
||||
load_time_s: number
|
||||
}
|
||||
|
||||
interface StatusInfo {
|
||||
active_backend: BackendMode
|
||||
loaded_models: string[]
|
||||
cache_hits: number
|
||||
cache_misses: number
|
||||
uptime_s: number
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock data (used when backend is not available)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MOCK_MODELS: ModelInfo[] = [
|
||||
{
|
||||
name: 'TrOCR Printed',
|
||||
key: 'trocr_printed',
|
||||
pytorch: { status: 'available', size_mb: 892, ram_mb: 1800 },
|
||||
onnx: { status: 'available', size_mb: 234, ram_mb: 620, quantized: true },
|
||||
},
|
||||
{
|
||||
name: 'TrOCR Handwritten',
|
||||
key: 'trocr_handwritten',
|
||||
pytorch: { status: 'available', size_mb: 892, ram_mb: 1800 },
|
||||
onnx: { status: 'not_found', size_mb: 0, ram_mb: 0, quantized: false },
|
||||
},
|
||||
{
|
||||
name: 'PP-DocLayout',
|
||||
key: 'pp_doclayout',
|
||||
pytorch: { status: 'not_found', size_mb: 0, ram_mb: 0 },
|
||||
onnx: { status: 'available', size_mb: 48, ram_mb: 180, quantized: false },
|
||||
},
|
||||
]
|
||||
|
||||
const MOCK_BENCHMARKS: BenchmarkRow[] = [
|
||||
{ model: 'TrOCR Printed', backend: 'PyTorch', quantization: 'FP32', size_mb: 892, ram_mb: 1800, inference_ms: 142, load_time_s: 3.2 },
|
||||
{ model: 'TrOCR Printed', backend: 'ONNX', quantization: 'INT8', size_mb: 234, ram_mb: 620, inference_ms: 38, load_time_s: 0.8 },
|
||||
{ model: 'TrOCR Handwritten', backend: 'PyTorch', quantization: 'FP32', size_mb: 892, ram_mb: 1800, inference_ms: 156, load_time_s: 3.4 },
|
||||
{ model: 'PP-DocLayout', backend: 'ONNX', quantization: 'FP32', size_mb: 48, ram_mb: 180, inference_ms: 22, load_time_s: 0.3 },
|
||||
]
|
||||
|
||||
const MOCK_STATUS: StatusInfo = {
|
||||
active_backend: 'auto',
|
||||
loaded_models: ['trocr_printed (ONNX)', 'pp_doclayout (ONNX)'],
|
||||
cache_hits: 1247,
|
||||
cache_misses: 83,
|
||||
uptime_s: 86400,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function StatusBadge({ status }: { status: ModelStatus }) {
|
||||
const cls =
|
||||
status === 'available'
|
||||
? 'bg-emerald-100 text-emerald-800 border-emerald-200'
|
||||
: status === 'loading'
|
||||
? 'bg-blue-100 text-blue-800 border-blue-200'
|
||||
: status === 'not_found'
|
||||
? 'bg-slate-100 text-slate-500 border-slate-200'
|
||||
: 'bg-red-100 text-red-800 border-red-200'
|
||||
const label =
|
||||
status === 'available' ? 'Verfuegbar'
|
||||
: status === 'loading' ? 'Laden...'
|
||||
: status === 'not_found' ? 'Nicht vorhanden'
|
||||
: 'Fehler'
|
||||
return (
|
||||
<span className={`inline-flex items-center px-2 py-0.5 rounded-full text-xs font-medium border ${cls}`}>
|
||||
{label}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
function formatBytes(mb: number) {
|
||||
if (mb === 0) return '--'
|
||||
if (mb >= 1000) return `${(mb / 1000).toFixed(1)} GB`
|
||||
return `${mb} MB`
|
||||
}
|
||||
|
||||
function formatUptime(seconds: number) {
|
||||
const h = Math.floor(seconds / 3600)
|
||||
const m = Math.floor((seconds % 3600) / 60)
|
||||
if (h > 0) return `${h}h ${m}m`
|
||||
return `${m}m`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function ModelManagementPage() {
|
||||
const [tab, setTab] = useState<Tab>('overview')
|
||||
const [models, setModels] = useState<ModelInfo[]>(MOCK_MODELS)
|
||||
const [benchmarks, setBenchmarks] = useState<BenchmarkRow[]>(MOCK_BENCHMARKS)
|
||||
const [status, setStatus] = useState<StatusInfo>(MOCK_STATUS)
|
||||
const [backend, setBackend] = useState<BackendMode>('auto')
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [benchmarkRunning, setBenchmarkRunning] = useState(false)
|
||||
const [usingMock, setUsingMock] = useState(false)
|
||||
|
||||
// Load status
|
||||
const loadStatus = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/models/status`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setStatus(data)
|
||||
setBackend(data.active_backend || 'auto')
|
||||
setUsingMock(false)
|
||||
} else {
|
||||
setUsingMock(true)
|
||||
}
|
||||
} catch {
|
||||
setUsingMock(true)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Load models
|
||||
const loadModels = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/models`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
if (data.models?.length) setModels(data.models)
|
||||
}
|
||||
} catch {
|
||||
// Keep mock data
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Load benchmarks
|
||||
const loadBenchmarks = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/models/benchmarks`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
if (data.benchmarks?.length) setBenchmarks(data.benchmarks)
|
||||
}
|
||||
} catch {
|
||||
// Keep mock data
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
loadStatus()
|
||||
loadModels()
|
||||
loadBenchmarks()
|
||||
}, [loadStatus, loadModels, loadBenchmarks])
|
||||
|
||||
// Save backend preference
|
||||
const saveBackend = async (mode: BackendMode) => {
|
||||
setBackend(mode)
|
||||
setSaving(true)
|
||||
try {
|
||||
await fetch(`${KLAUSUR_API}/api/v1/models/backend`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ backend: mode }),
|
||||
})
|
||||
await loadStatus()
|
||||
} catch {
|
||||
// Silently handle — mock mode
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Run benchmark
|
||||
const runBenchmark = async () => {
|
||||
setBenchmarkRunning(true)
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/models/benchmark`, {
|
||||
method: 'POST',
|
||||
})
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
if (data.benchmarks?.length) setBenchmarks(data.benchmarks)
|
||||
}
|
||||
await loadBenchmarks()
|
||||
} catch {
|
||||
// Keep existing data
|
||||
} finally {
|
||||
setBenchmarkRunning(false)
|
||||
}
|
||||
}
|
||||
|
||||
const tabs: { key: Tab; label: string }[] = [
|
||||
{ key: 'overview', label: 'Uebersicht' },
|
||||
{ key: 'benchmarks', label: 'Benchmarks' },
|
||||
{ key: 'configuration', label: 'Konfiguration' },
|
||||
]
|
||||
|
||||
return (
|
||||
<AIToolsSidebarResponsive>
|
||||
<div className="max-w-7xl mx-auto p-6 space-y-6">
|
||||
<PagePurpose
|
||||
title="Model Management"
|
||||
purpose="Verwaltung der ML-Modelle fuer OCR und Layout-Erkennung. Vergleich von PyTorch- und ONNX-Backends, Benchmark-Tests und Backend-Konfiguration."
|
||||
audience={['Entwickler', 'DevOps']}
|
||||
defaultCollapsed
|
||||
architecture={{
|
||||
services: ['klausur-service (FastAPI, Port 8086)'],
|
||||
databases: ['Dateisystem (Modell-Dateien)'],
|
||||
}}
|
||||
relatedPages={[
|
||||
{ name: 'OCR Pipeline', href: '/ai/ocr-pipeline', description: 'OCR-Pipeline ausfuehren' },
|
||||
{ name: 'OCR Vergleich', href: '/ai/ocr-compare', description: 'OCR-Methoden vergleichen' },
|
||||
{ name: 'GPU Infrastruktur', href: '/ai/gpu', description: 'GPU-Ressourcen verwalten' },
|
||||
]}
|
||||
/>
|
||||
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold text-slate-900">Model Management</h1>
|
||||
<p className="text-sm text-slate-500 mt-1">
|
||||
{models.length} Modelle konfiguriert
|
||||
{usingMock && (
|
||||
<span className="ml-2 text-xs bg-amber-100 text-amber-700 px-1.5 py-0.5 rounded">
|
||||
Mock-Daten (Backend nicht erreichbar)
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Status Cards */}
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-4 gap-4">
|
||||
<div className="bg-white rounded-lg border border-slate-200 px-4 py-3">
|
||||
<p className="text-xs text-slate-500 uppercase font-medium">Aktives Backend</p>
|
||||
<p className="text-lg font-semibold text-slate-900 mt-1">{status.active_backend.toUpperCase()}</p>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg border border-slate-200 px-4 py-3">
|
||||
<p className="text-xs text-slate-500 uppercase font-medium">Geladene Modelle</p>
|
||||
<p className="text-lg font-semibold text-slate-900 mt-1">{status.loaded_models.length}</p>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg border border-slate-200 px-4 py-3">
|
||||
<p className="text-xs text-slate-500 uppercase font-medium">Cache Hit-Rate</p>
|
||||
<p className="text-lg font-semibold text-slate-900 mt-1">
|
||||
{status.cache_hits + status.cache_misses > 0
|
||||
? `${((status.cache_hits / (status.cache_hits + status.cache_misses)) * 100).toFixed(1)}%`
|
||||
: '--'}
|
||||
</p>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg border border-slate-200 px-4 py-3">
|
||||
<p className="text-xs text-slate-500 uppercase font-medium">Uptime</p>
|
||||
<p className="text-lg font-semibold text-slate-900 mt-1">{formatUptime(status.uptime_s)}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="border-b border-slate-200">
|
||||
<nav className="flex gap-4">
|
||||
{tabs.map(t => (
|
||||
<button
|
||||
key={t.key}
|
||||
onClick={() => setTab(t.key)}
|
||||
className={`pb-3 px-1 text-sm font-medium border-b-2 transition-colors ${
|
||||
tab === t.key
|
||||
? 'border-teal-500 text-teal-600'
|
||||
: 'border-transparent text-slate-500 hover:text-slate-700'
|
||||
}`}
|
||||
>
|
||||
{t.label}
|
||||
</button>
|
||||
))}
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
{/* Overview Tab */}
|
||||
{tab === 'overview' && (
|
||||
<div className="space-y-4">
|
||||
<h3 className="text-sm font-medium text-slate-700">Verfuegbare Modelle</h3>
|
||||
<div className="grid gap-4 sm:grid-cols-2 lg:grid-cols-3">
|
||||
{models.map(m => (
|
||||
<div key={m.key} className="bg-white rounded-lg border border-slate-200 overflow-hidden">
|
||||
<div className="px-4 py-3 border-b border-slate-100">
|
||||
<h4 className="font-semibold text-slate-900">{m.name}</h4>
|
||||
<p className="text-xs text-slate-400 mt-0.5 font-mono">{m.key}</p>
|
||||
</div>
|
||||
<div className="px-4 py-3 space-y-3">
|
||||
{/* PyTorch */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs font-medium text-slate-600 w-16">PyTorch</span>
|
||||
<StatusBadge status={m.pytorch.status} />
|
||||
</div>
|
||||
{m.pytorch.status === 'available' && (
|
||||
<span className="text-xs text-slate-400">
|
||||
{formatBytes(m.pytorch.size_mb)} / {formatBytes(m.pytorch.ram_mb)} RAM
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{/* ONNX */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs font-medium text-slate-600 w-16">ONNX</span>
|
||||
<StatusBadge status={m.onnx.status} />
|
||||
</div>
|
||||
{m.onnx.status === 'available' && (
|
||||
<span className="text-xs text-slate-400">
|
||||
{formatBytes(m.onnx.size_mb)} / {formatBytes(m.onnx.ram_mb)} RAM
|
||||
{m.onnx.quantized && (
|
||||
<span className="ml-1 text-xs bg-violet-100 text-violet-700 px-1 rounded">INT8</span>
|
||||
)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Loaded Models List */}
|
||||
{status.loaded_models.length > 0 && (
|
||||
<div>
|
||||
<h3 className="text-sm font-medium text-slate-700 mb-2">Aktuell geladen</h3>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{status.loaded_models.map((m, i) => (
|
||||
<span key={i} className="inline-flex items-center px-3 py-1 rounded-full text-sm bg-teal-50 text-teal-700 border border-teal-200">
|
||||
{m}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Benchmarks Tab */}
|
||||
{tab === 'benchmarks' && (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium text-slate-700">PyTorch vs ONNX Vergleich</h3>
|
||||
<button
|
||||
onClick={runBenchmark}
|
||||
disabled={benchmarkRunning}
|
||||
className="inline-flex items-center gap-2 px-4 py-2 bg-teal-600 text-white rounded-lg hover:bg-teal-700 disabled:opacity-50 disabled:cursor-not-allowed text-sm font-medium transition-colors"
|
||||
>
|
||||
{benchmarkRunning ? (
|
||||
<>
|
||||
<svg className="animate-spin h-4 w-4" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
|
||||
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z" />
|
||||
</svg>
|
||||
Benchmark laeuft...
|
||||
</>
|
||||
) : (
|
||||
'Benchmark starten'
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="bg-white rounded-lg border border-slate-200 overflow-hidden">
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-sm">
|
||||
<thead>
|
||||
<tr className="border-b border-slate-200 bg-slate-50 text-left text-slate-500">
|
||||
<th className="px-4 py-3 font-medium">Modell</th>
|
||||
<th className="px-4 py-3 font-medium">Backend</th>
|
||||
<th className="px-4 py-3 font-medium">Quantisierung</th>
|
||||
<th className="px-4 py-3 font-medium text-right">Groesse</th>
|
||||
<th className="px-4 py-3 font-medium text-right">RAM</th>
|
||||
<th className="px-4 py-3 font-medium text-right">Inferenz</th>
|
||||
<th className="px-4 py-3 font-medium text-right">Ladezeit</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{benchmarks.map((b, i) => (
|
||||
<tr key={i} className="border-b border-slate-100 hover:bg-slate-50">
|
||||
<td className="px-4 py-3 font-medium text-slate-900">{b.model}</td>
|
||||
<td className="px-4 py-3">
|
||||
<span className={`inline-flex items-center px-2 py-0.5 rounded text-xs font-medium ${
|
||||
b.backend === 'ONNX'
|
||||
? 'bg-violet-100 text-violet-700'
|
||||
: 'bg-orange-100 text-orange-700'
|
||||
}`}>
|
||||
{b.backend}
|
||||
</span>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-slate-600">{b.quantization}</td>
|
||||
<td className="px-4 py-3 text-right text-slate-600">{formatBytes(b.size_mb)}</td>
|
||||
<td className="px-4 py-3 text-right text-slate-600">{formatBytes(b.ram_mb)}</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
<span className={`font-mono ${b.inference_ms < 50 ? 'text-emerald-600' : b.inference_ms < 100 ? 'text-amber-600' : 'text-red-600'}`}>
|
||||
{b.inference_ms} ms
|
||||
</span>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right text-slate-500">{b.load_time_s.toFixed(1)}s</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{benchmarks.length === 0 && (
|
||||
<div className="text-center py-12 text-slate-400">
|
||||
<p className="text-lg">Keine Benchmark-Daten</p>
|
||||
<p className="text-sm mt-1">Klicken Sie "Benchmark starten" um einen Vergleich durchzufuehren.</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Configuration Tab */}
|
||||
{tab === 'configuration' && (
|
||||
<div className="space-y-6">
|
||||
{/* Backend Selector */}
|
||||
<div className="bg-white rounded-lg border border-slate-200 p-5">
|
||||
<h3 className="text-sm font-semibold text-slate-900 mb-1">Inference Backend</h3>
|
||||
<p className="text-sm text-slate-500 mb-4">
|
||||
Waehlen Sie welches Backend fuer die Modell-Inferenz verwendet werden soll.
|
||||
</p>
|
||||
<div className="space-y-3">
|
||||
{([
|
||||
{
|
||||
mode: 'auto' as const,
|
||||
label: 'Auto',
|
||||
desc: 'ONNX wenn verfuegbar, Fallback auf PyTorch.',
|
||||
},
|
||||
{
|
||||
mode: 'pytorch' as const,
|
||||
label: 'PyTorch',
|
||||
desc: 'Immer PyTorch verwenden. Hoeherer RAM-Verbrauch, volle Flexibilitaet.',
|
||||
},
|
||||
{
|
||||
mode: 'onnx' as const,
|
||||
label: 'ONNX',
|
||||
desc: 'Immer ONNX verwenden. Schneller und weniger RAM, Fehler wenn nicht vorhanden.',
|
||||
},
|
||||
] as const).map(opt => (
|
||||
<label
|
||||
key={opt.mode}
|
||||
className={`flex items-start gap-3 p-3 rounded-lg border cursor-pointer transition-colors ${
|
||||
backend === opt.mode
|
||||
? 'border-teal-300 bg-teal-50'
|
||||
: 'border-slate-200 hover:bg-slate-50'
|
||||
}`}
|
||||
>
|
||||
<input
|
||||
type="radio"
|
||||
name="backend"
|
||||
value={opt.mode}
|
||||
checked={backend === opt.mode}
|
||||
onChange={() => saveBackend(opt.mode)}
|
||||
disabled={saving}
|
||||
className="mt-1 text-teal-600 focus:ring-teal-500"
|
||||
/>
|
||||
<div>
|
||||
<span className="font-medium text-slate-900">{opt.label}</span>
|
||||
<p className="text-sm text-slate-500 mt-0.5">{opt.desc}</p>
|
||||
</div>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
{saving && (
|
||||
<p className="text-xs text-teal-600 mt-3">Speichere...</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Model Details Table */}
|
||||
<div className="bg-white rounded-lg border border-slate-200 p-5">
|
||||
<h3 className="text-sm font-semibold text-slate-900 mb-4">Modell-Details</h3>
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-sm">
|
||||
<thead>
|
||||
<tr className="border-b border-slate-200 text-left text-slate-500">
|
||||
<th className="pb-2 font-medium">Modell</th>
|
||||
<th className="pb-2 font-medium">PyTorch</th>
|
||||
<th className="pb-2 font-medium text-right">Groesse (PT)</th>
|
||||
<th className="pb-2 font-medium">ONNX</th>
|
||||
<th className="pb-2 font-medium text-right">Groesse (ONNX)</th>
|
||||
<th className="pb-2 font-medium text-right">Einsparung</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{models.map(m => {
|
||||
const ptAvail = m.pytorch.status === 'available'
|
||||
const oxAvail = m.onnx.status === 'available'
|
||||
const savings = ptAvail && oxAvail && m.pytorch.size_mb > 0
|
||||
? Math.round((1 - m.onnx.size_mb / m.pytorch.size_mb) * 100)
|
||||
: null
|
||||
return (
|
||||
<tr key={m.key} className="border-b border-slate-100">
|
||||
<td className="py-2.5 font-medium text-slate-900">{m.name}</td>
|
||||
<td className="py-2.5"><StatusBadge status={m.pytorch.status} /></td>
|
||||
<td className="py-2.5 text-right text-slate-500">{ptAvail ? formatBytes(m.pytorch.size_mb) : '--'}</td>
|
||||
<td className="py-2.5"><StatusBadge status={m.onnx.status} /></td>
|
||||
<td className="py-2.5 text-right text-slate-500">{oxAvail ? formatBytes(m.onnx.size_mb) : '--'}</td>
|
||||
<td className="py-2.5 text-right">
|
||||
{savings !== null ? (
|
||||
<span className="text-emerald-600 font-medium">-{savings}%</span>
|
||||
) : (
|
||||
<span className="text-slate-300">--</span>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</AIToolsSidebarResponsive>
|
||||
)
|
||||
}
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { PagePurpose } from '@/components/common/PagePurpose'
|
||||
import { AIToolsSidebarResponsive } from '@/components/ai/AIToolsSidebar'
|
||||
|
||||
const KLAUSUR_API = '/klausur-api'
|
||||
|
||||
@@ -271,7 +270,7 @@ export default function GroundTruthReviewPage() {
|
||||
: null
|
||||
|
||||
return (
|
||||
<AIToolsSidebarResponsive>
|
||||
<div className="space-y-6">
|
||||
<div className="max-w-[1600px] mx-auto p-4 space-y-4">
|
||||
<PagePurpose
|
||||
title="Ground Truth Review"
|
||||
@@ -588,6 +587,6 @@ export default function GroundTruthReviewPage() {
|
||||
</details>
|
||||
)}
|
||||
</div>
|
||||
</AIToolsSidebarResponsive>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -233,6 +233,15 @@ export interface ExcludeRegion {
|
||||
label?: string
|
||||
}
|
||||
|
||||
export interface DocLayoutRegion {
|
||||
x: number
|
||||
y: number
|
||||
w: number
|
||||
h: number
|
||||
class_name: string
|
||||
confidence: number
|
||||
}
|
||||
|
||||
export interface StructureResult {
|
||||
image_width: number
|
||||
image_height: number
|
||||
@@ -246,6 +255,9 @@ export interface StructureResult {
|
||||
word_count: number
|
||||
border_ghosts_removed?: number
|
||||
duration_seconds: number
|
||||
/** PP-DocLayout regions (only present when method=ppdoclayout) */
|
||||
layout_regions?: DocLayoutRegion[]
|
||||
detection_method?: 'opencv' | 'ppdoclayout'
|
||||
}
|
||||
|
||||
export interface StructureBox {
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { PagePurpose } from '@/components/common/PagePurpose'
|
||||
import { AIToolsSidebarResponsive } from '@/components/ai/AIToolsSidebar'
|
||||
|
||||
const KLAUSUR_API = '/klausur-api'
|
||||
|
||||
@@ -165,7 +164,7 @@ export default function OCRRegressionPage() {
|
||||
const totalError = results.filter(r => r.status === 'error').length
|
||||
|
||||
return (
|
||||
<AIToolsSidebarResponsive>
|
||||
<div className="space-y-6">
|
||||
<div className="max-w-7xl mx-auto p-6 space-y-6">
|
||||
<PagePurpose
|
||||
title="OCR Regression Tests"
|
||||
@@ -399,6 +398,6 @@ export default function OCRRegressionPage() {
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</AIToolsSidebarResponsive>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,26 @@ const COLOR_HEX: Record<string, string> = {
|
||||
purple: '#9333ea',
|
||||
}
|
||||
|
||||
type DetectionMethod = 'auto' | 'opencv' | 'ppdoclayout'
|
||||
|
||||
/** Color map for PP-DocLayout region classes */
|
||||
const DOCLAYOUT_CLASS_COLORS: Record<string, string> = {
|
||||
table: '#2563eb',
|
||||
figure: '#16a34a',
|
||||
title: '#ea580c',
|
||||
text: '#6b7280',
|
||||
list: '#9333ea',
|
||||
header: '#0ea5e9',
|
||||
footer: '#64748b',
|
||||
equation: '#dc2626',
|
||||
}
|
||||
|
||||
const DOCLAYOUT_DEFAULT_COLOR = '#a3a3a3'
|
||||
|
||||
function getDocLayoutColor(className: string): string {
|
||||
return DOCLAYOUT_CLASS_COLORS[className.toLowerCase()] || DOCLAYOUT_DEFAULT_COLOR
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a mouse event on the image container to image-pixel coordinates.
|
||||
* The image uses object-contain inside an A4-ratio container, so we need
|
||||
@@ -96,6 +116,7 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [hasRun, setHasRun] = useState(false)
|
||||
const [overlayTs, setOverlayTs] = useState(0)
|
||||
const [detectionMethod, setDetectionMethod] = useState<DetectionMethod>('auto')
|
||||
|
||||
// Exclude region drawing state
|
||||
const [excludeRegions, setExcludeRegions] = useState<ExcludeRegion[]>([])
|
||||
@@ -106,7 +127,9 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
const [drawMode, setDrawMode] = useState(false)
|
||||
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const overlayContainerRef = useRef<HTMLDivElement>(null)
|
||||
const [containerSize, setContainerSize] = useState({ w: 0, h: 0 })
|
||||
const [overlayContainerSize, setOverlayContainerSize] = useState({ w: 0, h: 0 })
|
||||
|
||||
// Track container size for overlay positioning
|
||||
useEffect(() => {
|
||||
@@ -121,6 +144,19 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
return () => obs.disconnect()
|
||||
}, [])
|
||||
|
||||
// Track overlay container size for PP-DocLayout region overlays
|
||||
useEffect(() => {
|
||||
const el = overlayContainerRef.current
|
||||
if (!el) return
|
||||
const obs = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
setOverlayContainerSize({ w: entry.contentRect.width, h: entry.contentRect.height })
|
||||
}
|
||||
})
|
||||
obs.observe(el)
|
||||
return () => obs.disconnect()
|
||||
}, [])
|
||||
|
||||
// Auto-trigger detection on mount
|
||||
useEffect(() => {
|
||||
if (!sessionId || hasRun) return
|
||||
@@ -131,7 +167,8 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-structure`, {
|
||||
const params = detectionMethod !== 'auto' ? `?method=${detectionMethod}` : ''
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-structure${params}`, {
|
||||
method: 'POST',
|
||||
})
|
||||
|
||||
@@ -158,7 +195,8 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
setDetecting(true)
|
||||
setError(null)
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-structure`, {
|
||||
const params = detectionMethod !== 'auto' ? `?method=${detectionMethod}` : ''
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-structure${params}`, {
|
||||
method: 'POST',
|
||||
})
|
||||
if (!res.ok) throw new Error('Erneute Erkennung fehlgeschlagen')
|
||||
@@ -278,6 +316,31 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Detection method toggle */}
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs font-medium text-gray-500 dark:text-gray-400">Methode:</span>
|
||||
{(['auto', 'opencv', 'ppdoclayout'] as DetectionMethod[]).map((method) => (
|
||||
<button
|
||||
key={method}
|
||||
onClick={() => setDetectionMethod(method)}
|
||||
className={`px-3 py-1.5 text-xs rounded-md font-medium transition-colors ${
|
||||
detectionMethod === method
|
||||
? 'bg-teal-600 text-white'
|
||||
: 'bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-300 hover:bg-gray-200 dark:hover:bg-gray-600'
|
||||
}`}
|
||||
>
|
||||
{method === 'auto' ? 'Auto' : method === 'opencv' ? 'OpenCV' : 'PP-DocLayout'}
|
||||
</button>
|
||||
))}
|
||||
<span className="text-[10px] text-gray-400 dark:text-gray-500 ml-1">
|
||||
{detectionMethod === 'auto'
|
||||
? 'PP-DocLayout wenn verfuegbar, sonst OpenCV'
|
||||
: detectionMethod === 'ppdoclayout'
|
||||
? 'ONNX-basierte Layouterkennung mit Klassifikation'
|
||||
: 'Klassische OpenCV-Konturerkennung'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Draw mode toggle */}
|
||||
{result && (
|
||||
<div className="flex items-center gap-3">
|
||||
@@ -376,8 +439,17 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
<div className="space-y-2">
|
||||
<div className="text-xs font-medium text-gray-500 dark:text-gray-400 uppercase tracking-wider">
|
||||
Erkannte Struktur
|
||||
{result?.detection_method && (
|
||||
<span className="ml-2 text-[10px] font-normal normal-case">
|
||||
({result.detection_method === 'ppdoclayout' ? 'PP-DocLayout' : 'OpenCV'})
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="relative bg-gray-100 dark:bg-gray-800 rounded-lg overflow-hidden" style={{ aspectRatio: '210/297' }}>
|
||||
<div
|
||||
ref={overlayContainerRef}
|
||||
className="relative bg-gray-100 dark:bg-gray-800 rounded-lg overflow-hidden"
|
||||
style={{ aspectRatio: '210/297' }}
|
||||
>
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={overlayUrl}
|
||||
@@ -387,7 +459,52 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
(e.target as HTMLImageElement).style.display = 'none'
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* PP-DocLayout region overlays with class colors and labels */}
|
||||
{result?.layout_regions && overlayContainerSize.w > 0 && result.layout_regions.map((region, i) => {
|
||||
const pos = imageToOverlayPct(region, overlayContainerSize.w, overlayContainerSize.h, result.image_width, result.image_height)
|
||||
const color = getDocLayoutColor(region.class_name)
|
||||
return (
|
||||
<div
|
||||
key={`layout-${i}`}
|
||||
className="absolute border-2 pointer-events-none"
|
||||
style={{
|
||||
...pos,
|
||||
borderColor: color,
|
||||
backgroundColor: `${color}18`,
|
||||
}}
|
||||
>
|
||||
<span
|
||||
className="absolute -top-4 left-0 px-1 py-px text-[9px] font-medium text-white rounded-sm whitespace-nowrap leading-tight"
|
||||
style={{ backgroundColor: color }}
|
||||
>
|
||||
{region.class_name} {Math.round(region.confidence * 100)}%
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* PP-DocLayout legend */}
|
||||
{result?.layout_regions && result.layout_regions.length > 0 && (() => {
|
||||
const usedClasses = [...new Set(result.layout_regions!.map((r) => r.class_name.toLowerCase()))]
|
||||
return (
|
||||
<div className="flex flex-wrap gap-x-3 gap-y-1 px-1">
|
||||
{usedClasses.sort().map((cls) => (
|
||||
<span key={cls} className="inline-flex items-center gap-1 text-[10px] text-gray-500 dark:text-gray-400">
|
||||
<span
|
||||
className="w-2.5 h-2.5 rounded-sm border"
|
||||
style={{
|
||||
backgroundColor: `${getDocLayoutColor(cls)}30`,
|
||||
borderColor: getDocLayoutColor(cls),
|
||||
}}
|
||||
/>
|
||||
{cls}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
})()}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -430,6 +547,11 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
<span className="inline-flex items-center gap-1.5 px-3 py-1 rounded-full bg-amber-50 dark:bg-amber-900/20 text-amber-700 dark:text-amber-400 text-xs font-medium">
|
||||
{result.boxes.length} Box(en)
|
||||
</span>
|
||||
{result.layout_regions && result.layout_regions.length > 0 && (
|
||||
<span className="inline-flex items-center gap-1.5 px-3 py-1 rounded-full bg-indigo-50 dark:bg-indigo-900/20 text-indigo-700 dark:text-indigo-400 text-xs font-medium">
|
||||
{result.layout_regions.length} Layout-Region(en)
|
||||
</span>
|
||||
)}
|
||||
{result.graphics && result.graphics.length > 0 && (
|
||||
<span className="inline-flex items-center gap-1.5 px-3 py-1 rounded-full bg-purple-50 dark:bg-purple-900/20 text-purple-700 dark:text-purple-400 text-xs font-medium">
|
||||
{result.graphics.length} Grafik(en)
|
||||
@@ -451,6 +573,11 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
</span>
|
||||
)}
|
||||
<span className="text-gray-400 text-xs ml-auto">
|
||||
{result.detection_method && (
|
||||
<span className="mr-1.5">
|
||||
{result.detection_method === 'ppdoclayout' ? 'PP-DocLayout' : 'OpenCV'} |
|
||||
</span>
|
||||
)}
|
||||
{result.image_width}x{result.image_height}px | {result.duration_seconds}s
|
||||
</span>
|
||||
</div>
|
||||
@@ -491,6 +618,37 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* PP-DocLayout regions detail */}
|
||||
{result.layout_regions && result.layout_regions.length > 0 && (
|
||||
<div>
|
||||
<h4 className="text-xs font-medium text-gray-500 dark:text-gray-400 mb-2">
|
||||
PP-DocLayout Regionen ({result.layout_regions.length})
|
||||
</h4>
|
||||
<div className="space-y-1.5">
|
||||
{result.layout_regions.map((region, i) => {
|
||||
const color = getDocLayoutColor(region.class_name)
|
||||
return (
|
||||
<div key={i} className="flex items-center gap-3 text-xs">
|
||||
<span
|
||||
className="w-3 h-3 rounded-sm flex-shrink-0 border"
|
||||
style={{ backgroundColor: `${color}40`, borderColor: color }}
|
||||
/>
|
||||
<span className="text-gray-600 dark:text-gray-400 font-medium min-w-[60px]">
|
||||
{region.class_name}
|
||||
</span>
|
||||
<span className="font-mono text-gray-500">
|
||||
{region.w}x{region.h}px @ ({region.x}, {region.y})
|
||||
</span>
|
||||
<span className="text-gray-400">
|
||||
{Math.round(region.confidence * 100)}%
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Zones detail */}
|
||||
<div>
|
||||
<h4 className="text-xs font-medium text-gray-500 dark:text-gray-400 mb-2">Seitenzonen</h4>
|
||||
|
||||
@@ -200,6 +200,15 @@ export const navigation: NavCategory[] = [
|
||||
audience: ['Entwickler', 'QA'],
|
||||
subgroup: 'KI-Werkzeuge',
|
||||
},
|
||||
{
|
||||
id: 'model-management',
|
||||
name: 'Model Management',
|
||||
href: '/ai/model-management',
|
||||
description: 'ONNX & PyTorch Modell-Verwaltung',
|
||||
purpose: 'Verfuegbare ML-Modelle verwalten (PyTorch vs ONNX), Backend umschalten, Benchmark-Vergleiche ausfuehren und RAM/Performance-Metriken einsehen.',
|
||||
audience: ['Entwickler', 'DevOps'],
|
||||
subgroup: 'KI-Werkzeuge',
|
||||
},
|
||||
{
|
||||
id: 'agents',
|
||||
name: 'Agent Management',
|
||||
|
||||
@@ -1588,6 +1588,34 @@ cd klausur-service/backend && pytest tests/test_paddle_kombi.py -v # 36 Tests
|
||||
|
||||
---
|
||||
|
||||
## ONNX Backends und PP-DocLayout (Sprint 2)
|
||||
|
||||
### TrOCR ONNX Runtime
|
||||
|
||||
Ab Sprint 2 unterstuetzt die Pipeline **TrOCR mit ONNX Runtime** als Alternative zu PyTorch.
|
||||
ONNX reduziert den RAM-Verbrauch von ~1.1 GB auf ~300 MB pro Modell und beschleunigt
|
||||
die Inferenz um ~3x. Ideal fuer Hardware Tier 2 (8 GB RAM).
|
||||
|
||||
**Backend-Auswahl:** Umgebungsvariable `TROCR_BACKEND` (`auto` | `pytorch` | `onnx`).
|
||||
Im `auto`-Modus wird ONNX bevorzugt, wenn exportierte Modelle vorhanden sind.
|
||||
|
||||
Vollstaendige Dokumentation: [TrOCR ONNX Runtime](TrOCR-ONNX.md)
|
||||
|
||||
### PP-DocLayout (Document Layout Analysis)
|
||||
|
||||
PP-DocLayout ersetzt die bisherige manuelle Zonen-Erkennung durch ein vortrainiertes
|
||||
Layout-Analyse-Modell. Es erkennt automatisch:
|
||||
|
||||
- **Tabellen** (vocab_table, generic_table)
|
||||
- **Ueberschriften** (title, section_header)
|
||||
- **Bilder/Grafiken** (figure, illustration)
|
||||
- **Textbloecke** (paragraph, list)
|
||||
|
||||
PP-DocLayout laeuft als ONNX-Modell (~15 MB) und benoetigt kein PyTorch.
|
||||
Die Ergebnisse fliessen in Schritt 5 (Spaltenerkennung) und den Grid Editor ein.
|
||||
|
||||
---
|
||||
|
||||
## Aenderungshistorie
|
||||
|
||||
| Datum | Version | Aenderung |
|
||||
|
||||
83
docs-src/services/klausur-service/TrOCR-ONNX.md
Normal file
83
docs-src/services/klausur-service/TrOCR-ONNX.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# TrOCR ONNX Runtime
|
||||
|
||||
## Uebersicht
|
||||
|
||||
TrOCR (Transformer-based OCR) kann sowohl mit PyTorch als auch mit ONNX Runtime betrieben werden. ONNX bietet deutlich reduzierten RAM-Verbrauch und schnellere Inferenz — ideal fuer den Offline-Betrieb auf Hardware Tier 2 (8 GB RAM).
|
||||
|
||||
## Export-Prozess
|
||||
|
||||
### Voraussetzungen
|
||||
|
||||
- Python 3.10+
|
||||
- `optimum>=1.17.0` (Apache-2.0)
|
||||
- `onnxruntime` (bereits installiert via RapidOCR)
|
||||
|
||||
### Export ausfuehren
|
||||
|
||||
```bash
|
||||
python scripts/export-trocr-onnx.py --model both
|
||||
```
|
||||
|
||||
Dies exportiert:
|
||||
- `models/onnx/trocr-base-printed/` (~85 MB fp32 → ~25 MB int8)
|
||||
- `models/onnx/trocr-base-handwritten/` (~85 MB fp32 → ~25 MB int8)
|
||||
|
||||
### Quantisierung
|
||||
|
||||
Int8-Quantisierung reduziert Modellgroesse um ~70% mit weniger als 2% Genauigkeitsverlust.
|
||||
|
||||
## Runtime-Konfiguration
|
||||
|
||||
### Umgebungsvariablen
|
||||
|
||||
| Variable | Default | Beschreibung |
|
||||
|----------|---------|--------------|
|
||||
| `TROCR_BACKEND` | `auto` | Backend-Auswahl: `auto`, `pytorch`, `onnx` |
|
||||
| `TROCR_ONNX_DIR` | (siehe unten) | Pfad zu ONNX-Modellen |
|
||||
|
||||
### Backend-Modi
|
||||
|
||||
- **auto** (empfohlen): ONNX wenn verfuegbar, sonst PyTorch Fallback
|
||||
- **pytorch**: Erzwingt PyTorch (hoeherer RAM, aber bewaehrt)
|
||||
- **onnx**: Erzwingt ONNX (Fehler wenn Modell nicht vorhanden)
|
||||
|
||||
### Modell-Pfade (Suchpfade)
|
||||
|
||||
1. `$TROCR_ONNX_DIR/trocr-base-{variant}/`
|
||||
2. `/root/.cache/huggingface/onnx/trocr-base-{variant}/` (Docker)
|
||||
3. `models/onnx/trocr-base-{variant}/` (lokale Entwicklung)
|
||||
|
||||
## Hardware-Anforderungen
|
||||
|
||||
| Metrik | PyTorch float32 | ONNX int8 |
|
||||
|--------|-----------------|-----------|
|
||||
| Modell-Groesse | ~340 MB | ~50 MB |
|
||||
| RAM (Printed) | ~1.1 GB | ~300 MB |
|
||||
| RAM (Handwritten) | ~1.1 GB | ~300 MB |
|
||||
| Inferenz/Zeile | ~120 ms | ~40 ms |
|
||||
| Mindest-RAM | 4 GB | 2 GB |
|
||||
|
||||
## Benchmark
|
||||
|
||||
```bash
|
||||
# PyTorch Baseline (Sprint 1)
|
||||
python scripts/benchmark-trocr.py > benchmark-baseline.json
|
||||
|
||||
# ONNX Benchmark
|
||||
python scripts/benchmark-trocr.py --backend onnx > benchmark-onnx.json
|
||||
```
|
||||
|
||||
Benchmark-Vergleiche koennen im Admin unter `/ai/model-management` eingesehen werden.
|
||||
|
||||
## Verifikation
|
||||
|
||||
Der Export-Script verifiziert automatisch, dass ONNX-Output weniger als 2% von PyTorch abweicht. Manuelle Verifikation:
|
||||
|
||||
```python
|
||||
# Im Container
|
||||
python -c "
|
||||
from services.trocr_onnx_service import is_onnx_available, get_onnx_model_status
|
||||
print('Available:', is_onnx_available())
|
||||
print('Status:', get_onnx_model_status())
|
||||
"
|
||||
```
|
||||
413
klausur-service/backend/cv_doclayout_detect.py
Normal file
413
klausur-service/backend/cv_doclayout_detect.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
PP-DocLayout ONNX Document Layout Detection.
|
||||
|
||||
Uses PP-DocLayout ONNX model to detect document structure regions:
|
||||
table, figure, title, text, list, header, footer, equation, reference, abstract
|
||||
|
||||
Fallback: If ONNX model not available, returns empty list (caller should
|
||||
fall back to OpenCV-based detection in cv_graphic_detect.py).
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"detect_layout_regions",
|
||||
"is_doclayout_available",
|
||||
"get_doclayout_status",
|
||||
"LayoutRegion",
|
||||
"DOCLAYOUT_CLASSES",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Class labels (PP-DocLayout default order)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCLAYOUT_CLASSES = [
|
||||
"table", "figure", "title", "text", "list",
|
||||
"header", "footer", "equation", "reference", "abstract",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutRegion:
|
||||
"""A detected document layout region."""
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
label: str # table, figure, title, text, list, etc.
|
||||
confidence: float
|
||||
label_index: int # raw class index
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX model loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MODEL_SEARCH_PATHS = [
|
||||
# 1. Explicit environment variable
|
||||
os.environ.get("DOCLAYOUT_ONNX_PATH", ""),
|
||||
# 2. Docker default cache path
|
||||
"/root/.cache/huggingface/onnx/pp-doclayout/model.onnx",
|
||||
# 3. Local dev relative to working directory
|
||||
"models/onnx/pp-doclayout/model.onnx",
|
||||
]
|
||||
|
||||
_onnx_session: Optional[object] = None
|
||||
_model_path: Optional[str] = None
|
||||
_load_attempted: bool = False
|
||||
_load_error: Optional[str] = None
|
||||
|
||||
|
||||
def _find_model_path() -> Optional[str]:
|
||||
"""Search for the ONNX model file in known locations."""
|
||||
for p in _MODEL_SEARCH_PATHS:
|
||||
if p and Path(p).is_file():
|
||||
return str(Path(p).resolve())
|
||||
return None
|
||||
|
||||
|
||||
def _load_onnx_session():
|
||||
"""Lazy-load the ONNX runtime session (once)."""
|
||||
global _onnx_session, _model_path, _load_attempted, _load_error
|
||||
|
||||
if _load_attempted:
|
||||
return _onnx_session
|
||||
|
||||
_load_attempted = True
|
||||
|
||||
path = _find_model_path()
|
||||
if path is None:
|
||||
_load_error = "ONNX model not found in any search path"
|
||||
logger.info("PP-DocLayout: %s", _load_error)
|
||||
return None
|
||||
|
||||
try:
|
||||
import onnxruntime as ort # type: ignore[import-untyped]
|
||||
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# Prefer CPU – keeps the GPU free for OCR / LLM.
|
||||
providers = ["CPUExecutionProvider"]
|
||||
_onnx_session = ort.InferenceSession(path, sess_options, providers=providers)
|
||||
_model_path = path
|
||||
logger.info("PP-DocLayout: model loaded from %s", path)
|
||||
except ImportError:
|
||||
_load_error = "onnxruntime not installed"
|
||||
logger.info("PP-DocLayout: %s", _load_error)
|
||||
except Exception as exc:
|
||||
_load_error = str(exc)
|
||||
logger.warning("PP-DocLayout: failed to load model from %s: %s", path, exc)
|
||||
|
||||
return _onnx_session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_doclayout_available() -> bool:
|
||||
"""Return True if the ONNX model can be loaded successfully."""
|
||||
return _load_onnx_session() is not None
|
||||
|
||||
|
||||
def get_doclayout_status() -> Dict:
|
||||
"""Return diagnostic information about the DocLayout backend."""
|
||||
_load_onnx_session() # ensure we tried
|
||||
return {
|
||||
"available": _onnx_session is not None,
|
||||
"model_path": _model_path,
|
||||
"load_error": _load_error,
|
||||
"classes": DOCLAYOUT_CLASSES,
|
||||
"class_count": len(DOCLAYOUT_CLASSES),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-processing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INPUT_SIZE = 800 # PP-DocLayout expects 800x800
|
||||
|
||||
|
||||
def preprocess_image(img_bgr: np.ndarray) -> tuple:
|
||||
"""Resize + normalize image for PP-DocLayout ONNX input.
|
||||
|
||||
Returns:
|
||||
(input_tensor, scale_x, scale_y, pad_x, pad_y)
|
||||
where scale/pad allow mapping boxes back to original coords.
|
||||
"""
|
||||
orig_h, orig_w = img_bgr.shape[:2]
|
||||
|
||||
# Compute scale to fit within _INPUT_SIZE keeping aspect ratio
|
||||
scale = min(_INPUT_SIZE / orig_w, _INPUT_SIZE / orig_h)
|
||||
new_w = int(orig_w * scale)
|
||||
new_h = int(orig_h * scale)
|
||||
|
||||
import cv2 # local import — cv2 is always available in this service
|
||||
resized = cv2.resize(img_bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Pad to _INPUT_SIZE x _INPUT_SIZE with gray (114)
|
||||
pad_x = (_INPUT_SIZE - new_w) // 2
|
||||
pad_y = (_INPUT_SIZE - new_h) // 2
|
||||
padded = np.full((_INPUT_SIZE, _INPUT_SIZE, 3), 114, dtype=np.uint8)
|
||||
padded[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized
|
||||
|
||||
# Normalize to [0, 1] float32
|
||||
blob = padded.astype(np.float32) / 255.0
|
||||
|
||||
# HWC → CHW
|
||||
blob = blob.transpose(2, 0, 1)
|
||||
|
||||
# Add batch dimension → (1, 3, 800, 800)
|
||||
blob = np.expand_dims(blob, axis=0)
|
||||
|
||||
return blob, scale, pad_x, pad_y
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-Maximum Suppression (NMS)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_iou(box_a: np.ndarray, box_b: np.ndarray) -> float:
|
||||
"""Compute IoU between two boxes [x1, y1, x2, y2]."""
|
||||
ix1 = max(box_a[0], box_b[0])
|
||||
iy1 = max(box_a[1], box_b[1])
|
||||
ix2 = min(box_a[2], box_b[2])
|
||||
iy2 = min(box_a[3], box_b[3])
|
||||
|
||||
inter = max(0.0, ix2 - ix1) * max(0.0, iy2 - iy1)
|
||||
if inter == 0:
|
||||
return 0.0
|
||||
|
||||
area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
|
||||
area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
|
||||
union = area_a + area_b - inter
|
||||
return inter / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def nms(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float = 0.5) -> List[int]:
|
||||
"""Apply greedy Non-Maximum Suppression.
|
||||
|
||||
Args:
|
||||
boxes: (N, 4) array of [x1, y1, x2, y2].
|
||||
scores: (N,) confidence scores.
|
||||
iou_threshold: Overlap threshold for suppression.
|
||||
|
||||
Returns:
|
||||
List of kept indices.
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
order = np.argsort(scores)[::-1].tolist()
|
||||
keep: List[int] = []
|
||||
|
||||
while order:
|
||||
i = order.pop(0)
|
||||
keep.append(i)
|
||||
remaining = []
|
||||
for j in order:
|
||||
if _compute_iou(boxes[i], boxes[j]) < iou_threshold:
|
||||
remaining.append(j)
|
||||
order = remaining
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Post-processing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _postprocess(
|
||||
outputs: list,
|
||||
scale: float,
|
||||
pad_x: int,
|
||||
pad_y: int,
|
||||
orig_w: int,
|
||||
orig_h: int,
|
||||
confidence_threshold: float,
|
||||
max_regions: int,
|
||||
) -> List[LayoutRegion]:
|
||||
"""Parse ONNX output tensors into LayoutRegion list.
|
||||
|
||||
PP-DocLayout ONNX typically outputs one tensor of shape
|
||||
(1, N, 6) or three tensors (boxes, scores, class_ids).
|
||||
We handle both common formats.
|
||||
"""
|
||||
regions: List[LayoutRegion] = []
|
||||
|
||||
# --- Determine output format ---
|
||||
if len(outputs) == 1:
|
||||
# Single tensor: (1, N, 4+1+1) = (batch, detections, [x1,y1,x2,y2,score,class])
|
||||
raw = np.squeeze(outputs[0]) # (N, 6) or (N, 5+num_classes)
|
||||
if raw.ndim == 1:
|
||||
raw = raw.reshape(1, -1)
|
||||
if raw.shape[0] == 0:
|
||||
return []
|
||||
|
||||
if raw.shape[1] == 6:
|
||||
# Format: x1, y1, x2, y2, score, class_id
|
||||
all_boxes = raw[:, :4]
|
||||
all_scores = raw[:, 4]
|
||||
all_classes = raw[:, 5].astype(int)
|
||||
elif raw.shape[1] > 6:
|
||||
# Format: x1, y1, x2, y2, obj_conf, cls0_conf, cls1_conf, ...
|
||||
all_boxes = raw[:, :4]
|
||||
cls_scores = raw[:, 5:]
|
||||
all_classes = np.argmax(cls_scores, axis=1)
|
||||
all_scores = raw[:, 4] * np.max(cls_scores, axis=1)
|
||||
else:
|
||||
logger.warning("PP-DocLayout: unexpected output shape %s", raw.shape)
|
||||
return []
|
||||
|
||||
elif len(outputs) == 3:
|
||||
# Three tensors: boxes (N,4), scores (N,), class_ids (N,)
|
||||
all_boxes = np.squeeze(outputs[0])
|
||||
all_scores = np.squeeze(outputs[1])
|
||||
all_classes = np.squeeze(outputs[2]).astype(int)
|
||||
if all_boxes.ndim == 1:
|
||||
all_boxes = all_boxes.reshape(1, 4)
|
||||
all_scores = np.array([all_scores])
|
||||
all_classes = np.array([all_classes])
|
||||
else:
|
||||
logger.warning("PP-DocLayout: unexpected %d output tensors", len(outputs))
|
||||
return []
|
||||
|
||||
# --- Confidence filter ---
|
||||
mask = all_scores >= confidence_threshold
|
||||
boxes = all_boxes[mask]
|
||||
scores = all_scores[mask]
|
||||
classes = all_classes[mask]
|
||||
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# --- NMS ---
|
||||
keep_idxs = nms(boxes, scores, iou_threshold=0.5)
|
||||
boxes = boxes[keep_idxs]
|
||||
scores = scores[keep_idxs]
|
||||
classes = classes[keep_idxs]
|
||||
|
||||
# --- Scale boxes back to original image coordinates ---
|
||||
for i in range(len(boxes)):
|
||||
x1, y1, x2, y2 = boxes[i]
|
||||
|
||||
# Remove padding offset
|
||||
x1 = (x1 - pad_x) / scale
|
||||
y1 = (y1 - pad_y) / scale
|
||||
x2 = (x2 - pad_x) / scale
|
||||
y2 = (y2 - pad_y) / scale
|
||||
|
||||
# Clamp to original dimensions
|
||||
x1 = max(0, min(x1, orig_w))
|
||||
y1 = max(0, min(y1, orig_h))
|
||||
x2 = max(0, min(x2, orig_w))
|
||||
y2 = max(0, min(y2, orig_h))
|
||||
|
||||
w = int(round(x2 - x1))
|
||||
h = int(round(y2 - y1))
|
||||
if w < 5 or h < 5:
|
||||
continue
|
||||
|
||||
cls_idx = int(classes[i])
|
||||
label = DOCLAYOUT_CLASSES[cls_idx] if 0 <= cls_idx < len(DOCLAYOUT_CLASSES) else f"class_{cls_idx}"
|
||||
|
||||
regions.append(LayoutRegion(
|
||||
x=int(round(x1)),
|
||||
y=int(round(y1)),
|
||||
width=w,
|
||||
height=h,
|
||||
label=label,
|
||||
confidence=round(float(scores[i]), 4),
|
||||
label_index=cls_idx,
|
||||
))
|
||||
|
||||
# Sort by confidence descending, limit
|
||||
regions.sort(key=lambda r: r.confidence, reverse=True)
|
||||
return regions[:max_regions]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main detection function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def detect_layout_regions(
|
||||
img_bgr: np.ndarray,
|
||||
confidence_threshold: float = 0.5,
|
||||
max_regions: int = 50,
|
||||
) -> List[LayoutRegion]:
|
||||
"""Detect document layout regions using PP-DocLayout ONNX model.
|
||||
|
||||
Args:
|
||||
img_bgr: BGR color image (OpenCV format).
|
||||
confidence_threshold: Minimum confidence to keep a detection.
|
||||
max_regions: Maximum number of regions to return.
|
||||
|
||||
Returns:
|
||||
List of LayoutRegion sorted by confidence descending.
|
||||
Returns empty list if model is not available.
|
||||
"""
|
||||
session = _load_onnx_session()
|
||||
if session is None:
|
||||
return []
|
||||
|
||||
if img_bgr is None or img_bgr.size == 0:
|
||||
return []
|
||||
|
||||
orig_h, orig_w = img_bgr.shape[:2]
|
||||
|
||||
# Pre-process
|
||||
input_tensor, scale, pad_x, pad_y = preprocess_image(img_bgr)
|
||||
|
||||
# Run inference
|
||||
try:
|
||||
input_name = session.get_inputs()[0].name
|
||||
outputs = session.run(None, {input_name: input_tensor})
|
||||
except Exception as exc:
|
||||
logger.warning("PP-DocLayout inference failed: %s", exc)
|
||||
return []
|
||||
|
||||
# Post-process
|
||||
regions = _postprocess(
|
||||
outputs,
|
||||
scale=scale,
|
||||
pad_x=pad_x,
|
||||
pad_y=pad_y,
|
||||
orig_w=orig_w,
|
||||
orig_h=orig_h,
|
||||
confidence_threshold=confidence_threshold,
|
||||
max_regions=max_regions,
|
||||
)
|
||||
|
||||
if regions:
|
||||
label_counts: Dict[str, int] = {}
|
||||
for r in regions:
|
||||
label_counts[r.label] = label_counts.get(r.label, 0) + 1
|
||||
logger.info(
|
||||
"PP-DocLayout: %d regions (%s)",
|
||||
len(regions),
|
||||
", ".join(f"{k}: {v}" for k, v in sorted(label_counts.items())),
|
||||
)
|
||||
else:
|
||||
logger.debug("PP-DocLayout: no regions above threshold %.2f", confidence_threshold)
|
||||
|
||||
return regions
|
||||
@@ -120,6 +120,57 @@ def detect_graphic_elements(
|
||||
if img_bgr is None:
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Try PP-DocLayout ONNX first if available
|
||||
# ------------------------------------------------------------------
|
||||
import os
|
||||
backend = os.environ.get("GRAPHIC_DETECT_BACKEND", "auto")
|
||||
if backend in ("doclayout", "auto"):
|
||||
try:
|
||||
from cv_doclayout_detect import detect_layout_regions, is_doclayout_available
|
||||
if is_doclayout_available():
|
||||
regions = detect_layout_regions(img_bgr)
|
||||
if regions:
|
||||
_LABEL_TO_COLOR = {
|
||||
"figure": ("image", "green", _COLOR_HEX.get("green", "#16a34a")),
|
||||
"table": ("image", "blue", _COLOR_HEX.get("blue", "#2563eb")),
|
||||
}
|
||||
converted: List[GraphicElement] = []
|
||||
for r in regions:
|
||||
shape, color_name, color_hex = _LABEL_TO_COLOR.get(
|
||||
r.label,
|
||||
(r.label, "gray", _COLOR_HEX.get("gray", "#6b7280")),
|
||||
)
|
||||
converted.append(GraphicElement(
|
||||
x=r.x,
|
||||
y=r.y,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
area=r.width * r.height,
|
||||
shape=shape,
|
||||
color_name=color_name,
|
||||
color_hex=color_hex,
|
||||
confidence=r.confidence,
|
||||
contour=None,
|
||||
))
|
||||
converted.sort(key=lambda g: g.area, reverse=True)
|
||||
result = converted[:max_elements]
|
||||
if result:
|
||||
shape_counts: Dict[str, int] = {}
|
||||
for g in result:
|
||||
shape_counts[g.shape] = shape_counts.get(g.shape, 0) + 1
|
||||
logger.info(
|
||||
"GraphicDetect (PP-DocLayout): %d elements (%s)",
|
||||
len(result),
|
||||
", ".join(f"{s}: {c}" for s, c in sorted(shape_counts.items())),
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning("PP-DocLayout failed, falling back to OpenCV: %s", e)
|
||||
# ------------------------------------------------------------------
|
||||
# OpenCV fallback (original logic)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
logger.debug("GraphicDetect: image %dx%d, %d word_boxes, %d detected_boxes",
|
||||
|
||||
@@ -48,6 +48,9 @@ email-validator>=2.0.0
|
||||
# DOCX export for reconstruction editor (MIT license)
|
||||
python-docx>=1.1.0
|
||||
|
||||
# ONNX model export and optimization (Apache-2.0)
|
||||
optimum[onnxruntime]>=1.17.0
|
||||
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
|
||||
430
klausur-service/backend/services/trocr_onnx_service.py
Normal file
430
klausur-service/backend/services/trocr_onnx_service.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
TrOCR ONNX Service
|
||||
|
||||
ONNX-optimized inference for TrOCR text recognition.
|
||||
Uses optimum.onnxruntime.ORTModelForVision2Seq for hardware-accelerated
|
||||
inference without requiring PyTorch at runtime.
|
||||
|
||||
Advantages over PyTorch backend:
|
||||
- 2-4x faster inference on CPU
|
||||
- Lower memory footprint (~300 MB vs ~600 MB)
|
||||
- No PyTorch/CUDA dependency at runtime
|
||||
- Apple Silicon (CoreML) and x86 (OpenVINO) acceleration
|
||||
|
||||
Model paths searched (in order):
|
||||
1. TROCR_ONNX_DIR environment variable
|
||||
2. /root/.cache/huggingface/onnx/trocr-base-{printed,handwritten}/ (Docker)
|
||||
3. models/onnx/trocr-base-{printed,handwritten}/ (local dev)
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-use shared types and cache from trocr_service
|
||||
from .trocr_service import (
|
||||
OCRResult,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# {model_key: (processor, model)} — model_key = "printed" | "handwritten"
|
||||
_onnx_models: Dict[str, Any] = {}
|
||||
_onnx_available: Optional[bool] = None
|
||||
_onnx_model_loaded_at: Optional[datetime] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VARIANT_NAMES = {
|
||||
False: "trocr-base-printed",
|
||||
True: "trocr-base-handwritten",
|
||||
}
|
||||
|
||||
# HuggingFace model IDs (used for processor downloads)
|
||||
_HF_MODEL_IDS = {
|
||||
False: "microsoft/trocr-base-printed",
|
||||
True: "microsoft/trocr-base-handwritten",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_onnx_model_dir(handwritten: bool = False) -> Optional[Path]:
|
||||
"""
|
||||
Resolve the directory containing ONNX model files for the given variant.
|
||||
|
||||
Search order:
|
||||
1. TROCR_ONNX_DIR env var (appended with variant name)
|
||||
2. /root/.cache/huggingface/onnx/<variant>/ (Docker)
|
||||
3. models/onnx/<variant>/ (local dev, relative to this file)
|
||||
|
||||
Returns the first directory that exists and contains at least one .onnx file,
|
||||
or None if no valid directory is found.
|
||||
"""
|
||||
variant = _VARIANT_NAMES[handwritten]
|
||||
candidates: List[Path] = []
|
||||
|
||||
# 1. Environment variable
|
||||
env_dir = os.environ.get("TROCR_ONNX_DIR")
|
||||
if env_dir:
|
||||
candidates.append(Path(env_dir) / variant)
|
||||
# Also allow the env var to point directly at a variant dir
|
||||
candidates.append(Path(env_dir))
|
||||
|
||||
# 2. Docker path
|
||||
candidates.append(Path(f"/root/.cache/huggingface/onnx/{variant}"))
|
||||
|
||||
# 3. Local dev path (relative to klausur-service/backend/)
|
||||
backend_dir = Path(__file__).resolve().parent.parent
|
||||
candidates.append(backend_dir / "models" / "onnx" / variant)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate.is_dir():
|
||||
# Check for ONNX files or a model config (optimum stores config.json)
|
||||
onnx_files = list(candidate.glob("*.onnx"))
|
||||
has_config = (candidate / "config.json").exists()
|
||||
if onnx_files or has_config:
|
||||
logger.info(f"ONNX model directory resolved: {candidate}")
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _check_onnx_runtime_available() -> bool:
|
||||
"""Check if onnxruntime and optimum are importable."""
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq # noqa: F401
|
||||
from transformers import TrOCRProcessor # noqa: F401
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.debug(f"ONNX runtime dependencies not available: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_onnx_available(handwritten: bool = False) -> bool:
|
||||
"""
|
||||
Check whether ONNX inference is available for the given variant.
|
||||
|
||||
Returns True only when:
|
||||
- onnxruntime + optimum are installed
|
||||
- A valid model directory with ONNX files exists
|
||||
"""
|
||||
if not _check_onnx_runtime_available():
|
||||
return False
|
||||
return _resolve_onnx_model_dir(handwritten=handwritten) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_onnx_model(handwritten: bool = False):
|
||||
"""
|
||||
Lazy-load ONNX model and processor.
|
||||
|
||||
Returns:
|
||||
Tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _onnx_model_loaded_at
|
||||
|
||||
model_key = "handwritten" if handwritten else "printed"
|
||||
|
||||
if model_key in _onnx_models:
|
||||
return _onnx_models[model_key]
|
||||
|
||||
model_dir = _resolve_onnx_model_dir(handwritten=handwritten)
|
||||
if model_dir is None:
|
||||
logger.warning(
|
||||
f"No ONNX model directory found for variant "
|
||||
f"{'handwritten' if handwritten else 'printed'}"
|
||||
)
|
||||
return None, None
|
||||
|
||||
if not _check_onnx_runtime_available():
|
||||
logger.warning("ONNX runtime dependencies not installed")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
from transformers import TrOCRProcessor
|
||||
|
||||
hf_id = _HF_MODEL_IDS[handwritten]
|
||||
|
||||
logger.info(f"Loading ONNX TrOCR model from {model_dir} (processor: {hf_id})")
|
||||
t0 = time.monotonic()
|
||||
|
||||
# Load processor from HuggingFace (tokenizer + feature extractor)
|
||||
processor = TrOCRProcessor.from_pretrained(hf_id)
|
||||
|
||||
# Load ONNX model from local directory
|
||||
model = ORTModelForVision2Seq.from_pretrained(str(model_dir))
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.info(
|
||||
f"ONNX TrOCR model loaded in {elapsed:.1f}s "
|
||||
f"(variant={model_key}, dir={model_dir})"
|
||||
)
|
||||
|
||||
_onnx_models[model_key] = (processor, model)
|
||||
_onnx_model_loaded_at = datetime.now()
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ONNX TrOCR model ({model_key}): {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_onnx_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload ONNX model at startup for faster first request.
|
||||
|
||||
Call from FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_onnx_model()
|
||||
"""
|
||||
logger.info(f"Preloading ONNX TrOCR model (handwritten={handwritten})...")
|
||||
processor, model = _get_onnx_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
logger.info("ONNX TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("ONNX TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_onnx_model_status() -> Dict[str, Any]:
|
||||
"""Get current ONNX model status information."""
|
||||
runtime_ok = _check_onnx_runtime_available()
|
||||
|
||||
printed_dir = _resolve_onnx_model_dir(handwritten=False)
|
||||
handwritten_dir = _resolve_onnx_model_dir(handwritten=True)
|
||||
|
||||
printed_loaded = "printed" in _onnx_models
|
||||
handwritten_loaded = "handwritten" in _onnx_models
|
||||
|
||||
# Detect ONNX runtime providers
|
||||
providers = []
|
||||
if runtime_ok:
|
||||
try:
|
||||
import onnxruntime
|
||||
providers = onnxruntime.get_available_providers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"backend": "onnx",
|
||||
"runtime_available": runtime_ok,
|
||||
"providers": providers,
|
||||
"printed": {
|
||||
"model_dir": str(printed_dir) if printed_dir else None,
|
||||
"available": printed_dir is not None and runtime_ok,
|
||||
"loaded": printed_loaded,
|
||||
},
|
||||
"handwritten": {
|
||||
"model_dir": str(handwritten_dir) if handwritten_dir else None,
|
||||
"available": handwritten_dir is not None and runtime_ok,
|
||||
"loaded": handwritten_loaded,
|
||||
},
|
||||
"loaded_at": _onnx_model_loaded_at.isoformat() if _onnx_model_loaded_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_trocr_onnx(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR OCR using ONNX backend.
|
||||
|
||||
Mirrors the interface of trocr_service.run_trocr_ocr.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes (PNG, JPEG, etc.)
|
||||
handwritten: Use handwritten model variant
|
||||
split_lines: Split image into text lines before recognition
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence).
|
||||
Returns (None, 0.0) on failure.
|
||||
"""
|
||||
processor, model = _get_onnx_model(handwritten=handwritten)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("ONNX TrOCR model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text: List[str] = []
|
||||
confidences: List[float] = []
|
||||
|
||||
for line_image in lines:
|
||||
# Prepare input — processor returns PyTorch tensors
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
# Generate via ONNX (ORTModelForVision2Seq.generate is compatible)
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(
|
||||
f"ONNX TrOCR extracted {len(text)} chars from {len(lines)} lines"
|
||||
)
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ONNX TrOCR failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_onnx_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced ONNX TrOCR with caching and detailed results.
|
||||
|
||||
Mirrors the interface of trocr_service.run_trocr_ocr_enhanced.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model variant
|
||||
split_lines: Split image into text lines
|
||||
use_cache: Use SHA256-based in-memory cache
|
||||
|
||||
Returns:
|
||||
OCRResult with text, confidence, timing, word boxes, etc.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash,
|
||||
)
|
||||
|
||||
# Run ONNX inference
|
||||
text, confidence = await run_trocr_onnx(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines
|
||||
)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes: List[Dict[str, Any]] = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for word in words:
|
||||
word_conf = min(
|
||||
1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100)
|
||||
)
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0],
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences: List[float] = []
|
||||
if text:
|
||||
for char in text:
|
||||
char_conf = min(
|
||||
1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100)
|
||||
)
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
model_name = (
|
||||
"trocr-base-handwritten-onnx" if handwritten else "trocr-base-printed-onnx"
|
||||
)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model=model_name,
|
||||
has_lora_adapter=False,
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash,
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes,
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -19,6 +19,7 @@ Phase 2 Enhancements:
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
@@ -30,6 +31,11 @@ from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
@@ -221,6 +227,97 @@ def get_model_status() -> Dict[str, Any]:
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async — return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
@@ -230,6 +327,13 @@ async def run_trocr_ocr(
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" — always use PyTorch (original behaviour)
|
||||
- "auto" — try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
@@ -244,65 +348,38 @@ async def run_trocr_ocr(
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
backend = _trocr_backend
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR model not available")
|
||||
return None, 0.0
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
if split_lines:
|
||||
# Split image into lines and process each
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image] # Fallback to full image
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
# Prepare input
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
# Move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
# Decode
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
# TrOCR doesn't provide confidence, estimate based on output
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
# Combine results
|
||||
text = "\n".join(all_text)
|
||||
|
||||
# Average confidence
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
@@ -360,6 +437,22 @@ def _split_into_lines(image) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
@@ -369,6 +462,9 @@ async def run_trocr_ocr_enhanced(
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
@@ -378,6 +474,37 @@ async def run_trocr_ocr_enhanced(
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
@@ -397,8 +524,8 @@ async def run_trocr_ocr_enhanced(
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
|
||||
394
klausur-service/backend/tests/test_doclayout_detect.py
Normal file
394
klausur-service/backend/tests/test_doclayout_detect.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Tests for PP-DocLayout ONNX Document Layout Detection.
|
||||
|
||||
Uses mocking to avoid requiring the actual ONNX model file.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# We patch the module-level globals before importing to ensure clean state
|
||||
# in tests that check "no model" behaviour.
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fresh_import():
|
||||
"""Re-import cv_doclayout_detect with reset globals."""
|
||||
import cv_doclayout_detect as mod
|
||||
# Reset module-level caching so each test starts clean
|
||||
mod._onnx_session = None
|
||||
mod._model_path = None
|
||||
mod._load_attempted = False
|
||||
mod._load_error = None
|
||||
return mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. is_doclayout_available — no model present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsDoclayoutAvailableNoModel:
|
||||
def test_returns_false_when_no_onnx_file(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
assert mod.is_doclayout_available() is False
|
||||
|
||||
def test_returns_false_when_onnxruntime_missing(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value="/fake/model.onnx"):
|
||||
with patch.dict("sys.modules", {"onnxruntime": None}):
|
||||
# Force ImportError by making import fail
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "onnxruntime":
|
||||
raise ImportError("no onnxruntime")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
assert mod.is_doclayout_available() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. LayoutRegion dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLayoutRegionDataclass:
|
||||
def test_basic_creation(self):
|
||||
from cv_doclayout_detect import LayoutRegion
|
||||
region = LayoutRegion(
|
||||
x=10, y=20, width=100, height=200,
|
||||
label="figure", confidence=0.95, label_index=1,
|
||||
)
|
||||
assert region.x == 10
|
||||
assert region.y == 20
|
||||
assert region.width == 100
|
||||
assert region.height == 200
|
||||
assert region.label == "figure"
|
||||
assert region.confidence == 0.95
|
||||
assert region.label_index == 1
|
||||
|
||||
def test_all_fields_present(self):
|
||||
from cv_doclayout_detect import LayoutRegion
|
||||
import dataclasses
|
||||
field_names = {f.name for f in dataclasses.fields(LayoutRegion)}
|
||||
expected = {"x", "y", "width", "height", "label", "confidence", "label_index"}
|
||||
assert field_names == expected
|
||||
|
||||
def test_different_labels(self):
|
||||
from cv_doclayout_detect import LayoutRegion, DOCLAYOUT_CLASSES
|
||||
for idx, label in enumerate(DOCLAYOUT_CLASSES):
|
||||
region = LayoutRegion(
|
||||
x=0, y=0, width=50, height=50,
|
||||
label=label, confidence=0.8, label_index=idx,
|
||||
)
|
||||
assert region.label == label
|
||||
assert region.label_index == idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. detect_layout_regions — no model available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDetectLayoutRegionsNoModel:
|
||||
def test_returns_empty_list_when_model_unavailable(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
img = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
result = mod.detect_layout_regions(img)
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_list_for_none_image(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
result = mod.detect_layout_regions(None)
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_list_for_empty_image(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
img = np.array([], dtype=np.uint8)
|
||||
result = mod.detect_layout_regions(img)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Preprocessing — tensor shape verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPreprocessingShapes:
|
||||
def test_square_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (800, 800, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
assert tensor.dtype == np.float32
|
||||
assert 0.0 <= tensor.min()
|
||||
assert tensor.max() <= 1.0
|
||||
|
||||
def test_landscape_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (600, 1200, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
# Landscape: scale by width, should have vertical padding
|
||||
expected_scale = 800 / 1200
|
||||
assert abs(scale - expected_scale) < 1e-5
|
||||
assert pad_y > 0 # vertical padding expected
|
||||
|
||||
def test_portrait_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (1200, 600, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
# Portrait: scale by height, should have horizontal padding
|
||||
expected_scale = 800 / 1200
|
||||
assert abs(scale - expected_scale) < 1e-5
|
||||
assert pad_x > 0 # horizontal padding expected
|
||||
|
||||
def test_small_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (100, 200, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
|
||||
def test_typical_scan_a4(self):
|
||||
"""A4 scan at 300dpi: roughly 2480x3508 pixels."""
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (3508, 2480, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
|
||||
def test_values_normalized(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
# All white image
|
||||
img = np.full((400, 400, 3), 255, dtype=np.uint8)
|
||||
tensor, _, _, _ = preprocess_image(img)
|
||||
# The padded region is 114/255 ≈ 0.447, the image region is 1.0
|
||||
assert tensor.max() <= 1.0
|
||||
assert tensor.min() >= 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. NMS logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNmsLogic:
|
||||
def test_empty_input(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([]).reshape(0, 4)
|
||||
scores = np.array([])
|
||||
assert nms(boxes, scores) == []
|
||||
|
||||
def test_single_box(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([[10, 10, 100, 100]], dtype=np.float32)
|
||||
scores = np.array([0.9])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert kept == [0]
|
||||
|
||||
def test_non_overlapping_boxes(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([
|
||||
[0, 0, 50, 50],
|
||||
[200, 200, 300, 300],
|
||||
[400, 400, 500, 500],
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.9, 0.8, 0.7])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert len(kept) == 3
|
||||
assert set(kept) == {0, 1, 2}
|
||||
|
||||
def test_overlapping_boxes_suppressed(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Two boxes that heavily overlap
|
||||
boxes = np.array([
|
||||
[10, 10, 110, 110], # 100x100
|
||||
[15, 15, 115, 115], # 100x100, heavily overlapping with first
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.95, 0.80])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
# Only the higher-confidence box should survive
|
||||
assert kept == [0]
|
||||
|
||||
def test_partially_overlapping_boxes_kept(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Two boxes that overlap ~25% (below 0.5 threshold)
|
||||
boxes = np.array([
|
||||
[0, 0, 100, 100], # 100x100
|
||||
[75, 0, 175, 100], # 100x100, overlap 25x100 = 2500
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.9, 0.8])
|
||||
# IoU = 2500 / (10000 + 10000 - 2500) = 2500/17500 ≈ 0.143
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert len(kept) == 2
|
||||
|
||||
def test_nms_respects_score_ordering(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Three overlapping boxes — highest confidence should be kept first
|
||||
boxes = np.array([
|
||||
[10, 10, 110, 110],
|
||||
[12, 12, 112, 112],
|
||||
[14, 14, 114, 114],
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.5, 0.9, 0.7])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
# Index 1 has highest score → kept first, suppresses 0 and 2
|
||||
assert kept[0] == 1
|
||||
|
||||
def test_iou_computation(self):
|
||||
from cv_doclayout_detect import _compute_iou
|
||||
box_a = np.array([0, 0, 100, 100], dtype=np.float32)
|
||||
box_b = np.array([0, 0, 100, 100], dtype=np.float32)
|
||||
assert abs(_compute_iou(box_a, box_b) - 1.0) < 1e-5
|
||||
|
||||
box_c = np.array([200, 200, 300, 300], dtype=np.float32)
|
||||
assert _compute_iou(box_a, box_c) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DOCLAYOUT_CLASSES verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDoclayoutClasses:
|
||||
def test_correct_class_list(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
expected = [
|
||||
"table", "figure", "title", "text", "list",
|
||||
"header", "footer", "equation", "reference", "abstract",
|
||||
]
|
||||
assert DOCLAYOUT_CLASSES == expected
|
||||
|
||||
def test_class_count(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
assert len(DOCLAYOUT_CLASSES) == 10
|
||||
|
||||
def test_no_duplicates(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
assert len(DOCLAYOUT_CLASSES) == len(set(DOCLAYOUT_CLASSES))
|
||||
|
||||
def test_all_lowercase(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
for cls in DOCLAYOUT_CLASSES:
|
||||
assert cls == cls.lower(), f"Class '{cls}' should be lowercase"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. get_doclayout_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDoclayoutStatus:
|
||||
def test_status_when_unavailable(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
status = mod.get_doclayout_status()
|
||||
assert status["available"] is False
|
||||
assert status["model_path"] is None
|
||||
assert status["load_error"] is not None
|
||||
assert status["classes"] == mod.DOCLAYOUT_CLASSES
|
||||
assert status["class_count"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Post-processing with mocked ONNX outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPostprocessing:
|
||||
def test_single_tensor_format_6cols(self):
|
||||
"""Test parsing of (1, N, 6) output format: x1,y1,x2,y2,score,class."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
# One detection: figure at (100,100)-(300,300) in 800x800 space
|
||||
raw = np.array([[[100, 100, 300, 300, 0.92, 1]]], dtype=np.float32)
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "figure"
|
||||
assert regions[0].confidence >= 0.9
|
||||
|
||||
def test_three_tensor_format(self):
|
||||
"""Test parsing of 3-tensor output: boxes, scores, class_ids."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
boxes = np.array([[50, 50, 200, 150]], dtype=np.float32)
|
||||
scores = np.array([0.88], dtype=np.float32)
|
||||
class_ids = np.array([0], dtype=np.float32) # table
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[boxes, scores, class_ids],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "table"
|
||||
|
||||
def test_confidence_filtering(self):
|
||||
"""Detections below threshold should be excluded."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
raw = np.array([
|
||||
[100, 100, 200, 200, 0.9, 1], # above threshold
|
||||
[300, 300, 400, 400, 0.3, 2], # below threshold
|
||||
], dtype=np.float32).reshape(1, 2, 6)
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "figure"
|
||||
|
||||
def test_coordinate_scaling(self):
|
||||
"""Verify coordinates are correctly scaled back to original image."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
# Image was 1600x1200, scaled to fit 800x800 → scale=0.5, pad_y offset
|
||||
scale = 800 / 1600 # 0.5
|
||||
pad_x = 0
|
||||
pad_y = (800 - int(1200 * scale)) // 2 # (800-600)//2 = 100
|
||||
|
||||
# Detection in 800x800 space at (100, 200) to (300, 400)
|
||||
raw = np.array([[[100, 200, 300, 400, 0.95, 0]]], dtype=np.float32)
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=scale, pad_x=pad_x, pad_y=pad_y,
|
||||
orig_w=1600, orig_h=1200,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
r = regions[0]
|
||||
# x1 = (100 - 0) / 0.5 = 200
|
||||
assert r.x == 200
|
||||
# y1 = (200 - 100) / 0.5 = 200
|
||||
assert r.y == 200
|
||||
|
||||
def test_empty_output(self):
|
||||
from cv_doclayout_detect import _postprocess
|
||||
raw = np.array([]).reshape(1, 0, 6).astype(np.float32)
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert regions == []
|
||||
339
klausur-service/backend/tests/test_trocr_onnx.py
Normal file
339
klausur-service/backend/tests/test_trocr_onnx.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Tests for TrOCR ONNX service.
|
||||
|
||||
All tests use mocking — no actual ONNX model files required.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _services_path():
|
||||
"""Return absolute path to the services/ directory."""
|
||||
return Path(__file__).resolve().parent.parent / "services"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: is_onnx_available — no models on disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOnnxAvailableNoModels:
|
||||
"""When no ONNX files exist on disk, is_onnx_available must return False."""
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
)
|
||||
def test_is_onnx_available_no_models(self, mock_resolve, mock_runtime):
|
||||
from services.trocr_onnx_service import is_onnx_available
|
||||
|
||||
assert is_onnx_available(handwritten=False) is False
|
||||
assert is_onnx_available(handwritten=True) is False
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=False,
|
||||
)
|
||||
def test_is_onnx_available_no_runtime(self, mock_runtime):
|
||||
"""Even if model dirs existed, missing runtime → False."""
|
||||
from services.trocr_onnx_service import is_onnx_available
|
||||
|
||||
assert is_onnx_available(handwritten=False) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: get_onnx_model_status — not available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxModelStatusNotAvailable:
|
||||
"""Status dict when ONNX is not loaded."""
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
)
|
||||
def test_onnx_model_status_not_available(self, mock_resolve, mock_runtime):
|
||||
from services.trocr_onnx_service import get_onnx_model_status
|
||||
|
||||
# Clear any cached models from prior tests
|
||||
import services.trocr_onnx_service as mod
|
||||
mod._onnx_models.clear()
|
||||
mod._onnx_model_loaded_at = None
|
||||
|
||||
status = get_onnx_model_status()
|
||||
|
||||
assert status["backend"] == "onnx"
|
||||
assert status["runtime_available"] is False
|
||||
assert status["printed"]["available"] is False
|
||||
assert status["printed"]["loaded"] is False
|
||||
assert status["printed"]["model_dir"] is None
|
||||
assert status["handwritten"]["available"] is False
|
||||
assert status["handwritten"]["loaded"] is False
|
||||
assert status["handwritten"]["model_dir"] is None
|
||||
assert status["loaded_at"] is None
|
||||
assert status["providers"] == []
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=True,
|
||||
)
|
||||
def test_onnx_model_status_runtime_but_no_files(self, mock_runtime):
|
||||
"""Runtime installed but no model files on disk."""
|
||||
from services.trocr_onnx_service import get_onnx_model_status
|
||||
import services.trocr_onnx_service as mod
|
||||
mod._onnx_models.clear()
|
||||
mod._onnx_model_loaded_at = None
|
||||
|
||||
with patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
), patch("services.trocr_onnx_service.onnxruntime", create=True) as mock_ort:
|
||||
# Mock onnxruntime import inside get_onnx_model_status
|
||||
mock_ort_module = MagicMock()
|
||||
mock_ort_module.get_available_providers.return_value = [
|
||||
"CPUExecutionProvider"
|
||||
]
|
||||
with patch.dict("sys.modules", {"onnxruntime": mock_ort_module}):
|
||||
status = get_onnx_model_status()
|
||||
|
||||
assert status["runtime_available"] is True
|
||||
assert status["printed"]["available"] is False
|
||||
assert status["handwritten"]["available"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: path resolution logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxModelPaths:
|
||||
"""Verify the path resolution order."""
|
||||
|
||||
def test_env_var_path_takes_precedence(self, tmp_path):
|
||||
"""TROCR_ONNX_DIR env var should be checked first."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
# Create a fake model dir with a config.json
|
||||
model_dir = tmp_path / "trocr-base-printed"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "config.json").write_text("{}")
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
assert result is not None
|
||||
assert result == model_dir
|
||||
|
||||
def test_env_var_handwritten_variant(self, tmp_path):
|
||||
"""TROCR_ONNX_DIR works for handwritten variant too."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
model_dir = tmp_path / "trocr-base-handwritten"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "encoder_model.onnx").write_bytes(b"fake")
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=True)
|
||||
|
||||
assert result is not None
|
||||
assert result == model_dir
|
||||
|
||||
def test_returns_none_when_no_dirs_exist(self):
|
||||
"""When none of the candidate dirs exist, return None."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove TROCR_ONNX_DIR if set
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
# The Docker and local-dev paths almost certainly don't contain
|
||||
# real ONNX models on the test machine.
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
# Could be None or a real dir if someone has models locally.
|
||||
# We just verify it doesn't raise.
|
||||
assert result is None or isinstance(result, Path)
|
||||
|
||||
def test_docker_path_checked(self, tmp_path):
|
||||
"""Docker path /root/.cache/huggingface/onnx/ is in candidate list."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
docker_path = Path("/root/.cache/huggingface/onnx/trocr-base-printed")
|
||||
|
||||
# We can't create that path in tests, but we can verify the logic
|
||||
# by checking that when env var points nowhere and docker path
|
||||
# doesn't exist, the function still runs without error.
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
# Just verify it doesn't crash
|
||||
_resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
def test_local_dev_path_relative_to_backend(self, tmp_path):
|
||||
"""Local dev path is models/onnx/<variant>/ relative to backend dir."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
# The backend dir is derived from __file__, so we can't easily
|
||||
# redirect it. Instead, verify the function signature and return type.
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
# May or may not find models — just verify the return type
|
||||
assert result is None or isinstance(result, Path)
|
||||
|
||||
def test_dir_without_onnx_files_is_skipped(self, tmp_path):
|
||||
"""A directory that exists but has no .onnx files or config.json is skipped."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
empty_dir = tmp_path / "trocr-base-printed"
|
||||
empty_dir.mkdir(parents=True)
|
||||
# No .onnx files, no config.json
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
# The env-var candidate exists as a dir but has no model files,
|
||||
# so it should be skipped. Result depends on whether other
|
||||
# candidate dirs have models.
|
||||
if result is not None:
|
||||
# If found elsewhere, that's fine — just not the empty dir
|
||||
assert result != empty_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: fallback to PyTorch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxFallbackToPytorch:
|
||||
"""When ONNX is unavailable, the routing layer in trocr_service falls back."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_onnx_fallback_to_pytorch(self):
|
||||
"""With backend='auto' and ONNX unavailable, PyTorch path is used."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "auto"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
return_value=None,
|
||||
) as mock_onnx, patch(
|
||||
"services.trocr_service._run_pytorch_ocr",
|
||||
return_value=("pytorch result", 0.9),
|
||||
) as mock_pytorch:
|
||||
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
mock_pytorch.assert_called_once()
|
||||
assert text == "pytorch result"
|
||||
assert conf == 0.9
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_onnx_backend_forced(self):
|
||||
"""With backend='onnx', failure raises RuntimeError."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "onnx"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="ONNX backend.*unavailable"):
|
||||
await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pytorch_backend_skips_onnx(self):
|
||||
"""With backend='pytorch', ONNX is never attempted."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "pytorch"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
) as mock_onnx, patch(
|
||||
"services.trocr_service._run_pytorch_ocr",
|
||||
return_value=("pytorch only", 0.85),
|
||||
) as mock_pytorch:
|
||||
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
mock_onnx.assert_not_called()
|
||||
mock_pytorch.assert_called_once()
|
||||
assert text == "pytorch only"
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: TROCR_BACKEND env var handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackendConfig:
|
||||
"""TROCR_BACKEND environment variable handling."""
|
||||
|
||||
def test_default_backend_is_auto(self):
|
||||
"""Without env var, backend defaults to 'auto'."""
|
||||
import services.trocr_service as svc
|
||||
# The module reads the env var at import time; in a fresh import
|
||||
# with no TROCR_BACKEND set, it should default to "auto".
|
||||
# We test the get_active_backend function instead.
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "auto"
|
||||
assert svc.get_active_backend() == "auto"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_backend_pytorch(self):
|
||||
"""TROCR_BACKEND=pytorch is reflected in get_active_backend."""
|
||||
import services.trocr_service as svc
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "pytorch"
|
||||
assert svc.get_active_backend() == "pytorch"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_backend_onnx(self):
|
||||
"""TROCR_BACKEND=onnx is reflected in get_active_backend."""
|
||||
import services.trocr_service as svc
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "onnx"
|
||||
assert svc.get_active_backend() == "onnx"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_env_var_read_at_import(self):
|
||||
"""Module reads TROCR_BACKEND from environment."""
|
||||
# We can't easily re-import, but we can verify the variable exists
|
||||
import services.trocr_service as svc
|
||||
assert hasattr(svc, "_trocr_backend")
|
||||
assert svc._trocr_backend in ("auto", "pytorch", "onnx")
|
||||
@@ -70,6 +70,7 @@ nav:
|
||||
- BYOEH Developer Guide: services/klausur-service/BYOEH-Developer-Guide.md
|
||||
- NiBiS Pipeline: services/klausur-service/NiBiS-Ingestion-Pipeline.md
|
||||
- OCR Pipeline: services/klausur-service/OCR-Pipeline.md
|
||||
- TrOCR ONNX: services/klausur-service/TrOCR-ONNX.md
|
||||
- OCR Labeling: services/klausur-service/OCR-Labeling-Spec.md
|
||||
- OCR Vergleich: services/klausur-service/OCR-Compare.md
|
||||
- RAG Admin: services/klausur-service/RAG-Admin-Spec.md
|
||||
|
||||
546
scripts/export-doclayout-onnx.py
Executable file
546
scripts/export-doclayout-onnx.py
Executable file
@@ -0,0 +1,546 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
PP-DocLayout ONNX Export — exports PP-DocLayout model to ONNX for document layout detection.
|
||||
|
||||
PP-DocLayout detects: table, figure, title, text, list regions on document pages.
|
||||
Since PaddlePaddle doesn't work natively on ARM Mac, this script either:
|
||||
1. Downloads a pre-exported ONNX model
|
||||
2. Uses Docker (linux/amd64) for the conversion
|
||||
|
||||
Usage:
|
||||
python scripts/export-doclayout-onnx.py
|
||||
python scripts/export-doclayout-onnx.py --method docker
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger("export-doclayout")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 10 PP-DocLayout class labels in standard order
|
||||
CLASS_LABELS = [
|
||||
"table",
|
||||
"figure",
|
||||
"title",
|
||||
"text",
|
||||
"list",
|
||||
"header",
|
||||
"footer",
|
||||
"equation",
|
||||
"reference",
|
||||
"abstract",
|
||||
]
|
||||
|
||||
# Known download sources for pre-exported ONNX models.
|
||||
# Ordered by preference — first successful download wins.
|
||||
DOWNLOAD_SOURCES = [
|
||||
{
|
||||
"name": "PaddleOCR PP-DocLayout (ppyoloe_plus_sod, HuggingFace)",
|
||||
"url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx",
|
||||
"filename": "model.onnx",
|
||||
"sha256": None, # populated once a known-good hash is available
|
||||
},
|
||||
{
|
||||
"name": "PaddleOCR PP-DocLayout (RapidOCR mirror)",
|
||||
"url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx",
|
||||
"filename": "model.onnx",
|
||||
"sha256": None,
|
||||
},
|
||||
]
|
||||
|
||||
# Paddle inference model URLs (for Docker-based conversion).
|
||||
PADDLE_MODEL_URL = (
|
||||
"https://paddleocr.bj.bcebos.com/PP-DocLayout/PP-DocLayout_plus.tar"
|
||||
)
|
||||
|
||||
# Expected input shape for the model (batch, channels, height, width).
|
||||
MODEL_INPUT_SHAPE = (1, 3, 800, 800)
|
||||
|
||||
# Docker image name used for conversion.
|
||||
DOCKER_IMAGE_TAG = "breakpilot/paddle2onnx-converter:latest"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def sha256_file(path: Path) -> str:
|
||||
"""Compute SHA-256 hex digest for a file."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(1 << 20), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def download_file(url: str, dest: Path, desc: str = "") -> bool:
|
||||
"""Download a file with progress reporting. Returns True on success."""
|
||||
label = desc or url.split("/")[-1]
|
||||
log.info("Downloading %s ...", label)
|
||||
log.info(" URL: %s", url)
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "breakpilot-export/1.0"})
|
||||
with urllib.request.urlopen(req, timeout=120) as resp:
|
||||
total = resp.headers.get("Content-Length")
|
||||
total = int(total) if total else None
|
||||
downloaded = 0
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(dest, "wb") as f:
|
||||
while True:
|
||||
chunk = resp.read(1 << 18) # 256 KB
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total:
|
||||
pct = downloaded * 100 / total
|
||||
mb = downloaded / (1 << 20)
|
||||
total_mb = total / (1 << 20)
|
||||
print(
|
||||
f"\r {mb:.1f}/{total_mb:.1f} MB ({pct:.0f}%)",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
if total:
|
||||
print() # newline after progress
|
||||
|
||||
size_mb = dest.stat().st_size / (1 << 20)
|
||||
log.info(" Downloaded %.1f MB -> %s", size_mb, dest)
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
log.warning(" Download failed: %s", exc)
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
return False
|
||||
|
||||
|
||||
def verify_onnx(model_path: Path) -> bool:
|
||||
"""Load the ONNX model with onnxruntime, run a dummy inference, check outputs."""
|
||||
log.info("Verifying ONNX model: %s", model_path)
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
log.error("numpy is required for verification: pip install numpy")
|
||||
return False
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
log.error("onnxruntime is required for verification: pip install onnxruntime")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load the model
|
||||
opts = ort.SessionOptions()
|
||||
opts.log_severity_level = 3 # suppress verbose logs
|
||||
session = ort.InferenceSession(str(model_path), sess_options=opts)
|
||||
|
||||
# Inspect inputs
|
||||
inputs = session.get_inputs()
|
||||
log.info(" Model inputs:")
|
||||
for inp in inputs:
|
||||
log.info(" %s: shape=%s dtype=%s", inp.name, inp.shape, inp.type)
|
||||
|
||||
# Inspect outputs
|
||||
outputs = session.get_outputs()
|
||||
log.info(" Model outputs:")
|
||||
for out in outputs:
|
||||
log.info(" %s: shape=%s dtype=%s", out.name, out.shape, out.type)
|
||||
|
||||
# Build dummy input — use the first input's name and expected shape.
|
||||
input_name = inputs[0].name
|
||||
input_shape = inputs[0].shape
|
||||
|
||||
# Replace dynamic dims (strings or None) with concrete sizes.
|
||||
concrete_shape = []
|
||||
for i, dim in enumerate(input_shape):
|
||||
if isinstance(dim, (int,)) and dim > 0:
|
||||
concrete_shape.append(dim)
|
||||
elif i == 0:
|
||||
concrete_shape.append(1) # batch
|
||||
elif i == 1:
|
||||
concrete_shape.append(3) # channels
|
||||
else:
|
||||
concrete_shape.append(800) # spatial
|
||||
concrete_shape = tuple(concrete_shape)
|
||||
|
||||
# Fallback if shape looks wrong — use standard MODEL_INPUT_SHAPE.
|
||||
if len(concrete_shape) != 4:
|
||||
concrete_shape = MODEL_INPUT_SHAPE
|
||||
|
||||
log.info(" Running dummy inference with shape %s ...", concrete_shape)
|
||||
dummy = np.random.randn(*concrete_shape).astype(np.float32)
|
||||
result = session.run(None, {input_name: dummy})
|
||||
|
||||
log.info(" Inference succeeded — %d output tensors:", len(result))
|
||||
for i, r in enumerate(result):
|
||||
arr = np.asarray(r)
|
||||
log.info(" output[%d]: shape=%s dtype=%s", i, arr.shape, arr.dtype)
|
||||
|
||||
# Basic sanity checks
|
||||
if len(result) == 0:
|
||||
log.error(" Model produced no outputs!")
|
||||
return False
|
||||
|
||||
# Check for at least one output with a bounding-box-like shape (N, 4) or
|
||||
# a detection-like structure. Be lenient — different ONNX exports vary.
|
||||
has_plausible_output = False
|
||||
for r in result:
|
||||
arr = np.asarray(r)
|
||||
# Common detection output shapes: (1, N, 6), (N, 4), (N, 6), (1, N, 5+C), etc.
|
||||
if arr.ndim >= 2 and any(d >= 4 for d in arr.shape):
|
||||
has_plausible_output = True
|
||||
# Some models output (N,) labels or scores
|
||||
if arr.ndim >= 1 and arr.size > 0:
|
||||
has_plausible_output = True
|
||||
|
||||
if has_plausible_output:
|
||||
log.info(" Verification PASSED")
|
||||
return True
|
||||
else:
|
||||
log.warning(" Output shapes look unexpected, but model loaded OK.")
|
||||
log.warning(" Treating as PASSED (shapes may differ by export variant).")
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
log.error(" Verification FAILED: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Method: Download
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def try_download(output_dir: Path) -> bool:
|
||||
"""Attempt to download a pre-exported ONNX model. Returns True on success."""
|
||||
log.info("=== Method: DOWNLOAD ===")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = output_dir / "model.onnx"
|
||||
|
||||
for source in DOWNLOAD_SOURCES:
|
||||
log.info("Trying source: %s", source["name"])
|
||||
tmp_path = output_dir / f".{source['filename']}.tmp"
|
||||
|
||||
if not download_file(source["url"], tmp_path, desc=source["name"]):
|
||||
continue
|
||||
|
||||
# Check SHA-256 if known.
|
||||
if source["sha256"]:
|
||||
actual_hash = sha256_file(tmp_path)
|
||||
if actual_hash != source["sha256"]:
|
||||
log.warning(
|
||||
" SHA-256 mismatch: expected %s, got %s",
|
||||
source["sha256"],
|
||||
actual_hash,
|
||||
)
|
||||
tmp_path.unlink()
|
||||
continue
|
||||
|
||||
# Basic sanity: file should be > 1 MB (a real ONNX model, not an error page).
|
||||
size = tmp_path.stat().st_size
|
||||
if size < 1 << 20:
|
||||
log.warning(" File too small (%.1f KB) — probably not a valid model.", size / 1024)
|
||||
tmp_path.unlink()
|
||||
continue
|
||||
|
||||
# Move into place.
|
||||
shutil.move(str(tmp_path), str(model_path))
|
||||
log.info("Model saved to %s (%.1f MB)", model_path, model_path.stat().st_size / (1 << 20))
|
||||
return True
|
||||
|
||||
log.warning("All download sources failed.")
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Method: Docker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCKERFILE_CONTENT = r"""
|
||||
FROM --platform=linux/amd64 python:3.11-slim
|
||||
|
||||
RUN pip install --no-cache-dir \
|
||||
paddlepaddle==3.0.0 \
|
||||
paddle2onnx==1.3.1 \
|
||||
onnx==1.17.0 \
|
||||
requests
|
||||
|
||||
WORKDIR /work
|
||||
|
||||
# Download + extract the PP-DocLayout Paddle inference model.
|
||||
RUN python3 -c "
|
||||
import urllib.request, tarfile, os
|
||||
url = 'PADDLE_MODEL_URL_PLACEHOLDER'
|
||||
print(f'Downloading {url} ...')
|
||||
dest = '/work/pp_doclayout.tar'
|
||||
urllib.request.urlretrieve(url, dest)
|
||||
print('Extracting ...')
|
||||
with tarfile.open(dest) as t:
|
||||
t.extractall('/work/paddle_model')
|
||||
os.remove(dest)
|
||||
# List what we extracted
|
||||
for root, dirs, files in os.walk('/work/paddle_model'):
|
||||
for f in files:
|
||||
fp = os.path.join(root, f)
|
||||
sz = os.path.getsize(fp)
|
||||
print(f' {fp} ({sz} bytes)')
|
||||
"
|
||||
|
||||
# Convert Paddle model to ONNX.
|
||||
# paddle2onnx expects model_dir with model.pdmodel + model.pdiparams
|
||||
RUN python3 -c "
|
||||
import os, glob, subprocess
|
||||
|
||||
# Find the inference model files
|
||||
model_dir = '/work/paddle_model'
|
||||
pdmodel_files = glob.glob(os.path.join(model_dir, '**', '*.pdmodel'), recursive=True)
|
||||
pdiparams_files = glob.glob(os.path.join(model_dir, '**', '*.pdiparams'), recursive=True)
|
||||
|
||||
if not pdmodel_files:
|
||||
raise FileNotFoundError('No .pdmodel file found in extracted archive')
|
||||
|
||||
pdmodel = pdmodel_files[0]
|
||||
pdiparams = pdiparams_files[0] if pdiparams_files else None
|
||||
model_dir_actual = os.path.dirname(pdmodel)
|
||||
pdmodel_name = os.path.basename(pdmodel).replace('.pdmodel', '')
|
||||
|
||||
print(f'Found model: {pdmodel}')
|
||||
print(f'Found params: {pdiparams}')
|
||||
print(f'Model dir: {model_dir_actual}')
|
||||
print(f'Model name prefix: {pdmodel_name}')
|
||||
|
||||
cmd = [
|
||||
'paddle2onnx',
|
||||
'--model_dir', model_dir_actual,
|
||||
'--model_filename', os.path.basename(pdmodel),
|
||||
]
|
||||
if pdiparams:
|
||||
cmd += ['--params_filename', os.path.basename(pdiparams)]
|
||||
cmd += [
|
||||
'--save_file', '/work/output/model.onnx',
|
||||
'--opset_version', '14',
|
||||
'--enable_onnx_checker', 'True',
|
||||
]
|
||||
|
||||
os.makedirs('/work/output', exist_ok=True)
|
||||
print(f'Running: {\" \".join(cmd)}')
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
out_size = os.path.getsize('/work/output/model.onnx')
|
||||
print(f'Conversion done: /work/output/model.onnx ({out_size} bytes)')
|
||||
"
|
||||
|
||||
CMD ["cp", "-v", "/work/output/model.onnx", "/output/model.onnx"]
|
||||
""".replace(
|
||||
"PADDLE_MODEL_URL_PLACEHOLDER", PADDLE_MODEL_URL
|
||||
)
|
||||
|
||||
|
||||
def try_docker(output_dir: Path) -> bool:
|
||||
"""Build a Docker image to convert the Paddle model to ONNX. Returns True on success."""
|
||||
log.info("=== Method: DOCKER (linux/amd64) ===")
|
||||
|
||||
# Check Docker is available.
|
||||
docker_bin = shutil.which("docker") or "/usr/local/bin/docker"
|
||||
try:
|
||||
subprocess.run(
|
||||
[docker_bin, "version"],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=15,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc:
|
||||
log.error("Docker is not available: %s", exc)
|
||||
return False
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="doclayout-export-") as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# Write Dockerfile.
|
||||
dockerfile_path = tmpdir / "Dockerfile"
|
||||
dockerfile_path.write_text(DOCKERFILE_CONTENT)
|
||||
log.info("Wrote Dockerfile to %s", dockerfile_path)
|
||||
|
||||
# Build image.
|
||||
log.info("Building Docker image (this downloads ~2 GB, may take a while) ...")
|
||||
build_cmd = [
|
||||
docker_bin, "build",
|
||||
"--platform", "linux/amd64",
|
||||
"-t", DOCKER_IMAGE_TAG,
|
||||
"-f", str(dockerfile_path),
|
||||
str(tmpdir),
|
||||
]
|
||||
log.info(" %s", " ".join(build_cmd))
|
||||
build_result = subprocess.run(
|
||||
build_cmd,
|
||||
capture_output=False, # stream output to terminal
|
||||
timeout=1200, # 20 min
|
||||
)
|
||||
if build_result.returncode != 0:
|
||||
log.error("Docker build failed (exit code %d).", build_result.returncode)
|
||||
return False
|
||||
|
||||
# Run container — mount output_dir as /output, the CMD copies model.onnx there.
|
||||
log.info("Running conversion container ...")
|
||||
run_cmd = [
|
||||
docker_bin, "run",
|
||||
"--rm",
|
||||
"--platform", "linux/amd64",
|
||||
"-v", f"{output_dir.resolve()}:/output",
|
||||
DOCKER_IMAGE_TAG,
|
||||
]
|
||||
log.info(" %s", " ".join(run_cmd))
|
||||
run_result = subprocess.run(
|
||||
run_cmd,
|
||||
capture_output=False,
|
||||
timeout=300,
|
||||
)
|
||||
if run_result.returncode != 0:
|
||||
log.error("Docker run failed (exit code %d).", run_result.returncode)
|
||||
return False
|
||||
|
||||
model_path = output_dir / "model.onnx"
|
||||
if model_path.exists():
|
||||
size_mb = model_path.stat().st_size / (1 << 20)
|
||||
log.info("Model exported: %s (%.1f MB)", model_path, size_mb)
|
||||
return True
|
||||
else:
|
||||
log.error("Expected output file not found: %s", model_path)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def write_metadata(output_dir: Path, method: str) -> None:
|
||||
"""Write a metadata JSON next to the model for provenance tracking."""
|
||||
model_path = output_dir / "model.onnx"
|
||||
if not model_path.exists():
|
||||
return
|
||||
|
||||
meta = {
|
||||
"model": "PP-DocLayout",
|
||||
"format": "ONNX",
|
||||
"export_method": method,
|
||||
"class_labels": CLASS_LABELS,
|
||||
"input_shape": list(MODEL_INPUT_SHAPE),
|
||||
"file_size_bytes": model_path.stat().st_size,
|
||||
"sha256": sha256_file(model_path),
|
||||
}
|
||||
meta_path = output_dir / "metadata.json"
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
log.info("Metadata written to %s", meta_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export PP-DocLayout model to ONNX for document layout detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("models/onnx/pp-doclayout"),
|
||||
help="Directory for the exported ONNX model (default: models/onnx/pp-doclayout/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
choices=["auto", "download", "docker"],
|
||||
default="auto",
|
||||
help="Export method: auto (try download then docker), download, or docker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-verify",
|
||||
action="store_true",
|
||||
help="Skip ONNX model verification after export.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir: Path = args.output_dir
|
||||
model_path = output_dir / "model.onnx"
|
||||
|
||||
# Check if model already exists.
|
||||
if model_path.exists():
|
||||
size_mb = model_path.stat().st_size / (1 << 20)
|
||||
log.info("Model already exists: %s (%.1f MB)", model_path, size_mb)
|
||||
log.info("Delete it first if you want to re-export.")
|
||||
if not args.skip_verify:
|
||||
if not verify_onnx(model_path):
|
||||
log.error("Existing model failed verification!")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
success = False
|
||||
used_method = None
|
||||
|
||||
if args.method in ("auto", "download"):
|
||||
success = try_download(output_dir)
|
||||
if success:
|
||||
used_method = "download"
|
||||
|
||||
if not success and args.method in ("auto", "docker"):
|
||||
success = try_docker(output_dir)
|
||||
if success:
|
||||
used_method = "docker"
|
||||
|
||||
if not success:
|
||||
log.error("All export methods failed.")
|
||||
if args.method == "download":
|
||||
log.info("Hint: try --method docker to convert via Docker (linux/amd64).")
|
||||
elif args.method == "docker":
|
||||
log.info("Hint: ensure Docker is running and has internet access.")
|
||||
else:
|
||||
log.info("Hint: check your internet connection and Docker installation.")
|
||||
return 1
|
||||
|
||||
# Write metadata.
|
||||
write_metadata(output_dir, used_method)
|
||||
|
||||
# Verify.
|
||||
if not args.skip_verify:
|
||||
if not verify_onnx(model_path):
|
||||
log.error("Exported model failed verification!")
|
||||
log.info("The file is kept at %s — inspect manually.", model_path)
|
||||
return 1
|
||||
else:
|
||||
log.info("Skipping verification (--skip-verify).")
|
||||
|
||||
log.info("Done. Model ready at %s", model_path)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
412
scripts/export-trocr-onnx.py
Executable file
412
scripts/export-trocr-onnx.py
Executable file
@@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TrOCR ONNX Export — exports TrOCR models to ONNX with int8 quantization.
|
||||
|
||||
Supported models:
|
||||
- microsoft/trocr-base-printed
|
||||
- microsoft/trocr-base-handwritten
|
||||
|
||||
Steps per model:
|
||||
1. Load PyTorch model via optimum ORTModelForVision2Seq (export=True)
|
||||
2. Save ONNX to output directory
|
||||
3. Quantize to int8 via ORTQuantizer + AutoQuantizationConfig
|
||||
4. Verify: compare PyTorch vs ONNX outputs (diff < 2%)
|
||||
5. Report model sizes before/after quantization
|
||||
|
||||
Usage:
|
||||
python scripts/export-trocr-onnx.py
|
||||
python scripts/export-trocr-onnx.py --model printed
|
||||
python scripts/export-trocr-onnx.py --model handwritten --skip-verify
|
||||
python scripts/export-trocr-onnx.py --output-dir models/onnx --skip-quantize
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'klausur-service', 'backend'))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MODELS = {
|
||||
"printed": "microsoft/trocr-base-printed",
|
||||
"handwritten": "microsoft/trocr-base-handwritten",
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_DIR = os.path.join(os.path.dirname(__file__), '..', 'models', 'onnx')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def dir_size_mb(path: str) -> float:
|
||||
"""Return total size of all files under *path* in MB."""
|
||||
total = 0
|
||||
for root, _dirs, files in os.walk(path):
|
||||
for f in files:
|
||||
total += os.path.getsize(os.path.join(root, f))
|
||||
return total / (1024 * 1024)
|
||||
|
||||
|
||||
def log(msg: str) -> None:
|
||||
"""Print a timestamped log message to stderr."""
|
||||
print(f"[export-onnx] {msg}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
def _create_test_image():
|
||||
"""Create a simple synthetic text-line image for verification."""
|
||||
from PIL import Image
|
||||
|
||||
w, h = 384, 48
|
||||
img = Image.new('RGB', (w, h), 'white')
|
||||
pixels = img.load()
|
||||
# Draw a dark region to simulate printed text
|
||||
for x in range(60, 220):
|
||||
for y in range(10, 38):
|
||||
pixels[x, y] = (25, 25, 25)
|
||||
return img
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def export_to_onnx(model_name: str, output_dir: str) -> str:
|
||||
"""Export a HuggingFace TrOCR model to ONNX via optimum.
|
||||
|
||||
Returns the path to the saved ONNX directory.
|
||||
"""
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
short_name = model_name.split("/")[-1]
|
||||
onnx_path = os.path.join(output_dir, short_name)
|
||||
|
||||
log(f"Exporting {model_name} to ONNX ...")
|
||||
t0 = time.monotonic()
|
||||
|
||||
model = ORTModelForVision2Seq.from_pretrained(model_name, export=True)
|
||||
model.save_pretrained(onnx_path)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
size = dir_size_mb(onnx_path)
|
||||
log(f" Exported in {elapsed:.1f}s — {size:.1f} MB on disk: {onnx_path}")
|
||||
|
||||
return onnx_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def quantize_onnx(onnx_path: str) -> str:
|
||||
"""Apply int8 dynamic quantization to an ONNX model directory.
|
||||
|
||||
Returns the path to the quantized model directory.
|
||||
"""
|
||||
from optimum.onnxruntime import ORTQuantizer
|
||||
from optimum.onnxruntime.configuration import AutoQuantizationConfig
|
||||
|
||||
quantized_path = onnx_path + "-int8"
|
||||
log(f"Quantizing to int8 → {quantized_path} ...")
|
||||
t0 = time.monotonic()
|
||||
|
||||
# Pick quantization config based on platform.
|
||||
# arm64 (Apple Silicon) does not have AVX-512; use arm64 config when
|
||||
# available, otherwise fall back to avx512_vnni which still works for
|
||||
# dynamic quantisation (weights-only).
|
||||
machine = platform.machine().lower()
|
||||
if "arm" in machine or "aarch" in machine:
|
||||
try:
|
||||
qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
|
||||
log(" Using arm64 quantization config")
|
||||
except AttributeError:
|
||||
# Older optimum versions may lack arm64(); fall back.
|
||||
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
|
||||
log(" arm64 config not available, falling back to avx512_vnni")
|
||||
else:
|
||||
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
|
||||
log(" Using avx512_vnni quantization config")
|
||||
|
||||
quantizer = ORTQuantizer.from_pretrained(onnx_path)
|
||||
quantizer.quantize(save_dir=quantized_path, quantization_config=qconfig)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
size = dir_size_mb(quantized_path)
|
||||
log(f" Quantized in {elapsed:.1f}s — {size:.1f} MB on disk")
|
||||
|
||||
return quantized_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def verify_outputs(model_name: str, onnx_path: str) -> dict:
|
||||
"""Compare PyTorch and ONNX model outputs on a synthetic image.
|
||||
|
||||
Returns a dict with verification results including the max relative
|
||||
difference of generated token IDs and decoded text from both backends.
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
log(f"Verifying ONNX output against PyTorch for {model_name} ...")
|
||||
test_image = _create_test_image()
|
||||
|
||||
# --- PyTorch inference ---
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
pt_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
pt_model.eval()
|
||||
|
||||
pixel_values = processor(images=test_image, return_tensors="pt").pixel_values
|
||||
with torch.no_grad():
|
||||
pt_ids = pt_model.generate(pixel_values, max_new_tokens=50)
|
||||
pt_text = processor.batch_decode(pt_ids, skip_special_tokens=True)[0]
|
||||
|
||||
# --- ONNX inference ---
|
||||
ort_model = ORTModelForVision2Seq.from_pretrained(onnx_path)
|
||||
ort_pixel_values = processor(images=test_image, return_tensors="pt").pixel_values
|
||||
ort_ids = ort_model.generate(ort_pixel_values, max_new_tokens=50)
|
||||
ort_text = processor.batch_decode(ort_ids, skip_special_tokens=True)[0]
|
||||
|
||||
# --- Compare ---
|
||||
pt_arr = pt_ids[0].numpy().astype(np.float64)
|
||||
ort_arr = ort_ids[0].numpy().astype(np.float64)
|
||||
|
||||
# Pad to equal length for comparison
|
||||
max_len = max(len(pt_arr), len(ort_arr))
|
||||
if len(pt_arr) < max_len:
|
||||
pt_arr = np.pad(pt_arr, (0, max_len - len(pt_arr)))
|
||||
if len(ort_arr) < max_len:
|
||||
ort_arr = np.pad(ort_arr, (0, max_len - len(ort_arr)))
|
||||
|
||||
# Relative diff on token ids (treat 0-ids as 1 to avoid div-by-zero)
|
||||
denom = np.where(np.abs(pt_arr) > 0, np.abs(pt_arr), 1.0)
|
||||
rel_diff = np.abs(pt_arr - ort_arr) / denom
|
||||
max_diff_pct = float(np.max(rel_diff)) * 100.0
|
||||
exact_match = bool(np.array_equal(pt_ids[0].numpy(), ort_ids[0].numpy()))
|
||||
|
||||
passed = max_diff_pct < 2.0
|
||||
|
||||
result = {
|
||||
"passed": passed,
|
||||
"exact_token_match": exact_match,
|
||||
"max_relative_diff_pct": round(max_diff_pct, 4),
|
||||
"pytorch_text": pt_text,
|
||||
"onnx_text": ort_text,
|
||||
"text_match": pt_text == ort_text,
|
||||
}
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
log(f" Verification {status}: max_diff={max_diff_pct:.4f}% exact_match={exact_match}")
|
||||
log(f" PyTorch : '{pt_text}'")
|
||||
log(f" ONNX : '{ort_text}'")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-model pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def process_model(
|
||||
model_name: str,
|
||||
output_dir: str,
|
||||
skip_verify: bool = False,
|
||||
skip_quantize: bool = False,
|
||||
) -> dict:
|
||||
"""Run the full export pipeline for one model.
|
||||
|
||||
Returns a summary dict with paths, sizes, and verification results.
|
||||
"""
|
||||
short_name = model_name.split("/")[-1]
|
||||
summary: dict = {
|
||||
"model": model_name,
|
||||
"short_name": short_name,
|
||||
"onnx_path": None,
|
||||
"onnx_size_mb": None,
|
||||
"quantized_path": None,
|
||||
"quantized_size_mb": None,
|
||||
"size_reduction_pct": None,
|
||||
"verification_fp32": None,
|
||||
"verification_int8": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
log(f"{'='*60}")
|
||||
log(f"Processing: {model_name}")
|
||||
log(f"{'='*60}")
|
||||
|
||||
# Step 1 + 2: Export to ONNX
|
||||
try:
|
||||
onnx_path = export_to_onnx(model_name, output_dir)
|
||||
summary["onnx_path"] = onnx_path
|
||||
summary["onnx_size_mb"] = round(dir_size_mb(onnx_path), 1)
|
||||
except Exception as e:
|
||||
summary["error"] = f"ONNX export failed: {e}"
|
||||
log(f" ERROR: {summary['error']}")
|
||||
return summary
|
||||
|
||||
# Step 3: Verify fp32 ONNX
|
||||
if not skip_verify:
|
||||
try:
|
||||
summary["verification_fp32"] = verify_outputs(model_name, onnx_path)
|
||||
except Exception as e:
|
||||
log(f" WARNING: fp32 verification failed: {e}")
|
||||
summary["verification_fp32"] = {"passed": False, "error": str(e)}
|
||||
|
||||
# Step 4: Quantize to int8
|
||||
if not skip_quantize:
|
||||
try:
|
||||
quantized_path = quantize_onnx(onnx_path)
|
||||
summary["quantized_path"] = quantized_path
|
||||
q_size = dir_size_mb(quantized_path)
|
||||
summary["quantized_size_mb"] = round(q_size, 1)
|
||||
|
||||
if summary["onnx_size_mb"] and summary["onnx_size_mb"] > 0:
|
||||
reduction = (1 - q_size / dir_size_mb(onnx_path)) * 100
|
||||
summary["size_reduction_pct"] = round(reduction, 1)
|
||||
except Exception as e:
|
||||
summary["error"] = f"Quantization failed: {e}"
|
||||
log(f" ERROR: {summary['error']}")
|
||||
return summary
|
||||
|
||||
# Step 5: Verify int8 ONNX
|
||||
if not skip_verify:
|
||||
try:
|
||||
summary["verification_int8"] = verify_outputs(model_name, quantized_path)
|
||||
except Exception as e:
|
||||
log(f" WARNING: int8 verification failed: {e}")
|
||||
summary["verification_int8"] = {"passed": False, "error": str(e)}
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Summary printing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def print_summary(results: list[dict]) -> None:
|
||||
"""Print a human-readable summary table."""
|
||||
print("\n" + "=" * 72, file=sys.stderr)
|
||||
print(" EXPORT SUMMARY", file=sys.stderr)
|
||||
print("=" * 72, file=sys.stderr)
|
||||
|
||||
for r in results:
|
||||
print(f"\n Model: {r['model']}", file=sys.stderr)
|
||||
|
||||
if r.get("error"):
|
||||
print(f" ERROR: {r['error']}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Sizes
|
||||
onnx_mb = r.get("onnx_size_mb", "?")
|
||||
q_mb = r.get("quantized_size_mb", "?")
|
||||
reduction = r.get("size_reduction_pct", "?")
|
||||
|
||||
print(f" ONNX fp32 : {onnx_mb} MB ({r.get('onnx_path', '?')})", file=sys.stderr)
|
||||
if q_mb != "?":
|
||||
print(f" ONNX int8 : {q_mb} MB ({r.get('quantized_path', '?')})", file=sys.stderr)
|
||||
print(f" Reduction : {reduction}%", file=sys.stderr)
|
||||
|
||||
# Verification
|
||||
for label, key in [("fp32 verify", "verification_fp32"), ("int8 verify", "verification_int8")]:
|
||||
v = r.get(key)
|
||||
if v is None:
|
||||
print(f" {label}: skipped", file=sys.stderr)
|
||||
elif v.get("error"):
|
||||
print(f" {label}: ERROR — {v['error']}", file=sys.stderr)
|
||||
else:
|
||||
status = "PASS" if v["passed"] else "FAIL"
|
||||
diff = v.get("max_relative_diff_pct", "?")
|
||||
match = v.get("text_match", "?")
|
||||
print(f" {label}: {status} (max_diff={diff}%, text_match={match})", file=sys.stderr)
|
||||
|
||||
print("\n" + "=" * 72 + "\n", file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export TrOCR models to ONNX with int8 quantization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
choices=["printed", "handwritten", "both"],
|
||||
default="both",
|
||||
help="Which model to export (default: both)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default=DEFAULT_OUTPUT_DIR,
|
||||
help=f"Output directory for ONNX models (default: {DEFAULT_OUTPUT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-verify",
|
||||
action="store_true",
|
||||
help="Skip output verification against PyTorch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-quantize",
|
||||
action="store_true",
|
||||
help="Skip int8 quantization step",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve output directory
|
||||
output_dir = os.path.abspath(args.output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
log(f"Output directory: {output_dir}")
|
||||
|
||||
# Determine which models to process
|
||||
if args.model == "both":
|
||||
model_names = list(MODELS.values())
|
||||
else:
|
||||
model_names = [MODELS[args.model]]
|
||||
|
||||
# Process each model
|
||||
results = []
|
||||
for model_name in model_names:
|
||||
result = process_model(
|
||||
model_name=model_name,
|
||||
output_dir=output_dir,
|
||||
skip_verify=args.skip_verify,
|
||||
skip_quantize=args.skip_quantize,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Print summary
|
||||
print_summary(results)
|
||||
|
||||
# Exit with error code if any model failed verification
|
||||
any_fail = False
|
||||
for r in results:
|
||||
if r.get("error"):
|
||||
any_fail = True
|
||||
for vkey in ("verification_fp32", "verification_int8"):
|
||||
v = r.get(vkey)
|
||||
if v and not v.get("passed", True):
|
||||
any_fail = True
|
||||
|
||||
if any_fail:
|
||||
log("One or more steps failed or did not pass verification.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
log("All exports completed successfully.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user