From dbadff0aac402f4452d9218e60f8687571a2bc62 Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Thu, 4 Jun 2026 14:46:14 +0000 Subject: [PATCH] fix(m7.1): JWKS refresh-on-failure in auth middleware (#84) --- compliance-core/src/auth.rs | 108 ++++++++++++++++++++++++++++++++---- 1 file changed, 96 insertions(+), 12 deletions(-) diff --git a/compliance-core/src/auth.rs b/compliance-core/src/auth.rs index 3d54c43..4b422e7 100644 --- a/compliance-core/src/auth.rs +++ b/compliance-core/src/auth.rs @@ -148,27 +148,83 @@ async fn validate_token(token: &str, state: &JwksState) -> Result Ok(ctx), + Err(ValidationError::Permanent(e)) => Err(e), + Err(ValidationError::Stale(reason)) => { + tracing::info!( + kid = %kid, + reason = %reason, + "JWKS appears stale — forcing refresh and retrying" + ); + let jwks = fetch_or_get_jwks(state, true).await?; + try_validate(token, &header, &kid, &jwks).map_err(|e| match e { + ValidationError::Stale(s) | ValidationError::Permanent(s) => s, + }) + } + } +} - let jwk = jwks +#[derive(Debug)] +enum ValidationError { + /// Refresh-eligible: cached JWKS may be stale. + Stale(String), + /// Refusing the token regardless of JWKS freshness. + Permanent(String), +} + +fn try_validate( + token: &str, + header: &jsonwebtoken::Header, + kid: &str, + jwks: &JwkSet, +) -> Result { + let jwk = match jwks .keys .iter() - .find(|k| k.common.key_id.as_deref() == Some(&kid)) - .ok_or_else(|| "no matching key found in JWKS".to_string())?; + .find(|k| k.common.key_id.as_deref() == Some(kid)) + { + Some(j) => j, + None => { + return Err(ValidationError::Stale( + "no matching key found in JWKS".to_string(), + )) + } + }; - let decoding_key = - DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?; + let decoding_key = DecodingKey::from_jwk(jwk) + .map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?; let mut validation = Validation::new(header.alg); validation.validate_exp = true; validation.validate_aud = false; - let data = decode::(token, &decoding_key, &validation) - .map_err(|e| format!("token validation failed: {e}"))?; + let data = match decode::(token, &decoding_key, &validation) { + Ok(d) => d, + Err(e) => { + // Signature mismatch is the other refresh-eligible failure: + // the matching kid is present but the key bytes don't match. + // Everything else (expired, malformed, etc.) is permanent. + return Err( + if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) { + ValidationError::Stale(format!("token validation failed: {e}")) + } else { + ValidationError::Permanent(format!("token validation failed: {e}")) + }, + ); + } + }; - claims_to_context(data.claims) + claims_to_context(data.claims).map_err(ValidationError::Permanent) } /// Map the decoded JWT payload into the platform-wide `TenantContext`. @@ -198,14 +254,25 @@ fn claims_to_context(c: Claims) -> Result { }) } -async fn fetch_or_get_jwks(state: &JwksState) -> Result { - { +async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result { + if !force { let cached = state.jwks.read().await; if let Some(ref jwks) = *cached { return Ok(jwks.clone()); } } + // Hold the write lock across the fetch so concurrent refreshers + // don't all hammer Keycloak when keys rotate. If another writer + // already populated a fresh JWKS while we were waiting (and we + // weren't asked to force), use theirs. + let mut cached = state.jwks.write().await; + if !force { + if let Some(ref jwks) = *cached { + return Ok(jwks.clone()); + } + } + let resp = reqwest::get(&state.jwks_url) .await .map_err(|e| format!("failed to fetch JWKS: {e}"))?; @@ -215,7 +282,6 @@ async fn fetch_or_get_jwks(state: &JwksState) -> Result { .await .map_err(|e| format!("failed to parse JWKS: {e}"))?; - let mut cached = state.jwks.write().await; *cached = Some(jwks.clone()); Ok(jwks) @@ -293,6 +359,24 @@ mod tests { ); } + #[test] + fn try_validate_returns_stale_when_kid_missing_from_jwks() { + // Empty JWKS — the kid we ask for can't possibly match. The error + // must classify as Stale so the caller refreshes JWKS and retries. + let jwks = JwkSet { keys: vec![] }; + let header = jsonwebtoken::Header { + alg: jsonwebtoken::Algorithm::RS256, + kid: Some("kid-rotated-out".to_string()), + ..Default::default() + }; + let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks) + .expect_err("should fail"); + match err { + ValidationError::Stale(s) => assert!(s.contains("no matching key")), + ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"), + } + } + #[test] fn is_write_detects_methods() { assert!(!is_write(&Method::GET));