//! M7.1 — JWT validation + tenant context propagation. //! //! `require_jwt_auth` validates a Bearer JWT against Keycloak's JWKS and //! attaches a `TenantContext` to the request extensions. Downstream //! middleware (`require_tenant_status`) and Axum extractors (`TenantCtx`) //! read it from there. //! //! Skipped paths: //! * `/api/v1/health` — Kubernetes liveness; never authenticated. //! //! Failure modes: //! * No `JwksState` extension → pass-through (single-tenant dev mode). //! * Missing / malformed Bearer header → 401. //! * Signature / expiry invalid → 401. //! * Claims present but tenant_id missing → 401 (treated as a malformed //! token; the realm must always issue tenant_id). use std::sync::Arc; use axum::{ extract::Request, http::Method, middleware::Next, response::{IntoResponse, Response}, }; use compliance_core::{OrgRole, TenantContext, TenantStatus}; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation}; use reqwest::StatusCode; use serde::Deserialize; use tokio::sync::RwLock; /// Cached JWKS from Keycloak for token validation. #[derive(Clone)] pub struct JwksState { pub jwks: Arc>>, pub jwks_url: String, } /// Raw shape of the JWT payload — matches the breakpilot-dev realm's /// protocol-mapper output. Missing fields default to "" / empty so a /// realm that hasn't been fully wired yet still validates. #[derive(Debug, Deserialize)] struct Claims { sub: String, #[serde(default)] name: Option, #[serde(default)] preferred_username: Option, #[serde(default)] tenant_id: String, #[serde(default)] tenant_slug: String, #[serde(default)] org_roles: Vec, #[serde(default)] products: Vec, #[serde(default)] plan: String, #[serde(default)] tenant_status: Option, } const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"]; /// Middleware that validates Bearer JWT tokens against Keycloak's JWKS /// and attaches a `TenantContext` extension on success. /// /// Skips validation for the health endpoint. /// If `JwksState` is not present (Keycloak not configured), requests /// pass through and downstream code must handle the missing context. pub async fn require_jwt_auth(mut request: Request, next: Next) -> Response { let path = request.uri().path(); if PUBLIC_ENDPOINTS.contains(&path) { return next.run(request).await; } let jwks_state = match request.extensions().get::() { Some(s) => s.clone(), None => return next.run(request).await, }; let auth_header = match request.headers().get("authorization") { Some(h) => h, None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(), }; let token = match auth_header.to_str() { Ok(s) if s.starts_with("Bearer ") => &s[7..], _ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(), }; match validate_token(token, &jwks_state).await { Ok(ctx) => { request.extensions_mut().insert(ctx); next.run(request).await } Err(e) => { tracing::warn!("JWT validation failed: {e}"); (StatusCode::UNAUTHORIZED, "Invalid token").into_response() } } } /// Middleware that enforces the M7.1 `tenant_status` contract. /// /// * `Active` / `Trial` / `Demo` — pass through. /// * `Frozen` — read-only after cancel / non-payment. Writes return 402. /// * `Archived` — data-retention window closed. Every request returns 410. /// /// Pass-through when no `TenantContext` is present (single-tenant dev or /// the upstream JWT middleware ran without `JwksState`). pub async fn require_tenant_status(request: Request, next: Next) -> Response { let ctx = match request.extensions().get::() { Some(c) => c.clone(), None => return next.run(request).await, }; if ctx.status.is_archived() { return ( StatusCode::GONE, "Tenant archived — data retention window closed", ) .into_response(); } if ctx.status.is_frozen() && is_write(request.method()) { return ( StatusCode::PAYMENT_REQUIRED, "Tenant frozen — read-only. Re-activate to resume writes.", ) .into_response(); } next.run(request).await } /// Treat anything other than GET/HEAD/OPTIONS as a write. Good enough for /// REST. The few exceptions (e.g. read-side POSTs) can opt out at the /// handler level once we have them. fn is_write(m: &Method) -> bool { !matches!(m, &Method::GET | &Method::HEAD | &Method::OPTIONS) } async fn validate_token(token: &str, state: &JwksState) -> Result { let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?; let kid = header .kid .ok_or_else(|| "JWT missing kid header".to_string())?; let jwks = fetch_or_get_jwks(state).await?; let jwk = jwks .keys .iter() .find(|k| k.common.key_id.as_deref() == Some(&kid)) .ok_or_else(|| "no matching key found in JWKS".to_string())?; let decoding_key = DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?; let mut validation = Validation::new(header.alg); validation.validate_exp = true; validation.validate_aud = false; let data = decode::(token, &decoding_key, &validation) .map_err(|e| format!("token validation failed: {e}"))?; claims_to_context(data.claims) } /// Map the decoded JWT payload into the platform-wide `TenantContext`. /// Pulled out for unit testing — no I/O. fn claims_to_context(c: Claims) -> Result { if c.tenant_id.is_empty() { return Err("JWT is missing tenant_id claim".to_string()); } let status = c.tenant_status.unwrap_or_else(|| { tracing::warn!( "JWT missing tenant_status claim for tenant {} — defaulting to Trial", c.tenant_id ); TenantStatus::Trial }); Ok(TenantContext { tenant_id: c.tenant_id, tenant_slug: c.tenant_slug, org_roles: c.org_roles.iter().map(|r| OrgRole::parse(r)).collect(), products: c.products, plan: c.plan, status, user_id: c.sub, user_name: c.name.or(c.preferred_username), }) } async fn fetch_or_get_jwks(state: &JwksState) -> Result { { let cached = state.jwks.read().await; if let Some(ref jwks) = *cached { return Ok(jwks.clone()); } } let resp = reqwest::get(&state.jwks_url) .await .map_err(|e| format!("failed to fetch JWKS: {e}"))?; let jwks: JwkSet = resp .json() .await .map_err(|e| format!("failed to parse JWKS: {e}"))?; let mut cached = state.jwks.write().await; *cached = Some(jwks.clone()); Ok(jwks) } #[cfg(test)] #[allow(clippy::expect_used, clippy::unwrap_used)] mod tests { use super::*; fn base_claims() -> Claims { Claims { sub: "user-123".to_string(), name: Some("Alice Acme".to_string()), preferred_username: None, tenant_id: "00000000-0000-0000-0000-000000000001".to_string(), tenant_slug: "acme".to_string(), org_roles: vec!["IT_ADMIN".to_string()], products: vec!["compliance".to_string()], plan: "professional".to_string(), tenant_status: Some(TenantStatus::Active), } } #[test] fn claims_to_context_happy_path() { let ctx = claims_to_context(base_claims()).expect("should map"); assert_eq!(ctx.tenant_id, "00000000-0000-0000-0000-000000000001"); assert_eq!(ctx.tenant_slug, "acme"); assert_eq!(ctx.org_roles, vec![OrgRole::ItAdmin]); assert_eq!(ctx.products, vec!["compliance"]); assert_eq!(ctx.plan, "professional"); assert_eq!(ctx.status, TenantStatus::Active); assert_eq!(ctx.user_id, "user-123"); assert_eq!(ctx.user_name.as_deref(), Some("Alice Acme")); } #[test] fn claims_to_context_rejects_missing_tenant_id() { let mut c = base_claims(); c.tenant_id = "".to_string(); let err = claims_to_context(c).expect_err("should reject"); assert!(err.contains("tenant_id")); } #[test] fn claims_to_context_defaults_status_when_missing() { let mut c = base_claims(); c.tenant_status = None; let ctx = claims_to_context(c).expect("should map"); assert_eq!(ctx.status, TenantStatus::Trial); } #[test] fn claims_to_context_falls_back_to_preferred_username() { let mut c = base_claims(); c.name = None; c.preferred_username = Some("alice@acme.dev".to_string()); let ctx = claims_to_context(c).expect("should map"); assert_eq!(ctx.user_name.as_deref(), Some("alice@acme.dev")); } #[test] fn claims_to_context_parses_multiple_roles() { let mut c = base_claims(); c.org_roles = vec![ "IT_ADMIN".to_string(), "CXO".to_string(), "GARBAGE".to_string(), ]; let ctx = claims_to_context(c).expect("should map"); assert_eq!( ctx.org_roles, vec![OrgRole::ItAdmin, OrgRole::Cxo, OrgRole::Unknown] ); } #[test] fn is_write_detects_methods() { assert!(!is_write(&Method::GET)); assert!(!is_write(&Method::HEAD)); assert!(!is_write(&Method::OPTIONS)); assert!(is_write(&Method::POST)); assert!(is_write(&Method::PUT)); assert!(is_write(&Method::PATCH)); assert!(is_write(&Method::DELETE)); } }