05c01ea547
CI / Check (pull_request) Successful in 10m50s
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped
Lays the platform-wide multi-tenancy infrastructure on top of the existing Keycloak signature validation. JWTs now carry tenant_id, tenant_slug, org_roles, products, plan, and tenant_status; the middleware decodes them into a TenantContext and attaches it to the request extensions. A TenantCtx Axum extractor exposes the context to handlers, and a tenant_status middleware enforces the §5c lifecycle (frozen tenants are 402 on writes; archived tenants are 410 on every method). A db::tenant_filter helper in compliance-core gives every future collection a single grep-able pattern for tenant-scoped queries. Per-collection wiring (adding tenant_id to each model + threading the filter through every find/update/delete call) lands in a follow-up. Tests: 6 inline unit tests for claims→context mapping, 2 for the extractor, 6 integration tests for status middleware, 3 for db filter. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
306 lines
9.7 KiB
Rust
306 lines
9.7 KiB
Rust
//! M7.1 — JWT validation + tenant context propagation.
|
|
//!
|
|
//! `require_jwt_auth` validates a Bearer JWT against Keycloak's JWKS and
|
|
//! attaches a `TenantContext` to the request extensions. Downstream
|
|
//! middleware (`require_tenant_status`) and Axum extractors (`TenantCtx`)
|
|
//! read it from there.
|
|
//!
|
|
//! Skipped paths:
|
|
//! * `/api/v1/health` — Kubernetes liveness; never authenticated.
|
|
//!
|
|
//! Failure modes:
|
|
//! * No `JwksState` extension → pass-through (single-tenant dev mode).
|
|
//! * Missing / malformed Bearer header → 401.
|
|
//! * Signature / expiry invalid → 401.
|
|
//! * Claims present but tenant_id missing → 401 (treated as a malformed
|
|
//! token; the realm must always issue tenant_id).
|
|
|
|
use std::sync::Arc;
|
|
|
|
use axum::{
|
|
extract::Request,
|
|
http::Method,
|
|
middleware::Next,
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use compliance_core::{OrgRole, TenantContext, TenantStatus};
|
|
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
|
|
use reqwest::StatusCode;
|
|
use serde::Deserialize;
|
|
use tokio::sync::RwLock;
|
|
|
|
/// Cached JWKS from Keycloak for token validation.
|
|
#[derive(Clone)]
|
|
pub struct JwksState {
|
|
pub jwks: Arc<RwLock<Option<JwkSet>>>,
|
|
pub jwks_url: String,
|
|
}
|
|
|
|
/// Raw shape of the JWT payload — matches the breakpilot-dev realm's
|
|
/// protocol-mapper output. Missing fields default to "" / empty so a
|
|
/// realm that hasn't been fully wired yet still validates.
|
|
#[derive(Debug, Deserialize)]
|
|
struct Claims {
|
|
sub: String,
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
preferred_username: Option<String>,
|
|
#[serde(default)]
|
|
tenant_id: String,
|
|
#[serde(default)]
|
|
tenant_slug: String,
|
|
#[serde(default)]
|
|
org_roles: Vec<String>,
|
|
#[serde(default)]
|
|
products: Vec<String>,
|
|
#[serde(default)]
|
|
plan: String,
|
|
#[serde(default)]
|
|
tenant_status: Option<TenantStatus>,
|
|
}
|
|
|
|
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
|
|
|
|
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS
|
|
/// and attaches a `TenantContext` extension on success.
|
|
///
|
|
/// Skips validation for the health endpoint.
|
|
/// If `JwksState` is not present (Keycloak not configured), requests
|
|
/// pass through and downstream code must handle the missing context.
|
|
pub async fn require_jwt_auth(mut request: Request, next: Next) -> Response {
|
|
let path = request.uri().path();
|
|
|
|
if PUBLIC_ENDPOINTS.contains(&path) {
|
|
return next.run(request).await;
|
|
}
|
|
|
|
let jwks_state = match request.extensions().get::<JwksState>() {
|
|
Some(s) => s.clone(),
|
|
None => return next.run(request).await,
|
|
};
|
|
|
|
let auth_header = match request.headers().get("authorization") {
|
|
Some(h) => h,
|
|
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
|
|
};
|
|
|
|
let token = match auth_header.to_str() {
|
|
Ok(s) if s.starts_with("Bearer ") => &s[7..],
|
|
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
|
|
};
|
|
|
|
match validate_token(token, &jwks_state).await {
|
|
Ok(ctx) => {
|
|
request.extensions_mut().insert(ctx);
|
|
next.run(request).await
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("JWT validation failed: {e}");
|
|
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Middleware that enforces the M7.1 `tenant_status` contract.
|
|
///
|
|
/// * `Active` / `Trial` / `Demo` — pass through.
|
|
/// * `Frozen` — read-only after cancel / non-payment. Writes return 402.
|
|
/// * `Archived` — data-retention window closed. Every request returns 410.
|
|
///
|
|
/// Pass-through when no `TenantContext` is present (single-tenant dev or
|
|
/// the upstream JWT middleware ran without `JwksState`).
|
|
pub async fn require_tenant_status(request: Request, next: Next) -> Response {
|
|
let ctx = match request.extensions().get::<TenantContext>() {
|
|
Some(c) => c.clone(),
|
|
None => return next.run(request).await,
|
|
};
|
|
|
|
if ctx.status.is_archived() {
|
|
return (
|
|
StatusCode::GONE,
|
|
"Tenant archived — data retention window closed",
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
if ctx.status.is_frozen() && is_write(request.method()) {
|
|
return (
|
|
StatusCode::PAYMENT_REQUIRED,
|
|
"Tenant frozen — read-only. Re-activate to resume writes.",
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
next.run(request).await
|
|
}
|
|
|
|
/// Treat anything other than GET/HEAD/OPTIONS as a write. Good enough for
|
|
/// REST. The few exceptions (e.g. read-side POSTs) can opt out at the
|
|
/// handler level once we have them.
|
|
fn is_write(m: &Method) -> bool {
|
|
!matches!(m, &Method::GET | &Method::HEAD | &Method::OPTIONS)
|
|
}
|
|
|
|
async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext, String> {
|
|
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
|
|
|
|
let kid = header
|
|
.kid
|
|
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
|
|
|
let jwks = fetch_or_get_jwks(state).await?;
|
|
|
|
let jwk = jwks
|
|
.keys
|
|
.iter()
|
|
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
|
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
|
|
|
let decoding_key =
|
|
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
|
|
|
let mut validation = Validation::new(header.alg);
|
|
validation.validate_exp = true;
|
|
validation.validate_aud = false;
|
|
|
|
let data = decode::<Claims>(token, &decoding_key, &validation)
|
|
.map_err(|e| format!("token validation failed: {e}"))?;
|
|
|
|
claims_to_context(data.claims)
|
|
}
|
|
|
|
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
|
|
/// Pulled out for unit testing — no I/O.
|
|
fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
|
|
if c.tenant_id.is_empty() {
|
|
return Err("JWT is missing tenant_id claim".to_string());
|
|
}
|
|
|
|
let status = c.tenant_status.unwrap_or_else(|| {
|
|
tracing::warn!(
|
|
"JWT missing tenant_status claim for tenant {} — defaulting to Trial",
|
|
c.tenant_id
|
|
);
|
|
TenantStatus::Trial
|
|
});
|
|
|
|
Ok(TenantContext {
|
|
tenant_id: c.tenant_id,
|
|
tenant_slug: c.tenant_slug,
|
|
org_roles: c.org_roles.iter().map(|r| OrgRole::parse(r)).collect(),
|
|
products: c.products,
|
|
plan: c.plan,
|
|
status,
|
|
user_id: c.sub,
|
|
user_name: c.name.or(c.preferred_username),
|
|
})
|
|
}
|
|
|
|
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
|
{
|
|
let cached = state.jwks.read().await;
|
|
if let Some(ref jwks) = *cached {
|
|
return Ok(jwks.clone());
|
|
}
|
|
}
|
|
|
|
let resp = reqwest::get(&state.jwks_url)
|
|
.await
|
|
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
|
|
|
let jwks: JwkSet = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
|
|
|
let mut cached = state.jwks.write().await;
|
|
*cached = Some(jwks.clone());
|
|
|
|
Ok(jwks)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(clippy::expect_used, clippy::unwrap_used)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn base_claims() -> Claims {
|
|
Claims {
|
|
sub: "user-123".to_string(),
|
|
name: Some("Alice Acme".to_string()),
|
|
preferred_username: None,
|
|
tenant_id: "00000000-0000-0000-0000-000000000001".to_string(),
|
|
tenant_slug: "acme".to_string(),
|
|
org_roles: vec!["IT_ADMIN".to_string()],
|
|
products: vec!["compliance".to_string()],
|
|
plan: "professional".to_string(),
|
|
tenant_status: Some(TenantStatus::Active),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn claims_to_context_happy_path() {
|
|
let ctx = claims_to_context(base_claims()).expect("should map");
|
|
assert_eq!(ctx.tenant_id, "00000000-0000-0000-0000-000000000001");
|
|
assert_eq!(ctx.tenant_slug, "acme");
|
|
assert_eq!(ctx.org_roles, vec![OrgRole::ItAdmin]);
|
|
assert_eq!(ctx.products, vec!["compliance"]);
|
|
assert_eq!(ctx.plan, "professional");
|
|
assert_eq!(ctx.status, TenantStatus::Active);
|
|
assert_eq!(ctx.user_id, "user-123");
|
|
assert_eq!(ctx.user_name.as_deref(), Some("Alice Acme"));
|
|
}
|
|
|
|
#[test]
|
|
fn claims_to_context_rejects_missing_tenant_id() {
|
|
let mut c = base_claims();
|
|
c.tenant_id = "".to_string();
|
|
let err = claims_to_context(c).expect_err("should reject");
|
|
assert!(err.contains("tenant_id"));
|
|
}
|
|
|
|
#[test]
|
|
fn claims_to_context_defaults_status_when_missing() {
|
|
let mut c = base_claims();
|
|
c.tenant_status = None;
|
|
let ctx = claims_to_context(c).expect("should map");
|
|
assert_eq!(ctx.status, TenantStatus::Trial);
|
|
}
|
|
|
|
#[test]
|
|
fn claims_to_context_falls_back_to_preferred_username() {
|
|
let mut c = base_claims();
|
|
c.name = None;
|
|
c.preferred_username = Some("alice@acme.dev".to_string());
|
|
let ctx = claims_to_context(c).expect("should map");
|
|
assert_eq!(ctx.user_name.as_deref(), Some("alice@acme.dev"));
|
|
}
|
|
|
|
#[test]
|
|
fn claims_to_context_parses_multiple_roles() {
|
|
let mut c = base_claims();
|
|
c.org_roles = vec![
|
|
"IT_ADMIN".to_string(),
|
|
"CXO".to_string(),
|
|
"GARBAGE".to_string(),
|
|
];
|
|
let ctx = claims_to_context(c).expect("should map");
|
|
assert_eq!(
|
|
ctx.org_roles,
|
|
vec![OrgRole::ItAdmin, OrgRole::Cxo, OrgRole::Unknown]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn is_write_detects_methods() {
|
|
assert!(!is_write(&Method::GET));
|
|
assert!(!is_write(&Method::HEAD));
|
|
assert!(!is_write(&Method::OPTIONS));
|
|
assert!(is_write(&Method::POST));
|
|
assert!(is_write(&Method::PUT));
|
|
assert!(is_write(&Method::PATCH));
|
|
assert!(is_write(&Method::DELETE));
|
|
}
|
|
}
|