""" FastAPI routes for the Control Generator Pipeline. Endpoints: POST /v1/canonical/generate — Start generation run GET /v1/canonical/generate/status/{job_id} — Job status GET /v1/canonical/generate/jobs — All jobs GET /v1/canonical/generate/review-queue — Controls needing review POST /v1/canonical/generate/review/{control_id} — Complete review GET /v1/canonical/generate/processed-stats — Processing stats per collection GET /v1/canonical/blocked-sources — Blocked sources list POST /v1/canonical/blocked-sources/cleanup — Start cleanup workflow """ import json import logging from typing import Optional, List from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from sqlalchemy import text from database import SessionLocal from compliance.services.control_generator import ( ControlGeneratorPipeline, GeneratorConfig, ALL_COLLECTIONS, ) from compliance.services.rag_client import get_rag_client logger = logging.getLogger(__name__) router = APIRouter(prefix="/v1/canonical", tags=["control-generator"]) # ============================================================================= # REQUEST / RESPONSE MODELS # ============================================================================= class GenerateRequest(BaseModel): domain: Optional[str] = None collections: Optional[List[str]] = None max_controls: int = 50 batch_size: int = 5 skip_web_search: bool = False dry_run: bool = False class GenerateResponse(BaseModel): job_id: str status: str message: str total_chunks_scanned: int = 0 controls_generated: int = 0 controls_verified: int = 0 controls_needs_review: int = 0 controls_too_close: int = 0 controls_duplicates_found: int = 0 errors: list = [] controls: list = [] class ReviewRequest(BaseModel): action: str # "approve", "reject", "needs_rework" release_state: Optional[str] = None # Override release_state notes: Optional[str] = None class ProcessedStats(BaseModel): collection: str total_chunks_estimated: int processed_chunks: int pending_chunks: int direct_adopted: int llm_reformed: int skipped: int class BlockedSourceResponse(BaseModel): id: str regulation_code: str document_title: str reason: str deletion_status: str qdrant_collection: Optional[str] = None marked_at: str # ============================================================================= # ENDPOINTS # ============================================================================= @router.post("/generate", response_model=GenerateResponse) async def start_generation(req: GenerateRequest): """Start a control generation run.""" config = GeneratorConfig( collections=req.collections, domain=req.domain, batch_size=req.batch_size, max_controls=req.max_controls, skip_web_search=req.skip_web_search, dry_run=req.dry_run, ) db = SessionLocal() try: pipeline = ControlGeneratorPipeline(db=db, rag_client=get_rag_client()) result = await pipeline.run(config) return GenerateResponse( job_id=result.job_id, status=result.status, message=f"Generated {result.controls_generated} controls from {result.total_chunks_scanned} chunks", total_chunks_scanned=result.total_chunks_scanned, controls_generated=result.controls_generated, controls_verified=result.controls_verified, controls_needs_review=result.controls_needs_review, controls_too_close=result.controls_too_close, controls_duplicates_found=result.controls_duplicates_found, errors=result.errors, controls=result.controls if req.dry_run else [], ) except Exception as e: logger.error("Generation failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.get("/generate/status/{job_id}") async def get_job_status(job_id: str): """Get status of a generation job.""" db = SessionLocal() try: result = db.execute( text("SELECT * FROM canonical_generation_jobs WHERE id = :id::uuid"), {"id": job_id}, ) row = result.fetchone() if not row: raise HTTPException(status_code=404, detail="Job not found") cols = result.keys() job = dict(zip(cols, row)) # Serialize datetime fields for key in ("started_at", "completed_at", "created_at"): if job.get(key): job[key] = str(job[key]) job["id"] = str(job["id"]) return job finally: db.close() @router.get("/generate/jobs") async def list_jobs( limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0), ): """List all generation jobs.""" db = SessionLocal() try: result = db.execute( text(""" SELECT id, status, total_chunks_scanned, controls_generated, controls_verified, controls_needs_review, controls_too_close, controls_duplicates_found, created_at, completed_at FROM canonical_generation_jobs ORDER BY created_at DESC LIMIT :limit OFFSET :offset """), {"limit": limit, "offset": offset}, ) jobs = [] cols = result.keys() for row in result: job = dict(zip(cols, row)) job["id"] = str(job["id"]) for key in ("created_at", "completed_at"): if job.get(key): job[key] = str(job[key]) jobs.append(job) return {"jobs": jobs, "total": len(jobs)} finally: db.close() @router.get("/generate/review-queue") async def get_review_queue( release_state: str = Query("needs_review", pattern="^(needs_review|too_close|duplicate)$"), limit: int = Query(50, ge=1, le=200), ): """Get controls that need manual review.""" db = SessionLocal() try: result = db.execute( text(""" SELECT c.id, c.control_id, c.title, c.objective, c.severity, c.release_state, c.license_rule, c.customer_visible, c.generation_metadata, c.open_anchors, c.tags, c.created_at FROM canonical_controls c WHERE c.release_state = :state ORDER BY c.created_at DESC LIMIT :limit """), {"state": release_state, "limit": limit}, ) controls = [] cols = result.keys() for row in result: ctrl = dict(zip(cols, row)) ctrl["id"] = str(ctrl["id"]) ctrl["created_at"] = str(ctrl["created_at"]) # Parse JSON fields for jf in ("generation_metadata", "open_anchors", "tags"): if isinstance(ctrl.get(jf), str): try: ctrl[jf] = json.loads(ctrl[jf]) except (json.JSONDecodeError, TypeError): pass controls.append(ctrl) return {"controls": controls, "total": len(controls)} finally: db.close() @router.post("/generate/review/{control_id}") async def review_control(control_id: str, req: ReviewRequest): """Complete review of a generated control.""" db = SessionLocal() try: # Validate control exists and is in reviewable state result = db.execute( text("SELECT id, release_state FROM canonical_controls WHERE control_id = :cid"), {"cid": control_id}, ) row = result.fetchone() if not row: raise HTTPException(status_code=404, detail="Control not found") current_state = row[1] if current_state not in ("needs_review", "too_close", "duplicate"): raise HTTPException(status_code=400, detail=f"Control is in state '{current_state}', not reviewable") # Determine new state if req.action == "approve": new_state = req.release_state or "draft" elif req.action == "reject": new_state = "deprecated" elif req.action == "needs_rework": new_state = "needs_review" else: raise HTTPException(status_code=400, detail=f"Unknown action: {req.action}") if new_state not in ("draft", "review", "approved", "deprecated", "needs_review", "too_close", "duplicate"): raise HTTPException(status_code=400, detail=f"Invalid release_state: {new_state}") db.execute( text(""" UPDATE canonical_controls SET release_state = :state, updated_at = NOW() WHERE control_id = :cid """), {"state": new_state, "cid": control_id}, ) db.commit() return {"control_id": control_id, "release_state": new_state, "action": req.action} finally: db.close() @router.get("/generate/processed-stats") async def get_processed_stats(): """Get processing statistics per collection.""" db = SessionLocal() try: result = db.execute( text(""" SELECT collection, COUNT(*) as processed_chunks, COUNT(*) FILTER (WHERE processing_path = 'structured') as direct_adopted, COUNT(*) FILTER (WHERE processing_path = 'llm_reform') as llm_reformed, COUNT(*) FILTER (WHERE processing_path = 'skipped') as skipped FROM canonical_processed_chunks GROUP BY collection ORDER BY collection """) ) stats = [] cols = result.keys() for row in result: stat = dict(zip(cols, row)) stat["total_chunks_estimated"] = 0 # Would need Qdrant API to get total stat["pending_chunks"] = 0 stats.append(stat) return {"stats": stats} finally: db.close() # ============================================================================= # BLOCKED SOURCES # ============================================================================= @router.get("/blocked-sources") async def list_blocked_sources(): """List all blocked (Rule 3) sources.""" db = SessionLocal() try: result = db.execute( text(""" SELECT id, regulation_code, document_title, reason, deletion_status, qdrant_collection, marked_at FROM canonical_blocked_sources ORDER BY marked_at DESC """) ) sources = [] cols = result.keys() for row in result: src = dict(zip(cols, row)) src["id"] = str(src["id"]) src["marked_at"] = str(src["marked_at"]) sources.append(src) return {"sources": sources} finally: db.close() @router.post("/blocked-sources/cleanup") async def start_cleanup(): """Start cleanup workflow for blocked sources. This marks all pending blocked sources for deletion. Actual RAG chunk deletion and file removal is a separate manual step. """ db = SessionLocal() try: result = db.execute( text(""" UPDATE canonical_blocked_sources SET deletion_status = 'marked_for_deletion' WHERE deletion_status = 'pending' RETURNING regulation_code """) ) marked = [row[0] for row in result] db.commit() return { "status": "marked_for_deletion", "marked_count": len(marked), "regulation_codes": marked, "message": "Sources marked for deletion. Run manual cleanup to remove RAG chunks and files.", } finally: db.close() # ============================================================================= # CUSTOMER VIEW FILTER # ============================================================================= @router.get("/controls-customer") async def get_controls_customer_view( severity: Optional[str] = Query(None), domain: Optional[str] = Query(None), ): """Get controls filtered for customer visibility. Rule 3 controls have source_citation and source_original_text hidden. generation_metadata is NEVER shown to customers. """ db = SessionLocal() try: query = """ SELECT c.id, c.control_id, c.title, c.objective, c.rationale, c.scope, c.requirements, c.test_procedure, c.evidence, c.severity, c.risk_score, c.implementation_effort, c.open_anchors, c.release_state, c.tags, c.license_rule, c.customer_visible, c.source_original_text, c.source_citation, c.created_at, c.updated_at FROM canonical_controls c WHERE c.release_state IN ('draft', 'approved') """ params: dict = {} if severity: query += " AND c.severity = :severity" params["severity"] = severity if domain: query += " AND c.control_id LIKE :domain" params["domain"] = f"{domain.upper()}-%" query += " ORDER BY c.control_id" result = db.execute(text(query), params) controls = [] cols = result.keys() for row in result: ctrl = dict(zip(cols, row)) ctrl["id"] = str(ctrl["id"]) for key in ("created_at", "updated_at"): if ctrl.get(key): ctrl[key] = str(ctrl[key]) # Parse JSON fields for jf in ("scope", "requirements", "test_procedure", "evidence", "open_anchors", "tags", "source_citation"): if isinstance(ctrl.get(jf), str): try: ctrl[jf] = json.loads(ctrl[jf]) except (json.JSONDecodeError, TypeError): pass # Customer visibility rules: # - NEVER show generation_metadata # - Rule 3: NEVER show source_citation or source_original_text ctrl.pop("generation_metadata", None) if not ctrl.get("customer_visible", True): ctrl["source_citation"] = None ctrl["source_original_text"] = None controls.append(ctrl) return {"controls": controls, "total": len(controls)} finally: db.close()