"""SSE-Endpoint für den Agent-Test-Harness. User-Vorgabe 2026-06-08: pro Agent isoliert testen mit z.B. 5 URLs gleichzeitig. Live-Stream der Events ins Frontend. Endpoints: GET /specialist-agent/agents POST /specialist-agent/test/start { agent_id, urls } GET /specialist-agent/test/stream/{run_id} → SSE-Stream GET /specialist-agent/run/{run_id}/artifacts GET /specialist-agent/run/{run_id}/artifact/{relpath} """ from __future__ import annotations import asyncio import html as html_lib import json import logging import uuid from collections.abc import AsyncGenerator from typing import Any from fastapi import APIRouter, HTTPException from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel, Field from compliance.api.agent_check._fetch import _fetch_text as full_fetch_text from compliance.services.specialist_agents import REGISTRY, AgentInput from compliance.services.specialist_agents._evidence_vault import ( EvidenceVault, delete_run as vault_delete_run, list_runs as vault_list_runs, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/specialist-agent", tags=["specialist-agent"]) # In-memory event-queues pro run_id. Restart-fragil aber für ein # Live-Test-Tool ausreichend (keine Persistenz nötig). _run_queues: dict[str, asyncio.Queue] = {} _run_states: dict[str, dict[str, Any]] = {} class TestStartRequest(BaseModel): agent_id: str urls: list[str] = Field(default_factory=list, max_length=10) raw_texts: list[str] = Field(default_factory=list, max_length=10) business_scope: list[str] = Field(default_factory=list) company_name: str = "" origin_domain: str = "" class TestStartResponse(BaseModel): run_id: str agent_id: str slot_count: int @router.get("/agents") async def list_agents() -> dict[str, Any]: """Liefert die registrierten Specialist-Agenten.""" return {"agents": REGISTRY.list_agents()} @router.post("/test/start", response_model=TestStartResponse) async def start_test(req: TestStartRequest) -> TestStartResponse: """Startet einen Multi-URL-Test gegen einen Agent. Liefert eine run_id zurück. Der Frontend-Client öffnet danach einen SSE-Stream auf /test/stream/{run_id} um Events zu empfangen. """ agent = REGISTRY.get(req.agent_id) if agent is None: raise HTTPException(404, f"agent '{req.agent_id}' nicht registriert") slots = max(len(req.urls), len(req.raw_texts)) if slots == 0: raise HTTPException(400, "urls oder raw_texts dürfen nicht leer sein") run_id = uuid.uuid4().hex[:16] queue: asyncio.Queue = asyncio.Queue(maxsize=500) _run_queues[run_id] = queue _run_states[run_id] = { "agent_id": req.agent_id, "started": False, "finished": False, "slot_count": slots, "results": {}, } vault = EvidenceVault(agent.agent_id, agent.agent_version, run_id=run_id) asyncio.create_task(_run_test_orchestrator(run_id, req, vault)) return TestStartResponse( run_id=run_id, agent_id=req.agent_id, slot_count=slots, ) @router.get("/test/stream/{run_id}") async def stream_test(run_id: str) -> StreamingResponse: """SSE-Stream der Events für einen laufenden Test.""" if run_id not in _run_queues: raise HTTPException(404, "run_id unbekannt") return StreamingResponse( _event_generator(run_id), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # nginx "Connection": "keep-alive", }, ) async def _event_generator(run_id: str) -> AsyncGenerator[str, None]: """Reads events from the queue until the run is finished.""" queue = _run_queues[run_id] # Initial hello yield _format_sse({"type": "hello", "run_id": run_id}) try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=30.0) except asyncio.TimeoutError: # heartbeat yield _format_sse({"type": "heartbeat"}) if _run_states.get(run_id, {}).get("finished"): yield _format_sse({"type": "stream_close"}) return continue yield _format_sse(event) if event.get("type") in ("run_complete", "run_error"): yield _format_sse({"type": "stream_close"}) return finally: # Defer cleanup: keep state for 5 min so late GETs can read # results from _run_states. The queue can be released earlier. asyncio.get_event_loop().call_later( 300, lambda: _run_queues.pop(run_id, None), ) def _format_sse(payload: dict) -> str: """SSE event line format.""" return f"data: {json.dumps(payload, default=str)}\n\n" async def _emit(run_id: str, event: dict) -> None: q = _run_queues.get(run_id) if q is None: return try: await q.put(event) except Exception: pass async def _run_test_orchestrator( run_id: str, req: TestStartRequest, vault: EvidenceVault, ) -> None: """Kernlogik: pro URL / raw_text parallel den Agent feuern.""" agent = REGISTRY.get(req.agent_id) if agent is None: await _emit(run_id, {"type": "run_error", "error": "agent gone"}) return _run_states[run_id]["started"] = True await _emit(run_id, { "type": "run_started", "agent_id": agent.agent_id, "agent_version": agent.agent_version, "slot_count": _run_states[run_id]["slot_count"], }) slot_jobs: list[asyncio.Task] = [] # URLs first, then raw_texts. Slots numbered url1, url2, …, text1, … for i, url in enumerate(req.urls, start=1): slot = f"url{i}" slot_jobs.append(asyncio.create_task( _process_slot(run_id, slot, agent, url, "", req, vault), )) for j, raw in enumerate(req.raw_texts, start=1): slot = f"text{j}" slot_jobs.append(asyncio.create_task( _process_slot(run_id, slot, agent, "", raw, req, vault), )) try: await asyncio.gather(*slot_jobs, return_exceptions=True) finally: manifest = vault.finalize() _run_states[run_id]["finished"] = True await _emit(run_id, { "type": "run_complete", "vault_url": vault.url(), "manifest_asset_count": len(manifest.get("assets") or []), }) async def _process_slot( run_id: str, slot: str, agent, url: str, raw_text: str, req: TestStartRequest, vault: EvidenceVault, ) -> None: """Holt den Text (URL oder raw), ruft Agent, vault-speichert Output. Nutzt für den URL-Fetch die VOLLE Compliance-Check-Pipeline (_fetch_text aus _fetch.py): 240s Playwright-Discovery + HTTP- Fallback mit Browser-UA + Multi-Page-Merge + CMP-Capture. """ label = url or f"text-slot-{slot}" await _emit(run_id, {"type": "slot_started", "slot": slot, "label": label}) text = raw_text fetch_err = "" cmp_payloads: list[dict] = [] if url and not raw_text: await _emit(run_id, {"type": "slot_fetching", "slot": slot, "url": url, "doc_type": agent.doc_type}) try: text, cmp_payloads = await full_fetch_text( url, doc_type=agent.doc_type, ) except Exception as e: fetch_err = f"{type(e).__name__}: {str(e)[:160]}" text = "" if not text and not fetch_err: fetch_err = ( "Fetch lieferte 0 Zeichen — Site möglicherweise " "Cloudflare-/Anti-Bot-geschützt oder JS-only-Rendering. " "Tipp: Text manuell ins raw_text-Feld einfügen." ) if fetch_err: await _emit(run_id, { "type": "slot_fetch_error", "slot": slot, "error": fetch_err, }) if text: # HTML-Entity-Decode: dsi-discovery liefert manchmal   / & # / ä als Literal-String — der Agent würde regex-pattern # darüber stolpern. Wir decoden VOR dem Vault-Dump so dass der # raw_text auch lesbar bleibt. text = html_lib.unescape(text) vault.put_bytes("raw", slot, "source.txt", text.encode("utf-8"), mime="text/plain") if cmp_payloads: vault.put_json("raw", slot, "cmp_payloads.json", cmp_payloads) await _emit(run_id, { "type": "slot_text_ready", "slot": slot, "char_count": len(text), "word_count": len(text.split()) if text else 0, "cmp_payloads": len(cmp_payloads), }) agent_input = AgentInput( doc_type=agent.doc_type, text=text, url=url, business_scope=req.business_scope, company_name=req.company_name, origin_domain=req.origin_domain, context={"cmp_payloads": cmp_payloads} if cmp_payloads else {}, ) await _emit(run_id, {"type": "slot_agent_running", "slot": slot}) try: output = await agent.evaluate(agent_input) except Exception as e: logger.exception("agent crashed slot=%s", slot) await _emit(run_id, { "type": "slot_agent_error", "slot": slot, "error": f"{type(e).__name__}: {str(e)[:160]}", }) return # Wenn Fetch fail war: füge die Fehlermeldung an die notes des Output if fetch_err and not text: output.notes = ( (output.notes + " · " if output.notes else "") + f"fetch_error: {fetch_err}" ) # Persist findings as JSON in vault vault.put_json("finding", slot, "output.json", json.loads(output.model_dump_json())) # Update state for later /artifacts query _run_states[run_id]["results"][slot] = json.loads( output.model_dump_json(), ) # Stream finding-emitted events for f in output.findings: await _emit(run_id, { "type": "finding", "slot": slot, "check_id": f.check_id, "severity": f.severity, "title": f.title, "field_id": f.field_id, }) for esc in output.escalation_log: await _emit(run_id, { "type": "escalation", "slot": slot, "stage": esc.stage, "model": esc.model, "success": esc.success, "duration_ms": esc.duration_ms, }) await _emit(run_id, { "type": "slot_complete", "slot": slot, "duration_ms": output.duration_ms, "mc_total": output.mc_total, "mc_ok": output.mc_ok, "mc_na": output.mc_na, "mc_high": output.mc_high, "mc_medium": output.mc_medium, "mc_low": output.mc_low, "findings_count": len(output.findings), "recommendations_count": len(output.recommendations), "confidence": output.confidence, }) # ── Run / Vault Queries ────────────────────────────────────────────── @router.get("/run/{run_id}/result") async def get_run_result(run_id: str) -> dict[str, Any]: """Komplette Ergebnisse eines Runs (für Frontend-Refresh).""" state = _run_states.get(run_id) if state is None: raise HTTPException(404, "run unbekannt") return { "run_id": run_id, "agent_id": state["agent_id"], "finished": state["finished"], "results": state["results"], "vault_url": f"/api/v1/specialist-agent/run/{run_id}/artifacts", } @router.get("/run/{run_id}/artifacts") async def list_run_artifacts(run_id: str) -> dict[str, Any]: """Listet die Assets eines Runs.""" vault = EvidenceVault("?", "?", run_id=run_id) return { "run_id": run_id, "manifest": vault._manifest, } @router.get("/run/{run_id}/artifact/{path:path}") async def get_run_artifact(run_id: str, path: str): """Liefert ein einzelnes Artefakt aus dem Vault.""" vault = EvidenceVault("?", "?", run_id=run_id) p = vault.asset_path(path) if p is None: raise HTTPException(404, "asset not found") return FileResponse(str(p)) @router.delete("/run/{run_id}") async def delete_run(run_id: str) -> dict[str, bool]: """DSR Art. 17: löscht den ganzen Run + Vault.""" deleted_vault = vault_delete_run(run_id) _run_queues.pop(run_id, None) _run_states.pop(run_id, None) return {"deleted": deleted_vault} @router.get("/runs") async def list_runs(limit: int = 20) -> dict[str, Any]: """Listet die letzten Runs im Vault.""" return {"runs": vault_list_runs(limit)}