""" Middleware Admin API Endpoints for BreakPilot Provides admin functionality for managing middleware configurations: - View and update middleware settings - Rate limiting IP whitelist/blacklist - View middleware events/statistics """ import os from datetime import datetime, timedelta from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field # Database connection import asyncpg # Session middleware for authentication from session import require_permission, Session router = APIRouter(prefix="/api/admin/middleware", tags=["middleware-admin"]) # Database URL DATABASE_URL = os.getenv( "DATABASE_URL", "postgresql://breakpilot:breakpilot@localhost:5432/breakpilot_dev" ) # Lazy database pool _db_pool: Optional[asyncpg.Pool] = None async def get_db_pool() -> asyncpg.Pool: """Get or create database connection pool.""" global _db_pool if _db_pool is None: _db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) return _db_pool # ============================================== # Request/Response Models # ============================================== class MiddlewareConfigResponse(BaseModel): """Response model for middleware configuration.""" id: str middleware_name: str enabled: bool config: Dict[str, Any] updated_at: Optional[datetime] = None class MiddlewareConfigUpdateRequest(BaseModel): """Request model for updating middleware configuration.""" enabled: Optional[bool] = None config: Optional[Dict[str, Any]] = None class RateLimitIPRequest(BaseModel): """Request model for adding IP to whitelist/blacklist.""" ip_address: str list_type: str = Field(..., pattern="^(whitelist|blacklist)$") reason: Optional[str] = None expires_at: Optional[datetime] = None class RateLimitIPResponse(BaseModel): """Response model for rate limit IP entry.""" id: str ip_address: str list_type: str reason: Optional[str] = None expires_at: Optional[datetime] = None created_at: datetime class MiddlewareEventResponse(BaseModel): """Response model for middleware event.""" id: str middleware_name: str event_type: str ip_address: Optional[str] = None user_id: Optional[str] = None request_path: Optional[str] = None request_method: Optional[str] = None details: Optional[Dict[str, Any]] = None created_at: datetime class MiddlewareStatsResponse(BaseModel): """Response model for middleware statistics.""" middleware_name: str total_events: int events_last_hour: int events_last_24h: int top_event_types: List[Dict[str, Any]] top_ips: List[Dict[str, Any]] class SDKAnomalyScoreResponse(BaseModel): """Response model for SDK anomaly score.""" id: str user_id: str score: float throttle_level: int triggered_rules: List[str] endpoint_diversity_count: int request_count_1h: int snapshot_at: datetime class SDKProtectionStatsResponse(BaseModel): """Response model for SDK protection statistics.""" total_users_tracked: int users_level_0: int users_level_1: int users_level_2: int users_level_3: int avg_score: float max_score: float class SDKProtectionTierResponse(BaseModel): """Response model for SDK protection tier.""" tier_name: str quota_per_minute: int quota_per_hour: int quota_per_day: int quota_per_month: int diversity_threshold: int burst_threshold: int class SDKProtectionTierUpdateRequest(BaseModel): """Request model for updating SDK protection tier.""" quota_per_minute: Optional[int] = None quota_per_hour: Optional[int] = None quota_per_day: Optional[int] = None quota_per_month: Optional[int] = None diversity_threshold: Optional[int] = None burst_threshold: Optional[int] = None # ============================================== # Middleware Configuration Endpoints # ============================================== @router.get("", response_model=List[MiddlewareConfigResponse]) async def list_middleware_configs( session: Session = Depends(require_permission("settings:read")), ): """ List all middleware configurations. Requires: settings:read permission """ pool = await get_db_pool() rows = await pool.fetch(""" SELECT id, middleware_name, enabled, config, updated_at FROM middleware_config ORDER BY middleware_name """) return [ MiddlewareConfigResponse( id=str(row["id"]), middleware_name=row["middleware_name"], enabled=row["enabled"], config=row["config"] or {}, updated_at=row["updated_at"], ) for row in rows ] @router.get("/{name}", response_model=MiddlewareConfigResponse) async def get_middleware_config( name: str, session: Session = Depends(require_permission("settings:read")), ): """ Get configuration for a specific middleware. Requires: settings:read permission """ pool = await get_db_pool() row = await pool.fetchrow(""" SELECT id, middleware_name, enabled, config, updated_at FROM middleware_config WHERE middleware_name = $1 """, name) if not row: raise HTTPException(status_code=404, detail=f"Middleware '{name}' not found") return MiddlewareConfigResponse( id=str(row["id"]), middleware_name=row["middleware_name"], enabled=row["enabled"], config=row["config"] or {}, updated_at=row["updated_at"], ) @router.put("/{name}", response_model=MiddlewareConfigResponse) async def update_middleware_config( name: str, data: MiddlewareConfigUpdateRequest, session: Session = Depends(require_permission("settings:write")), ): """ Update configuration for a specific middleware. Requires: settings:write permission """ pool = await get_db_pool() # Build update query dynamically updates = [] params = [name] param_idx = 2 if data.enabled is not None: updates.append(f"enabled = ${param_idx}") params.append(data.enabled) param_idx += 1 if data.config is not None: updates.append(f"config = ${param_idx}") params.append(data.config) param_idx += 1 if not updates: raise HTTPException(status_code=400, detail="No fields to update") updates.append("updated_at = NOW()") updates.append(f"updated_by = ${param_idx}") params.append(session.user_id) query = f""" UPDATE middleware_config SET {", ".join(updates)} WHERE middleware_name = $1 RETURNING id, middleware_name, enabled, config, updated_at """ row = await pool.fetchrow(query, *params) if not row: raise HTTPException(status_code=404, detail=f"Middleware '{name}' not found") # Log the configuration change await pool.execute(""" INSERT INTO middleware_events (middleware_name, event_type, user_id, details) VALUES ($1, 'config_changed', $2, $3) """, name, session.user_id, {"changes": data.dict(exclude_none=True)}) return MiddlewareConfigResponse( id=str(row["id"]), middleware_name=row["middleware_name"], enabled=row["enabled"], config=row["config"] or {}, updated_at=row["updated_at"], ) # ============================================== # Rate Limiting IP Management # ============================================== @router.get("/rate-limit/ip-list", response_model=List[RateLimitIPResponse]) async def list_rate_limit_ips( list_type: Optional[str] = Query(None, pattern="^(whitelist|blacklist)$"), session: Session = Depends(require_permission("settings:read")), ): """ List all IPs in whitelist/blacklist. Requires: settings:read permission """ pool = await get_db_pool() if list_type: rows = await pool.fetch(""" SELECT id, ip_address::text, list_type, reason, expires_at, created_at FROM rate_limit_ip_list WHERE list_type = $1 ORDER BY created_at DESC """, list_type) else: rows = await pool.fetch(""" SELECT id, ip_address::text, list_type, reason, expires_at, created_at FROM rate_limit_ip_list ORDER BY list_type, created_at DESC """) return [ RateLimitIPResponse( id=str(row["id"]), ip_address=row["ip_address"], list_type=row["list_type"], reason=row["reason"], expires_at=row["expires_at"], created_at=row["created_at"], ) for row in rows ] @router.post("/rate-limit/ip-list", response_model=RateLimitIPResponse, status_code=201) async def add_rate_limit_ip( data: RateLimitIPRequest, session: Session = Depends(require_permission("settings:write")), ): """ Add IP to whitelist or blacklist. Requires: settings:write permission """ pool = await get_db_pool() try: row = await pool.fetchrow(""" INSERT INTO rate_limit_ip_list (ip_address, list_type, reason, expires_at, created_by) VALUES ($1::inet, $2, $3, $4, $5) RETURNING id, ip_address::text, list_type, reason, expires_at, created_at """, data.ip_address, data.list_type, data.reason, data.expires_at, session.user_id) except asyncpg.UniqueViolationError: raise HTTPException( status_code=409, detail=f"IP {data.ip_address} already exists in {data.list_type}" ) # Log the event await pool.execute(""" INSERT INTO middleware_events (middleware_name, event_type, ip_address, user_id, details) VALUES ('rate_limiter', $1, $2::inet, $3, $4) """, f"ip_{data.list_type}_add", data.ip_address, session.user_id, {"reason": data.reason}) return RateLimitIPResponse( id=str(row["id"]), ip_address=row["ip_address"], list_type=row["list_type"], reason=row["reason"], expires_at=row["expires_at"], created_at=row["created_at"], ) @router.delete("/rate-limit/ip-list/{ip_id}") async def remove_rate_limit_ip( ip_id: str, session: Session = Depends(require_permission("settings:write")), ): """ Remove IP from whitelist/blacklist. Requires: settings:write permission """ pool = await get_db_pool() # Get the entry first for logging row = await pool.fetchrow(""" SELECT ip_address::text, list_type FROM rate_limit_ip_list WHERE id = $1 """, ip_id) if not row: raise HTTPException(status_code=404, detail="IP entry not found") await pool.execute(""" DELETE FROM rate_limit_ip_list WHERE id = $1 """, ip_id) # Log the event await pool.execute(""" INSERT INTO middleware_events (middleware_name, event_type, ip_address, user_id, details) VALUES ('rate_limiter', $1, $2::inet, $3, $4) """, f"ip_{row['list_type']}_remove", row["ip_address"], session.user_id, {}) return {"message": "IP removed successfully"} # ============================================== # Middleware Events & Statistics # ============================================== @router.get("/events", response_model=List[MiddlewareEventResponse]) async def list_middleware_events( middleware_name: Optional[str] = None, event_type: Optional[str] = None, limit: int = Query(100, le=1000), offset: int = 0, session: Session = Depends(require_permission("audit:read")), ): """ List middleware events (rate limit triggers, config changes, etc.). Requires: audit:read permission """ pool = await get_db_pool() conditions = [] params = [] param_idx = 1 if middleware_name: conditions.append(f"middleware_name = ${param_idx}") params.append(middleware_name) param_idx += 1 if event_type: conditions.append(f"event_type = ${param_idx}") params.append(event_type) param_idx += 1 where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" params.extend([limit, offset]) query = f""" SELECT id, middleware_name, event_type, ip_address::text, user_id::text, request_path, request_method, details, created_at FROM middleware_events {where_clause} ORDER BY created_at DESC LIMIT ${param_idx} OFFSET ${param_idx + 1} """ rows = await pool.fetch(query, *params) return [ MiddlewareEventResponse( id=str(row["id"]), middleware_name=row["middleware_name"], event_type=row["event_type"], ip_address=row["ip_address"], user_id=row["user_id"], request_path=row["request_path"], request_method=row["request_method"], details=row["details"], created_at=row["created_at"], ) for row in rows ] @router.get("/stats", response_model=List[MiddlewareStatsResponse]) async def get_middleware_stats( session: Session = Depends(require_permission("settings:read")), ): """ Get statistics for all middlewares. Requires: settings:read permission """ pool = await get_db_pool() stats = [] middlewares = ["request_id", "security_headers", "cors", "rate_limiter", "pii_redactor", "input_gate", "sdk_protection"] for mw in middlewares: # Get event counts counts = await pool.fetchrow(""" SELECT COUNT(*) as total, COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '1 hour') as last_hour, COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '24 hours') as last_24h FROM middleware_events WHERE middleware_name = $1 """, mw) # Get top event types top_events = await pool.fetch(""" SELECT event_type, COUNT(*) as count FROM middleware_events WHERE middleware_name = $1 AND created_at > NOW() - INTERVAL '24 hours' GROUP BY event_type ORDER BY count DESC LIMIT 5 """, mw) # Get top IPs (for rate limiter) top_ips = await pool.fetch(""" SELECT ip_address::text, COUNT(*) as count FROM middleware_events WHERE middleware_name = $1 AND ip_address IS NOT NULL AND created_at > NOW() - INTERVAL '24 hours' GROUP BY ip_address ORDER BY count DESC LIMIT 5 """, mw) stats.append(MiddlewareStatsResponse( middleware_name=mw, total_events=counts["total"] or 0, events_last_hour=counts["last_hour"] or 0, events_last_24h=counts["last_24h"] or 0, top_event_types=[ {"event_type": r["event_type"], "count": r["count"]} for r in top_events ], top_ips=[ {"ip_address": r["ip_address"], "count": r["count"]} for r in top_ips ], )) return stats @router.get("/stats/{name}", response_model=MiddlewareStatsResponse) async def get_middleware_stats_by_name( name: str, session: Session = Depends(require_permission("settings:read")), ): """ Get statistics for a specific middleware. Requires: settings:read permission """ pool = await get_db_pool() # Get event counts counts = await pool.fetchrow(""" SELECT COUNT(*) as total, COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '1 hour') as last_hour, COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '24 hours') as last_24h FROM middleware_events WHERE middleware_name = $1 """, name) # Get top event types top_events = await pool.fetch(""" SELECT event_type, COUNT(*) as count FROM middleware_events WHERE middleware_name = $1 AND created_at > NOW() - INTERVAL '24 hours' GROUP BY event_type ORDER BY count DESC LIMIT 10 """, name) # Get top IPs top_ips = await pool.fetch(""" SELECT ip_address::text, COUNT(*) as count FROM middleware_events WHERE middleware_name = $1 AND ip_address IS NOT NULL AND created_at > NOW() - INTERVAL '24 hours' GROUP BY ip_address ORDER BY count DESC LIMIT 10 """, name) return MiddlewareStatsResponse( middleware_name=name, total_events=counts["total"] or 0, events_last_hour=counts["last_hour"] or 0, events_last_24h=counts["last_24h"] or 0, top_event_types=[ {"event_type": r["event_type"], "count": r["count"]} for r in top_events ], top_ips=[ {"ip_address": r["ip_address"], "count": r["count"]} for r in top_ips ], ) # ============================================== # SDK Protection Endpoints # ============================================== @router.get("/sdk-protection/scores", response_model=List[SDKAnomalyScoreResponse]) async def list_sdk_anomaly_scores( min_score: float = Query(0, ge=0), limit: int = Query(50, le=500), session: Session = Depends(require_permission("settings:read")), ): """ List current SDK anomaly scores. Requires: settings:read permission """ pool = await get_db_pool() rows = await pool.fetch(""" SELECT DISTINCT ON (user_id) id, user_id, score, throttle_level, triggered_rules, endpoint_diversity_count, request_count_1h, snapshot_at FROM sdk_anomaly_scores WHERE score >= $1 ORDER BY user_id, snapshot_at DESC LIMIT $2 """, min_score, limit) return [ SDKAnomalyScoreResponse( id=str(row["id"]), user_id=row["user_id"], score=float(row["score"]), throttle_level=row["throttle_level"], triggered_rules=row["triggered_rules"] or [], endpoint_diversity_count=row["endpoint_diversity_count"] or 0, request_count_1h=row["request_count_1h"] or 0, snapshot_at=row["snapshot_at"], ) for row in rows ] @router.get("/sdk-protection/stats", response_model=SDKProtectionStatsResponse) async def get_sdk_protection_stats( session: Session = Depends(require_permission("settings:read")), ): """ Get SDK protection statistics (users per throttle level). Requires: settings:read permission """ pool = await get_db_pool() row = await pool.fetchrow(""" WITH latest_scores AS ( SELECT DISTINCT ON (user_id) user_id, score, throttle_level FROM sdk_anomaly_scores WHERE snapshot_at > NOW() - INTERVAL '24 hours' ORDER BY user_id, snapshot_at DESC ) SELECT COUNT(*) as total, COUNT(*) FILTER (WHERE throttle_level = 0) as level_0, COUNT(*) FILTER (WHERE throttle_level = 1) as level_1, COUNT(*) FILTER (WHERE throttle_level = 2) as level_2, COUNT(*) FILTER (WHERE throttle_level = 3) as level_3, COALESCE(AVG(score), 0) as avg_score, COALESCE(MAX(score), 0) as max_score FROM latest_scores """) return SDKProtectionStatsResponse( total_users_tracked=row["total"] or 0, users_level_0=row["level_0"] or 0, users_level_1=row["level_1"] or 0, users_level_2=row["level_2"] or 0, users_level_3=row["level_3"] or 0, avg_score=float(row["avg_score"]), max_score=float(row["max_score"]), ) @router.post("/sdk-protection/reset-score/{user_id}") async def reset_sdk_anomaly_score( user_id: str, session: Session = Depends(require_permission("settings:write")), ): """ Reset anomaly score for a specific user. Requires: settings:write permission """ pool = await get_db_pool() # Insert a zero-score snapshot await pool.execute(""" INSERT INTO sdk_anomaly_scores (user_id, score, throttle_level, triggered_rules, endpoint_diversity_count, request_count_1h) VALUES ($1, 0, 0, '[]'::jsonb, 0, 0) """, user_id) # Clear Valkey score if available try: import redis.asyncio as r valkey_url = os.getenv("VALKEY_URL", "redis://localhost:6379") client = r.from_url(valkey_url, decode_responses=True, socket_timeout=1.0) await client.delete(f"sdk_protect:score:{user_id}") await client.aclose() except Exception: pass # Best-effort Valkey cleanup # Log the event await pool.execute(""" INSERT INTO middleware_events (middleware_name, event_type, user_id, details) VALUES ('sdk_protection', 'score_reset', $1, $2) """, session.user_id, {"target_user": user_id}) return {"message": f"Anomaly score reset for user {user_id}"} @router.get("/sdk-protection/tiers", response_model=List[SDKProtectionTierResponse]) async def list_sdk_protection_tiers( session: Session = Depends(require_permission("settings:read")), ): """ List SDK protection tier configurations. Requires: settings:read permission """ pool = await get_db_pool() rows = await pool.fetch(""" SELECT tier_name, quota_per_minute, quota_per_hour, quota_per_day, quota_per_month, diversity_threshold, burst_threshold FROM sdk_protection_tiers ORDER BY quota_per_minute """) return [ SDKProtectionTierResponse( tier_name=row["tier_name"], quota_per_minute=row["quota_per_minute"], quota_per_hour=row["quota_per_hour"], quota_per_day=row["quota_per_day"], quota_per_month=row["quota_per_month"], diversity_threshold=row["diversity_threshold"], burst_threshold=row["burst_threshold"], ) for row in rows ] @router.put("/sdk-protection/tiers/{name}", response_model=SDKProtectionTierResponse) async def update_sdk_protection_tier( name: str, data: SDKProtectionTierUpdateRequest, session: Session = Depends(require_permission("settings:write")), ): """ Update an SDK protection tier. Requires: settings:write permission """ pool = await get_db_pool() # Build update dynamically updates = [] params = [name] param_idx = 2 for field_name in [ "quota_per_minute", "quota_per_hour", "quota_per_day", "quota_per_month", "diversity_threshold", "burst_threshold", ]: value = getattr(data, field_name) if value is not None: updates.append(f"{field_name} = ${param_idx}") params.append(value) param_idx += 1 if not updates: from fastapi import HTTPException raise HTTPException(status_code=400, detail="No fields to update") updates.append("updated_at = NOW()") query = f""" UPDATE sdk_protection_tiers SET {", ".join(updates)} WHERE tier_name = $1 RETURNING tier_name, quota_per_minute, quota_per_hour, quota_per_day, quota_per_month, diversity_threshold, burst_threshold """ row = await pool.fetchrow(query, *params) if not row: from fastapi import HTTPException raise HTTPException(status_code=404, detail=f"Tier '{name}' not found") # Log the change await pool.execute(""" INSERT INTO middleware_events (middleware_name, event_type, user_id, details) VALUES ('sdk_protection', 'tier_updated', $1, $2) """, session.user_id, {"tier": name, "changes": data.dict(exclude_none=True)}) return SDKProtectionTierResponse( tier_name=row["tier_name"], quota_per_minute=row["quota_per_minute"], quota_per_hour=row["quota_per_hour"], quota_per_day=row["quota_per_day"], quota_per_month=row["quota_per_month"], diversity_threshold=row["diversity_threshold"], burst_threshold=row["burst_threshold"], )