""" Training API — simulation helper and SSE generators. """ import json import uuid import asyncio from datetime import datetime, timedelta from training_models import TrainingStatus, _state async def simulate_training_progress(job_id: str): """Simulate training progress (replace with actual training logic).""" if job_id not in _state.jobs: return job = _state.jobs[job_id] job["status"] = TrainingStatus.TRAINING.value job["started_at"] = datetime.now().isoformat() total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch current_step = 0 while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value: # Update progress progress = (current_step / total_steps) * 100 current_epoch = current_step // 100 + 1 # Simulate decreasing loss base_loss = 0.8 * (1 - progress / 100) + 0.1 loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100)) val_loss = loss * 1.1 # Update job state job["progress"] = progress job["current_epoch"] = min(current_epoch, job["total_epochs"]) job["loss"] = round(loss, 4) job["val_loss"] = round(val_loss, 4) job["documents_processed"] = int((progress / 100) * job["total_documents"]) # Update metrics job["metrics"]["loss_history"].append(round(loss, 4)) job["metrics"]["val_loss_history"].append(round(val_loss, 4)) job["metrics"]["precision"] = round(0.5 + (progress / 200), 3) job["metrics"]["recall"] = round(0.45 + (progress / 200), 3) job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3) job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3) # Keep only last 50 history points if len(job["metrics"]["loss_history"]) > 50: job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:] job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:] # Estimate completion if progress > 0: elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds() remaining = (elapsed / progress) * (100 - progress) job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat() current_step += 1 await asyncio.sleep(0.5) # Simulate work # Mark as completed if job["status"] == TrainingStatus.TRAINING.value: job["status"] = TrainingStatus.COMPLETED.value job["progress"] = 100 job["completed_at"] = datetime.now().isoformat() # Create model version version_id = str(uuid.uuid4()) _state.model_versions[version_id] = { "id": version_id, "job_id": job_id, "version": f"v{len(_state.model_versions) + 1}.0", "model_type": job["model_type"], "created_at": datetime.now().isoformat(), "metrics": job["metrics"], "is_active": True, "size_mb": 245.7, "bundeslaender": job["config"]["bundeslaender"], } _state.active_job_id = None async def training_metrics_generator(job_id: str, request): """ SSE generator for streaming training metrics. Yields JSON-encoded training status updates every 500ms. """ while True: # Check if client disconnected if await request.is_disconnected(): break # Get job status if job_id not in _state.jobs: yield f"data: {json.dumps({'error': 'Job not found'})}\n\n" break job = _state.jobs[job_id] # Build metrics response metrics_data = { "job_id": job["id"], "status": job["status"], "progress": job["progress"], "current_epoch": job["current_epoch"], "total_epochs": job["total_epochs"], "current_step": int(job["progress"] * job["total_epochs"]), "total_steps": job["total_epochs"] * 100, "elapsed_time_ms": 0, "estimated_remaining_ms": 0, "metrics": { "loss": job["loss"], "val_loss": job["val_loss"], "accuracy": job["metrics"]["accuracy"], "learning_rate": job["learning_rate"] }, "history": [ { "epoch": i + 1, "step": (i + 1) * 10, "loss": loss, "val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None, "learning_rate": job["learning_rate"], "timestamp": 0 } for i, loss in enumerate(job["metrics"]["loss_history"][-50:]) ] } # Calculate elapsed time if job["started_at"]: started = datetime.fromisoformat(job["started_at"]) metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000) # Calculate remaining time if job["estimated_completion"]: estimated = datetime.fromisoformat(job["estimated_completion"]) metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000)) # Send SSE event yield f"data: {json.dumps(metrics_data)}\n\n" # Check if job completed if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]: break # Wait before next update await asyncio.sleep(0.5) async def batch_ocr_progress_generator(images_count: int, request): """ SSE generator for batch OCR progress simulation. In production, this would integrate with actual OCR processing. """ import random for i in range(images_count): # Check if client disconnected if await request.is_disconnected(): break # Simulate processing time await asyncio.sleep(random.uniform(0.3, 0.8)) progress_data = { "type": "progress", "current": i + 1, "total": images_count, "progress_percent": ((i + 1) / images_count) * 100, "elapsed_ms": (i + 1) * 500, "estimated_remaining_ms": (images_count - i - 1) * 500, "result": { "text": f"Sample recognized text for image {i + 1}", "confidence": round(random.uniform(0.7, 0.98), 2), "processing_time_ms": random.randint(200, 600), "from_cache": random.random() < 0.2 } } yield f"data: {json.dumps(progress_data)}\n\n" # Send completion event yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"