use std::sync::Arc; use axum::http::HeaderValue; use axum::{middleware, Extension}; use tokio::sync::RwLock; use tower_http::cors::CorsLayer; use tower_http::set_header::SetResponseHeaderLayer; use tower_http::trace::TraceLayer; use crate::agent::ComplianceAgent; use crate::api::auth_middleware::{require_jwt_auth, require_tenant_status, JwksState}; use crate::api::routes; use crate::error::AgentError; pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> { let mut app = routes::build_router() .layer(Extension(Arc::new(agent.clone()))) .layer(CorsLayer::permissive()) .layer(TraceLayer::new_for_http()) // Security headers (defense-in-depth, primary enforcement via Traefik) .layer(SetResponseHeaderLayer::overriding( axum::http::header::STRICT_TRANSPORT_SECURITY, HeaderValue::from_static("max-age=31536000; includeSubDomains"), )) .layer(SetResponseHeaderLayer::overriding( axum::http::header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"), )) .layer(SetResponseHeaderLayer::overriding( axum::http::header::X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"), )) .layer(SetResponseHeaderLayer::overriding( axum::http::header::REFERRER_POLICY, HeaderValue::from_static("strict-origin-when-cross-origin"), )); if let (Some(kc_url), Some(kc_realm)) = (&agent.config.keycloak_url, &agent.config.keycloak_realm) { let jwks_url = format!("{kc_url}/realms/{kc_realm}/protocol/openid-connect/certs"); let jwks_state = JwksState { jwks: Arc::new(RwLock::new(None)), jwks_url, }; tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'"); // Layers execute outermost-first. The Extension must run before // require_jwt_auth so that middleware can read JwksState from // request extensions, and the status gate must run after the // JWT auth so TenantContext is in extensions. app = app .layer(middleware::from_fn(require_tenant_status)) .layer(middleware::from_fn(require_jwt_auth)) .layer(Extension(jwks_state)); } else { tracing::warn!("Keycloak not configured - API endpoints are unprotected"); } let addr = format!("0.0.0.0:{port}"); let listener = tokio::net::TcpListener::bind(&addr) .await .map_err(|e| AgentError::Other(format!("Failed to bind to {addr}: {e}")))?; tracing::info!("REST API listening on {addr}"); axum::serve(listener, app) .await .map_err(|e| AgentError::Other(format!("API server error: {e}")))?; Ok(()) }