diff --git a/control-pipeline/api/control_generator_routes.py b/control-pipeline/api/control_generator_routes.py index 13d420d..e1a0df4 100644 --- a/control-pipeline/api/control_generator_routes.py +++ b/control-pipeline/api/control_generator_routes.py @@ -2293,6 +2293,67 @@ async def get_batch_process_status(job_id: str): return status +class RunPass0aRequest(BaseModel): + limit: int = 0 # 0 = no limit + batch_size: int = 5 + use_anthropic: bool = True + category_filter: Optional[str] = None + source_filter: Optional[str] = None + + +_pass0a_status: dict = {} + + +async def _run_pass0a_background(req: RunPass0aRequest, job_id: str): + """Run Pass 0a in background with own DB session.""" + from services.decomposition_pass import DecompositionPass + db = SessionLocal() + try: + _pass0a_status[job_id] = {"status": "running"} + dp = DecompositionPass(db) + result = await dp.run_pass0a( + limit=req.limit, + batch_size=req.batch_size, + use_anthropic=req.use_anthropic, + category_filter=req.category_filter, + source_filter=req.source_filter, + ) + _pass0a_status[job_id] = {"status": "completed", **result} + logger.info("Pass 0a job %s completed: %s", job_id, result) + except Exception as e: + logger.error("Pass 0a job %s failed: %s", job_id, e) + _pass0a_status[job_id] = {"status": "failed", "error": str(e)} + finally: + db.close() + + +@router.post("/generate/run-pass0a") +async def run_pass0a(req: RunPass0aRequest): + """Run Pass 0a (Obligation Extraction) on undecomposed controls. + + Extracts individual normative obligations from rich controls using LLM. + Runs in background — poll status via GET /generate/pass0a-status/{job_id}. + """ + import uuid + job_id = str(uuid.uuid4())[:8] + _pass0a_status[job_id] = {"status": "starting"} + asyncio.create_task(_run_pass0a_background(req, job_id)) + return { + "status": "running", + "job_id": job_id, + "message": f"Pass 0a started. Poll /generate/pass0a-status/{job_id}", + } + + +@router.get("/generate/pass0a-status/{job_id}") +async def get_pass0a_status(job_id: str): + """Get status of a Pass 0a job.""" + status = _pass0a_status.get(job_id) + if not status: + raise HTTPException(status_code=404, detail="Pass 0a job not found") + return status + + class SubmitPass0bRequest(BaseModel): limit: int = 10 batch_size: int = 5