221 lines
6.8 KiB
Python
221 lines
6.8 KiB
Python
"""
|
|
Session Management API
|
|
Handles voice session lifecycle
|
|
|
|
Endpoints:
|
|
- POST /api/v1/sessions # Session erstellen
|
|
- GET /api/v1/sessions/{id} # Session Status
|
|
- DELETE /api/v1/sessions/{id} # Session beenden
|
|
- GET /api/v1/sessions/{id}/tasks # Pending Tasks
|
|
"""
|
|
import structlog
|
|
from fastapi import APIRouter, HTTPException, Request, Depends
|
|
from typing import List, Optional
|
|
from datetime import datetime, timedelta
|
|
|
|
from config import settings
|
|
from models.session import (
|
|
VoiceSession,
|
|
SessionCreate,
|
|
SessionResponse,
|
|
SessionStatus,
|
|
)
|
|
from models.task import TaskResponse, TaskState
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# In-memory session store (will be replaced with Valkey in production)
|
|
# This is transient - sessions are never persisted to disk
|
|
_sessions: dict[str, VoiceSession] = {}
|
|
|
|
|
|
async def get_session(session_id: str) -> VoiceSession:
|
|
"""Get session by ID or raise 404."""
|
|
session = _sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return session
|
|
|
|
|
|
@router.post("", response_model=SessionResponse)
|
|
async def create_session(request: Request, session_data: SessionCreate):
|
|
"""
|
|
Create a new voice session.
|
|
|
|
Returns a session ID and WebSocket URL for audio streaming.
|
|
The client must connect to the WebSocket within 30 seconds.
|
|
"""
|
|
logger.info(
|
|
"Creating voice session",
|
|
namespace_id=session_data.namespace_id[:8] + "...",
|
|
device_type=session_data.device_type,
|
|
)
|
|
|
|
# Verify namespace key hash
|
|
orchestrator = request.app.state.orchestrator
|
|
encryption = request.app.state.encryption
|
|
|
|
if settings.encryption_enabled:
|
|
if not encryption.verify_key_hash(session_data.key_hash):
|
|
logger.warning("Invalid key hash", namespace_id=session_data.namespace_id[:8])
|
|
raise HTTPException(status_code=401, detail="Invalid encryption key hash")
|
|
|
|
# Check rate limits
|
|
namespace_sessions = [
|
|
s for s in _sessions.values()
|
|
if s.namespace_id == session_data.namespace_id
|
|
and s.status not in [SessionStatus.CLOSED, SessionStatus.ERROR]
|
|
]
|
|
if len(namespace_sessions) >= settings.max_sessions_per_user:
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"Maximum {settings.max_sessions_per_user} concurrent sessions allowed"
|
|
)
|
|
|
|
# Create session
|
|
session = VoiceSession(
|
|
namespace_id=session_data.namespace_id,
|
|
key_hash=session_data.key_hash,
|
|
device_type=session_data.device_type,
|
|
client_version=session_data.client_version,
|
|
)
|
|
|
|
# Store session (in RAM only)
|
|
_sessions[session.id] = session
|
|
|
|
logger.info(
|
|
"Voice session created",
|
|
session_id=session.id[:8],
|
|
namespace_id=session_data.namespace_id[:8],
|
|
)
|
|
|
|
# Build WebSocket URL
|
|
# Use X-Forwarded-Proto if behind a reverse proxy (nginx), otherwise use request scheme
|
|
forwarded_proto = request.headers.get("x-forwarded-proto", request.url.scheme)
|
|
host = request.headers.get("host", f"localhost:{settings.port}")
|
|
ws_scheme = "wss" if forwarded_proto == "https" else "ws"
|
|
ws_url = f"{ws_scheme}://{host}/ws/voice?session_id={session.id}"
|
|
|
|
return SessionResponse(
|
|
id=session.id,
|
|
namespace_id=session.namespace_id,
|
|
status=session.status,
|
|
created_at=session.created_at,
|
|
websocket_url=ws_url,
|
|
)
|
|
|
|
|
|
@router.get("/{session_id}", response_model=SessionResponse)
|
|
async def get_session_status(session_id: str, request: Request):
|
|
"""
|
|
Get session status.
|
|
|
|
Returns current session state including message count and pending tasks.
|
|
"""
|
|
session = await get_session(session_id)
|
|
|
|
# Check if session expired
|
|
session_age = datetime.utcnow() - session.created_at
|
|
if session_age > timedelta(hours=settings.session_ttl_hours):
|
|
session.status = SessionStatus.CLOSED
|
|
logger.info("Session expired", session_id=session_id[:8])
|
|
|
|
# Build WebSocket URL
|
|
# Use X-Forwarded-Proto if behind a reverse proxy (nginx), otherwise use request scheme
|
|
forwarded_proto = request.headers.get("x-forwarded-proto", request.url.scheme)
|
|
host = request.headers.get("host", f"localhost:{settings.port}")
|
|
ws_scheme = "wss" if forwarded_proto == "https" else "ws"
|
|
ws_url = f"{ws_scheme}://{host}/ws/voice?session_id={session.id}"
|
|
|
|
return SessionResponse(
|
|
id=session.id,
|
|
namespace_id=session.namespace_id,
|
|
status=session.status,
|
|
created_at=session.created_at,
|
|
websocket_url=ws_url,
|
|
)
|
|
|
|
|
|
@router.delete("/{session_id}")
|
|
async def close_session(session_id: str):
|
|
"""
|
|
Close and delete a session.
|
|
|
|
All transient data (messages, audio state) is discarded.
|
|
This is the expected cleanup path.
|
|
"""
|
|
session = await get_session(session_id)
|
|
|
|
logger.info(
|
|
"Closing session",
|
|
session_id=session_id[:8],
|
|
messages_count=len(session.messages),
|
|
tasks_count=len(session.pending_tasks),
|
|
)
|
|
|
|
# Mark as closed
|
|
session.status = SessionStatus.CLOSED
|
|
|
|
# Remove from active sessions
|
|
del _sessions[session_id]
|
|
|
|
return {"status": "closed", "session_id": session_id}
|
|
|
|
|
|
@router.get("/{session_id}/tasks", response_model=List[TaskResponse])
|
|
async def get_session_tasks(session_id: str, request: Request, state: Optional[TaskState] = None):
|
|
"""
|
|
Get tasks for a session.
|
|
|
|
Optionally filter by task state.
|
|
"""
|
|
session = await get_session(session_id)
|
|
|
|
# Get tasks from the in-memory task store
|
|
from api.tasks import _tasks
|
|
|
|
# Filter tasks by session_id and optionally by state
|
|
tasks = [
|
|
task for task in _tasks.values()
|
|
if task.session_id == session_id
|
|
and (state is None or task.state == state)
|
|
]
|
|
|
|
return [
|
|
TaskResponse(
|
|
id=task.id,
|
|
session_id=task.session_id,
|
|
type=task.type,
|
|
state=task.state,
|
|
created_at=task.created_at,
|
|
updated_at=task.updated_at,
|
|
result_available=task.result_ref is not None,
|
|
error_message=task.error_message,
|
|
)
|
|
for task in tasks
|
|
]
|
|
|
|
|
|
@router.get("/{session_id}/stats")
|
|
async def get_session_stats(session_id: str):
|
|
"""
|
|
Get session statistics (for debugging/monitoring).
|
|
|
|
No PII is returned - only aggregate counts.
|
|
"""
|
|
session = await get_session(session_id)
|
|
|
|
return {
|
|
"session_id_truncated": session_id[:8],
|
|
"status": session.status.value,
|
|
"age_seconds": (datetime.utcnow() - session.created_at).total_seconds(),
|
|
"message_count": len(session.messages),
|
|
"pending_tasks_count": len(session.pending_tasks),
|
|
"audio_chunks_received": session.audio_chunks_received,
|
|
"audio_chunks_processed": session.audio_chunks_processed,
|
|
"device_type": session.device_type,
|
|
}
|