""" Enhanced Task Orchestrator - Multi-Agent Integration Extends the existing TaskOrchestrator with Multi-Agent support: - Session management with checkpoints - Message bus integration for inter-agent communication - Quality judge integration via BQAS - Heartbeat-based liveness """ import structlog import asyncio from typing import Optional, Dict, Any from datetime import datetime from services.task_orchestrator import TaskOrchestrator, Intent from models.task import Task, TaskState # Import agent-core components import sys sys.path.insert(0, '/Users/benjaminadmin/Projekte/breakpilot-pwa/agent-core') from sessions.session_manager import SessionManager, AgentSession, SessionState from sessions.heartbeat import HeartbeatMonitor, HeartbeatClient from brain.memory_store import MemoryStore from brain.context_manager import ContextManager, MessageRole from orchestrator.message_bus import MessageBus, AgentMessage, MessagePriority from orchestrator.task_router import TaskRouter, RoutingStrategy logger = structlog.get_logger(__name__) class EnhancedTaskOrchestrator(TaskOrchestrator): """ Enhanced TaskOrchestrator with Multi-Agent support. Extends the existing TaskOrchestrator to integrate with: - Session management for persistence and recovery - Message bus for inter-agent communication - Quality judge for response validation - Memory store for long-term learning """ def __init__( self, redis_client=None, db_pool=None, namespace: str = "breakpilot" ): """ Initialize the enhanced orchestrator. Args: redis_client: Async Redis/Valkey client db_pool: Async PostgreSQL connection pool namespace: Namespace for isolation """ super().__init__() # Initialize agent-core components self.session_manager = SessionManager( redis_client=redis_client, db_pool=db_pool, namespace=namespace ) self.memory_store = MemoryStore( redis_client=redis_client, db_pool=db_pool, namespace=namespace ) self.context_manager = ContextManager( redis_client=redis_client, db_pool=db_pool, namespace=namespace ) self.message_bus = MessageBus( redis_client=redis_client, db_pool=db_pool, namespace=namespace ) self.heartbeat = HeartbeatMonitor( timeout_seconds=30, check_interval_seconds=5, max_missed_beats=3 ) self.task_router = TaskRouter() # Track active sessions by voice session ID self._voice_sessions: Dict[str, AgentSession] = {} self._heartbeat_clients: Dict[str, HeartbeatClient] = {} logger.info("Enhanced TaskOrchestrator initialized with agent-core") async def start(self) -> None: """Starts the enhanced orchestrator""" await self.message_bus.start() await self.heartbeat.start_monitoring() # Subscribe to messages directed at this orchestrator await self.message_bus.subscribe( "voice-orchestrator", self._handle_agent_message ) logger.info("Enhanced TaskOrchestrator started") async def stop(self) -> None: """Stops the enhanced orchestrator""" # Stop all heartbeat clients for client in self._heartbeat_clients.values(): await client.stop() self._heartbeat_clients.clear() await self.heartbeat.stop_monitoring() await self.message_bus.stop() logger.info("Enhanced TaskOrchestrator stopped") async def create_session( self, voice_session_id: str, user_id: str = "", metadata: Optional[Dict[str, Any]] = None ) -> AgentSession: """ Creates a new agent session for a voice session. Args: voice_session_id: The voice session ID user_id: Optional user ID metadata: Additional metadata Returns: The created AgentSession """ # Create session via session manager session = await self.session_manager.create_session( agent_type="voice-orchestrator", user_id=user_id, context={"voice_session_id": voice_session_id}, metadata=metadata ) # Create conversation context self.context_manager.create_context( session_id=session.session_id, system_prompt=self._get_system_prompt(), max_messages=50 ) # Start heartbeat for this session heartbeat_client = HeartbeatClient( session_id=session.session_id, monitor=self.heartbeat, interval_seconds=10 ) await heartbeat_client.start() # Register heartbeat for monitoring self.heartbeat.register(session.session_id, "voice-orchestrator") # Store references self._voice_sessions[voice_session_id] = session self._heartbeat_clients[session.session_id] = heartbeat_client logger.info( "Created agent session", session_id=session.session_id[:8], voice_session_id=voice_session_id ) return session async def get_session( self, voice_session_id: str ) -> Optional[AgentSession]: """Gets the agent session for a voice session""" return self._voice_sessions.get(voice_session_id) async def end_session(self, voice_session_id: str) -> None: """ Ends an agent session. Args: voice_session_id: The voice session ID """ session = self._voice_sessions.get(voice_session_id) if not session: return # Stop heartbeat if session.session_id in self._heartbeat_clients: await self._heartbeat_clients[session.session_id].stop() del self._heartbeat_clients[session.session_id] # Unregister from heartbeat monitor self.heartbeat.unregister(session.session_id) # Mark session as completed session.complete() await self.session_manager.update_session(session) # Clean up del self._voice_sessions[voice_session_id] logger.info( "Ended agent session", session_id=session.session_id[:8], duration_seconds=session.get_duration().total_seconds() ) async def queue_task(self, task: Task) -> None: """ Queue a task with session checkpointing. Extends parent to add checkpoint for recovery. """ # Get session for this task session = self._voice_sessions.get(task.session_id) if session: # Checkpoint before queueing session.checkpoint("task_queued", { "task_id": task.id, "task_type": task.type.value, "parameters": task.parameters }) await self.session_manager.update_session(session) # Call parent implementation await super().queue_task(task) async def process_task(self, task: Task) -> None: """ Process a task with enhanced routing and quality checks. Extends parent to: - Route complex tasks to specialized agents - Run quality checks via BQAS - Store results in memory for learning """ session = self._voice_sessions.get(task.session_id) if session: session.checkpoint("task_processing", { "task_id": task.id }) # Check if this task should be routed to a specialized agent if self._needs_specialized_agent(task): await self._route_to_agent(task, session) else: # Use parent implementation for simple tasks await super().process_task(task) # Run quality check on result if task.result_ref and self._needs_quality_check(task): await self._run_quality_check(task, session) # Store in memory for learning if task.state == TaskState.READY and task.result_ref: await self._store_task_result(task) if session: session.checkpoint("task_completed", { "task_id": task.id, "state": task.state.value }) await self.session_manager.update_session(session) def _needs_specialized_agent(self, task: Task) -> bool: """Check if task needs routing to a specialized agent""" from models.task import TaskType # Tasks that benefit from specialized agents specialized_types = [ TaskType.PARENT_LETTER, # Could use grader for tone TaskType.FEEDBACK_SUGGEST, # Quality judge for appropriateness ] return task.type in specialized_types def _needs_quality_check(self, task: Task) -> bool: """Check if task result needs quality validation""" from models.task import TaskType # Tasks that generate content should be checked content_types = [ TaskType.PARENT_LETTER, TaskType.CLASS_MESSAGE, TaskType.FEEDBACK_SUGGEST, TaskType.WORKSHEET_GENERATE, ] return task.type in content_types async def _route_to_agent( self, task: Task, session: Optional[AgentSession] ) -> None: """Routes a task to a specialized agent""" # Determine target agent intent = f"task_{task.type.value}" routing_result = await self.task_router.route( intent=intent, context={"task": task.parameters}, strategy=RoutingStrategy.LEAST_LOADED ) if not routing_result.success: # Fall back to local processing logger.warning( "No agent available for task, using local processing", task_id=task.id[:8], reason=routing_result.reason ) await super().process_task(task) return # Send to agent via message bus try: response = await self.message_bus.request( AgentMessage( sender="voice-orchestrator", receiver=routing_result.agent_id, message_type=f"process_{task.type.value}", payload={ "task_id": task.id, "task_type": task.type.value, "parameters": task.parameters, "session_id": session.session_id if session else None }, priority=MessagePriority.NORMAL ), timeout=30.0 ) task.result_ref = response.get("result", "") task.transition_to(TaskState.READY, "agent_processed") except asyncio.TimeoutError: logger.error( "Agent timeout, falling back to local", task_id=task.id[:8], agent=routing_result.agent_id ) await super().process_task(task) async def _run_quality_check( self, task: Task, session: Optional[AgentSession] ) -> None: """Runs quality check on task result via quality judge""" try: response = await self.message_bus.request( AgentMessage( sender="voice-orchestrator", receiver="quality-judge", message_type="evaluate_response", payload={ "task_id": task.id, "task_type": task.type.value, "response": task.result_ref, "context": task.parameters }, priority=MessagePriority.NORMAL ), timeout=10.0 ) quality_score = response.get("composite_score", 0) if quality_score < 60: # Mark for review task.error_message = f"Quality check failed: {quality_score}" logger.warning( "Task failed quality check", task_id=task.id[:8], score=quality_score ) except asyncio.TimeoutError: # Quality check timeout is non-fatal logger.warning( "Quality check timeout", task_id=task.id[:8] ) async def _store_task_result(self, task: Task) -> None: """Stores task result in memory for learning""" await self.memory_store.remember( key=f"task:{task.type.value}:{task.id}", value={ "result": task.result_ref, "parameters": task.parameters, "completed_at": datetime.utcnow().isoformat() }, agent_id="voice-orchestrator", ttl_days=30 ) async def _handle_agent_message( self, message: AgentMessage ) -> Optional[Dict[str, Any]]: """Handles incoming messages from other agents""" logger.debug( "Received agent message", sender=message.sender, type=message.message_type ) if message.message_type == "task_status_update": # Handle task status updates task_id = message.payload.get("task_id") if task_id in self._tasks: task = self._tasks[task_id] new_state = message.payload.get("state") if new_state: task.transition_to(TaskState(new_state), "agent_update") return None def _get_system_prompt(self) -> str: """Returns the system prompt for the voice assistant""" return """Du bist ein hilfreicher Assistent für Lehrer in der Breakpilot-App. Deine Aufgaben: - Hilf beim Erstellen von Arbeitsblättern - Unterstütze bei der Korrektur - Erstelle Elternbriefe und Klassennachrichten - Dokumentiere Beobachtungen und Erinnerungen Halte dich kurz und präzise. Nutze einfache, klare Sprache. Bei Unklarheiten frage nach.""" # Recovery methods async def recover_session( self, voice_session_id: str, session_id: str ) -> Optional[AgentSession]: """ Recovers a session from checkpoint. Args: voice_session_id: The voice session ID session_id: The agent session ID to recover Returns: The recovered session or None """ session = await self.session_manager.get_session(session_id) if not session: logger.warning( "Session not found for recovery", session_id=session_id ) return None if session.state != SessionState.ACTIVE: logger.warning( "Session not active for recovery", session_id=session_id, state=session.state.value ) return None # Resume session session.resume() # Restore heartbeat heartbeat_client = HeartbeatClient( session_id=session.session_id, monitor=self.heartbeat, interval_seconds=10 ) await heartbeat_client.start() self.heartbeat.register(session.session_id, "voice-orchestrator") # Store references self._voice_sessions[voice_session_id] = session self._heartbeat_clients[session.session_id] = heartbeat_client # Recover pending tasks from checkpoints await self._recover_pending_tasks(session) logger.info( "Recovered session", session_id=session.session_id[:8], checkpoints=len(session.checkpoints) ) return session async def _recover_pending_tasks(self, session: AgentSession) -> None: """Recovers pending tasks from session checkpoints""" for checkpoint in reversed(session.checkpoints): if checkpoint.name == "task_queued": task_id = checkpoint.data.get("task_id") if task_id and task_id in self._tasks: task = self._tasks[task_id] if task.state == TaskState.QUEUED: # Re-process queued task await self.process_task(task) logger.info( "Recovered pending task", task_id=task_id[:8] )