""" Vast.ai Power Control API. Stellt Endpoints bereit fuer: - Start/Stop von vast.ai Instanzen - Status-Abfrage - Auto-Shutdown bei Inaktivitaet - Kosten-Tracking Sicherheit: Alle Endpoints erfordern CONTROL_API_KEY. """ import asyncio import json import logging import os import time from datetime import datetime, timezone from pathlib import Path from typing import Optional, Dict, Any, List from fastapi import APIRouter, Depends, HTTPException, Header, BackgroundTasks from pydantic import BaseModel, Field from .vast_client import VastAIClient, InstanceInfo, InstanceStatus, AccountInfo logger = logging.getLogger(__name__) router = APIRouter(prefix="/infra/vast", tags=["Infrastructure"]) # ------------------------- # Configuration (ENV) # ------------------------- VAST_API_KEY = os.getenv("VAST_API_KEY") VAST_INSTANCE_ID = os.getenv("VAST_INSTANCE_ID") # Numeric instance ID CONTROL_API_KEY = os.getenv("CONTROL_API_KEY") # Admin key for these endpoints # Health check configuration VAST_HEALTH_PORT = int(os.getenv("VAST_HEALTH_PORT", "8001")) VAST_HEALTH_PATH = os.getenv("VAST_HEALTH_PATH", "/health") VAST_WAIT_TIMEOUT_S = int(os.getenv("VAST_WAIT_TIMEOUT_S", "600")) # 10 min # Auto-shutdown configuration AUTO_SHUTDOWN_ENABLED = os.getenv("VAST_AUTO_SHUTDOWN", "true").lower() == "true" AUTO_SHUTDOWN_MINUTES = int(os.getenv("VAST_AUTO_SHUTDOWN_MINUTES", "30")) # State persistence (in /tmp for container compatibility) STATE_PATH = Path(os.getenv("VAST_STATE_PATH", "/tmp/vast_state.json")) AUDIT_PATH = Path(os.getenv("VAST_AUDIT_PATH", "/tmp/vast_audit.log")) # ------------------------- # State Management # ------------------------- class VastState: """ Persistenter State fuer vast.ai Kontrolle. Speichert: - Aktueller Endpunkt (weil IP sich aendern kann) - Letzte Aktivitaet (fuer Auto-Shutdown) - Kosten-Tracking """ def __init__(self, path: Path = STATE_PATH): self.path = path self._state: Dict[str, Any] = self._load() def _load(self) -> Dict[str, Any]: """Laedt State von Disk.""" if not self.path.exists(): return { "desired_state": None, "endpoint_base_url": None, "last_activity": None, "last_start": None, "last_stop": None, "total_runtime_seconds": 0, "total_cost_usd": 0.0, } try: return json.loads(self.path.read_text(encoding="utf-8")) except Exception: return {} def _save(self) -> None: """Speichert State auf Disk.""" self.path.parent.mkdir(parents=True, exist_ok=True) self.path.write_text( json.dumps(self._state, ensure_ascii=False, indent=2), encoding="utf-8", ) def get(self, key: str, default: Any = None) -> Any: return self._state.get(key, default) def set(self, key: str, value: Any) -> None: self._state[key] = value self._save() def update(self, data: Dict[str, Any]) -> None: self._state.update(data) self._save() def record_activity(self) -> None: """Zeichnet letzte Aktivitaet auf (fuer Auto-Shutdown).""" self._state["last_activity"] = datetime.now(timezone.utc).isoformat() self._save() def get_last_activity(self) -> Optional[datetime]: """Gibt letzte Aktivitaet als datetime.""" ts = self._state.get("last_activity") if ts: return datetime.fromisoformat(ts) return None def record_start(self) -> None: """Zeichnet Start-Zeit auf.""" self._state["last_start"] = datetime.now(timezone.utc).isoformat() self._state["desired_state"] = "RUNNING" self._save() def record_stop(self, dph_total: Optional[float] = None) -> None: """Zeichnet Stop-Zeit auf und berechnet Kosten.""" now = datetime.now(timezone.utc) self._state["last_stop"] = now.isoformat() self._state["desired_state"] = "STOPPED" # Berechne Runtime und Kosten last_start = self._state.get("last_start") if last_start: start_dt = datetime.fromisoformat(last_start) runtime_seconds = (now - start_dt).total_seconds() self._state["total_runtime_seconds"] = ( self._state.get("total_runtime_seconds", 0) + runtime_seconds ) if dph_total: hours = runtime_seconds / 3600 cost = hours * dph_total self._state["total_cost_usd"] = ( self._state.get("total_cost_usd", 0.0) + cost ) logger.info( f"Session cost: ${cost:.3f} ({runtime_seconds/60:.1f} min @ ${dph_total}/h)" ) self._save() # Global state instance _state = VastState() # ------------------------- # Audit Logging # ------------------------- def audit_log(event: str, actor: str = "system", meta: Optional[Dict[str, Any]] = None) -> None: """Schreibt Audit-Log Eintrag.""" meta = meta or {} line = json.dumps( { "ts": datetime.now(timezone.utc).isoformat(), "event": event, "actor": actor, "meta": meta, }, ensure_ascii=False, ) AUDIT_PATH.parent.mkdir(parents=True, exist_ok=True) with AUDIT_PATH.open("a", encoding="utf-8") as f: f.write(line + "\n") logger.info(f"AUDIT: {event} by {actor}") # ------------------------- # Request/Response Models # ------------------------- class PowerOnRequest(BaseModel): wait_for_health: bool = Field(default=True, description="Warten bis LLM bereit") health_path: str = Field(default=VAST_HEALTH_PATH) health_port: int = Field(default=VAST_HEALTH_PORT) class PowerOnResponse(BaseModel): status: str instance_id: Optional[int] = None endpoint_base_url: Optional[str] = None health_url: Optional[str] = None message: Optional[str] = None class PowerOffRequest(BaseModel): pass # Keine Parameter noetig class PowerOffResponse(BaseModel): status: str session_runtime_minutes: Optional[float] = None session_cost_usd: Optional[float] = None message: Optional[str] = None class VastStatusResponse(BaseModel): instance_id: Optional[int] = None status: str gpu_name: Optional[str] = None dph_total: Optional[float] = None endpoint_base_url: Optional[str] = None last_activity: Optional[str] = None auto_shutdown_in_minutes: Optional[int] = None total_runtime_hours: Optional[float] = None total_cost_usd: Optional[float] = None # Budget / Credit Informationen account_credit: Optional[float] = None # Verbleibendes Guthaben in USD account_total_spend: Optional[float] = None # Gesamtausgaben auf vast.ai # Session-Kosten (seit letztem Start) session_runtime_minutes: Optional[float] = None session_cost_usd: Optional[float] = None message: Optional[str] = None class CostStatsResponse(BaseModel): total_runtime_hours: float total_cost_usd: float sessions_count: int avg_session_minutes: float # ------------------------- # Security Dependency # ------------------------- def require_control_key(x_api_key: Optional[str] = Header(default=None)) -> None: """ Admin-Schutz fuer Control-Endpoints. Header: X-API-Key: """ if not CONTROL_API_KEY: raise HTTPException( status_code=500, detail="CONTROL_API_KEY not configured on server", ) if x_api_key != CONTROL_API_KEY: raise HTTPException(status_code=401, detail="Unauthorized") # ------------------------- # Auto-Shutdown Background Task # ------------------------- _shutdown_task: Optional[asyncio.Task] = None async def auto_shutdown_monitor() -> None: """ Hintergrund-Task der bei Inaktivitaet die Instanz stoppt. Laeuft permanent wenn Instanz an ist und prueft alle 60s ob Aktivitaet stattfand. Stoppt Instanz wenn keine Aktivitaet seit AUTO_SHUTDOWN_MINUTES. """ if not VAST_API_KEY or not VAST_INSTANCE_ID: return client = VastAIClient(VAST_API_KEY) try: while True: await asyncio.sleep(60) # Check every minute if not AUTO_SHUTDOWN_ENABLED: continue last_activity = _state.get_last_activity() if not last_activity: continue # Berechne Inaktivitaet now = datetime.now(timezone.utc) inactive_minutes = (now - last_activity).total_seconds() / 60 if inactive_minutes >= AUTO_SHUTDOWN_MINUTES: logger.info( f"Auto-shutdown triggered: {inactive_minutes:.1f} min inactive" ) audit_log( "auto_shutdown", actor="system", meta={"inactive_minutes": inactive_minutes}, ) # Hole aktuelle Instanz-Info fuer Kosten instance = await client.get_instance(int(VAST_INSTANCE_ID)) dph = instance.dph_total if instance else None # Stop await client.stop_instance(int(VAST_INSTANCE_ID)) _state.record_stop(dph_total=dph) audit_log("auto_shutdown_complete", actor="system") except asyncio.CancelledError: pass except Exception as e: logger.error(f"Auto-shutdown monitor error: {e}") finally: await client.close() def start_auto_shutdown_monitor() -> None: """Startet den Auto-Shutdown Monitor.""" global _shutdown_task if _shutdown_task is None or _shutdown_task.done(): _shutdown_task = asyncio.create_task(auto_shutdown_monitor()) logger.info("Auto-shutdown monitor started") def stop_auto_shutdown_monitor() -> None: """Stoppt den Auto-Shutdown Monitor.""" global _shutdown_task if _shutdown_task and not _shutdown_task.done(): _shutdown_task.cancel() logger.info("Auto-shutdown monitor stopped") # ------------------------- # API Endpoints # ------------------------- @router.get("/status", response_model=VastStatusResponse, dependencies=[Depends(require_control_key)]) async def get_status() -> VastStatusResponse: """ Gibt Status der vast.ai Instanz zurueck. Inkludiert: - Aktueller Status (running/stopped/etc) - GPU Info und Kosten pro Stunde - Endpoint URL - Auto-Shutdown Timer - Gesamtkosten - Account Credit (verbleibendes Budget) - Session-Kosten (seit letztem Start) """ if not VAST_API_KEY or not VAST_INSTANCE_ID: return VastStatusResponse( status="unconfigured", message="VAST_API_KEY or VAST_INSTANCE_ID not set", ) client = VastAIClient(VAST_API_KEY) try: instance = await client.get_instance(int(VAST_INSTANCE_ID)) if not instance: return VastStatusResponse( instance_id=int(VAST_INSTANCE_ID), status="not_found", message=f"Instance {VAST_INSTANCE_ID} not found", ) # Hole Account-Info fuer Budget/Credit account_info = await client.get_account_info() account_credit = account_info.credit if account_info else None account_total_spend = account_info.total_spend if account_info else None # Update endpoint if running endpoint = None if instance.status == InstanceStatus.RUNNING: endpoint = instance.get_endpoint_url(VAST_HEALTH_PORT) if endpoint: _state.set("endpoint_base_url", endpoint) # Calculate auto-shutdown timer auto_shutdown_minutes = None if AUTO_SHUTDOWN_ENABLED and instance.status == InstanceStatus.RUNNING: last_activity = _state.get_last_activity() if last_activity: inactive = (datetime.now(timezone.utc) - last_activity).total_seconds() / 60 auto_shutdown_minutes = max(0, int(AUTO_SHUTDOWN_MINUTES - inactive)) # Berechne aktuelle Session-Kosten (wenn Instanz laeuft) session_runtime_minutes = None session_cost_usd = None last_start = _state.get("last_start") # Falls Instanz laeuft aber kein last_start gesetzt (z.B. nach Container-Neustart), # nutze start_date aus der vast.ai API falls vorhanden, sonst jetzt if instance.status == InstanceStatus.RUNNING and not last_start: if instance.started_at: _state.set("last_start", instance.started_at.isoformat()) last_start = instance.started_at.isoformat() else: _state.record_start() last_start = _state.get("last_start") if last_start and instance.status == InstanceStatus.RUNNING: start_dt = datetime.fromisoformat(last_start) session_runtime_minutes = (datetime.now(timezone.utc) - start_dt).total_seconds() / 60 if instance.dph_total: session_cost_usd = (session_runtime_minutes / 60) * instance.dph_total return VastStatusResponse( instance_id=instance.id, status=instance.status.value, gpu_name=instance.gpu_name, dph_total=instance.dph_total, endpoint_base_url=endpoint or _state.get("endpoint_base_url"), last_activity=_state.get("last_activity"), auto_shutdown_in_minutes=auto_shutdown_minutes, total_runtime_hours=_state.get("total_runtime_seconds", 0) / 3600, total_cost_usd=_state.get("total_cost_usd", 0.0), account_credit=account_credit, account_total_spend=account_total_spend, session_runtime_minutes=session_runtime_minutes, session_cost_usd=session_cost_usd, ) finally: await client.close() @router.post("/power/on", response_model=PowerOnResponse, dependencies=[Depends(require_control_key)]) async def power_on( payload: PowerOnRequest, background_tasks: BackgroundTasks, ) -> PowerOnResponse: """ Startet die vast.ai Instanz. 1. Startet Instanz via API 2. Wartet auf Status RUNNING 3. Optional: Wartet auf Health-Endpoint 4. Startet Auto-Shutdown Monitor """ if not VAST_API_KEY or not VAST_INSTANCE_ID: raise HTTPException( status_code=500, detail="VAST_API_KEY or VAST_INSTANCE_ID not configured", ) instance_id = int(VAST_INSTANCE_ID) audit_log("power_on_requested", meta={"instance_id": instance_id}) client = VastAIClient(VAST_API_KEY) try: # Start instance success = await client.start_instance(instance_id) if not success: raise HTTPException(status_code=502, detail="Failed to start instance") _state.record_start() _state.record_activity() # Wait for running status instance = await client.wait_for_status( instance_id, InstanceStatus.RUNNING, timeout_seconds=300, ) if not instance: return PowerOnResponse( status="starting", instance_id=instance_id, message="Instance start requested but not yet running. Check status.", ) # Get endpoint endpoint = instance.get_endpoint_url(payload.health_port) if endpoint: _state.set("endpoint_base_url", endpoint) # Wait for health if requested if payload.wait_for_health: health_ok = await client.wait_for_health( instance, health_path=payload.health_path, internal_port=payload.health_port, timeout_seconds=VAST_WAIT_TIMEOUT_S, ) if not health_ok: audit_log("power_on_health_timeout", meta={"instance_id": instance_id}) return PowerOnResponse( status="running_unhealthy", instance_id=instance_id, endpoint_base_url=endpoint, message=f"Instance running but health check failed at {endpoint}{payload.health_path}", ) # Start auto-shutdown monitor start_auto_shutdown_monitor() audit_log("power_on_complete", meta={ "instance_id": instance_id, "endpoint": endpoint, }) return PowerOnResponse( status="running", instance_id=instance_id, endpoint_base_url=endpoint, health_url=f"{endpoint}{payload.health_path}" if endpoint else None, message="Instance running and healthy", ) finally: await client.close() @router.post("/power/off", response_model=PowerOffResponse, dependencies=[Depends(require_control_key)]) async def power_off(payload: PowerOffRequest) -> PowerOffResponse: """ Stoppt die vast.ai Instanz (behaelt Disk). Berechnet Session-Kosten und -Laufzeit. """ if not VAST_API_KEY or not VAST_INSTANCE_ID: raise HTTPException( status_code=500, detail="VAST_API_KEY or VAST_INSTANCE_ID not configured", ) instance_id = int(VAST_INSTANCE_ID) audit_log("power_off_requested", meta={"instance_id": instance_id}) # Stop auto-shutdown monitor stop_auto_shutdown_monitor() client = VastAIClient(VAST_API_KEY) try: # Get current info for cost calculation instance = await client.get_instance(instance_id) dph = instance.dph_total if instance else None # Calculate session stats before updating state session_runtime = 0.0 session_cost = 0.0 last_start = _state.get("last_start") if last_start: start_dt = datetime.fromisoformat(last_start) session_runtime = (datetime.now(timezone.utc) - start_dt).total_seconds() / 60 if dph: session_cost = (session_runtime / 60) * dph # Stop instance success = await client.stop_instance(instance_id) if not success: raise HTTPException(status_code=502, detail="Failed to stop instance") _state.record_stop(dph_total=dph) audit_log("power_off_complete", meta={ "instance_id": instance_id, "session_minutes": session_runtime, "session_cost": session_cost, }) return PowerOffResponse( status="stopped", session_runtime_minutes=session_runtime, session_cost_usd=session_cost, message=f"Instance stopped. Session: {session_runtime:.1f} min, ${session_cost:.3f}", ) finally: await client.close() @router.post("/activity", dependencies=[Depends(require_control_key)]) async def record_activity() -> Dict[str, str]: """ Zeichnet Aktivitaet auf (verzoegert Auto-Shutdown). Sollte von LLM Gateway aufgerufen werden bei jedem Request. """ _state.record_activity() return {"status": "recorded", "last_activity": _state.get("last_activity")} @router.get("/costs", response_model=CostStatsResponse, dependencies=[Depends(require_control_key)]) async def get_costs() -> CostStatsResponse: """ Gibt Kosten-Statistiken zurueck. """ total_seconds = _state.get("total_runtime_seconds", 0) total_cost = _state.get("total_cost_usd", 0.0) # TODO: Sessions count from audit log sessions = 1 if total_seconds > 0 else 0 avg_minutes = (total_seconds / 60 / sessions) if sessions > 0 else 0 return CostStatsResponse( total_runtime_hours=total_seconds / 3600, total_cost_usd=total_cost, sessions_count=sessions, avg_session_minutes=avg_minutes, ) @router.get("/audit", dependencies=[Depends(require_control_key)]) async def get_audit_log(limit: int = 50) -> List[Dict[str, Any]]: """ Gibt letzte Audit-Log Eintraege zurueck. """ if not AUDIT_PATH.exists(): return [] lines = AUDIT_PATH.read_text(encoding="utf-8").strip().split("\n") entries = [] for line in lines[-limit:]: try: entries.append(json.loads(line)) except json.JSONDecodeError: continue return list(reversed(entries)) # Neueste zuerst