From f47469927938b9ddae26f934bea1631ea10f399c Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar <30073382+mighty840@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:40:55 +0200 Subject: [PATCH] fix(core): JWKS refresh-on-failure in M7.1 auth middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without this, every Keycloak signing-key rotation produces a silent 401 storm against every request until the agent restarts — the cached JWKS is held forever and never reconciled against KC. Now: when `kid` isn't in the cached JWKS or the matching key fails signature verification, we classify the failure as Stale, force a JWKS refresh, and retry once. Anything else (expired, malformed, missing tenant_id) is Permanent and short-circuits straight to 401. * Splits the path into a pure `try_validate(token, header, kid, jwks)` helper returning a `ValidationError { Stale | Permanent }` enum. * `fetch_or_get_jwks(state, force)` takes a force flag and holds the write lock across the network fetch so concurrent refreshers don't all hammer Keycloak when keys rotate (the second writer reuses what the first put in cache). * Adds a unit test for the kid-not-found Stale classification. Co-Authored-By: Claude Opus 4.7 --- compliance-core/src/auth.rs | 108 ++++++++++++++++++++++++++++++++---- 1 file changed, 96 insertions(+), 12 deletions(-) diff --git a/compliance-core/src/auth.rs b/compliance-core/src/auth.rs index 3d54c43..4b422e7 100644 --- a/compliance-core/src/auth.rs +++ b/compliance-core/src/auth.rs @@ -148,27 +148,83 @@ async fn validate_token(token: &str, state: &JwksState) -> Result Ok(ctx), + Err(ValidationError::Permanent(e)) => Err(e), + Err(ValidationError::Stale(reason)) => { + tracing::info!( + kid = %kid, + reason = %reason, + "JWKS appears stale — forcing refresh and retrying" + ); + let jwks = fetch_or_get_jwks(state, true).await?; + try_validate(token, &header, &kid, &jwks).map_err(|e| match e { + ValidationError::Stale(s) | ValidationError::Permanent(s) => s, + }) + } + } +} - let jwk = jwks +#[derive(Debug)] +enum ValidationError { + /// Refresh-eligible: cached JWKS may be stale. + Stale(String), + /// Refusing the token regardless of JWKS freshness. + Permanent(String), +} + +fn try_validate( + token: &str, + header: &jsonwebtoken::Header, + kid: &str, + jwks: &JwkSet, +) -> Result { + let jwk = match jwks .keys .iter() - .find(|k| k.common.key_id.as_deref() == Some(&kid)) - .ok_or_else(|| "no matching key found in JWKS".to_string())?; + .find(|k| k.common.key_id.as_deref() == Some(kid)) + { + Some(j) => j, + None => { + return Err(ValidationError::Stale( + "no matching key found in JWKS".to_string(), + )) + } + }; - let decoding_key = - DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?; + let decoding_key = DecodingKey::from_jwk(jwk) + .map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?; let mut validation = Validation::new(header.alg); validation.validate_exp = true; validation.validate_aud = false; - let data = decode::(token, &decoding_key, &validation) - .map_err(|e| format!("token validation failed: {e}"))?; + let data = match decode::(token, &decoding_key, &validation) { + Ok(d) => d, + Err(e) => { + // Signature mismatch is the other refresh-eligible failure: + // the matching kid is present but the key bytes don't match. + // Everything else (expired, malformed, etc.) is permanent. + return Err( + if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) { + ValidationError::Stale(format!("token validation failed: {e}")) + } else { + ValidationError::Permanent(format!("token validation failed: {e}")) + }, + ); + } + }; - claims_to_context(data.claims) + claims_to_context(data.claims).map_err(ValidationError::Permanent) } /// Map the decoded JWT payload into the platform-wide `TenantContext`. @@ -198,14 +254,25 @@ fn claims_to_context(c: Claims) -> Result { }) } -async fn fetch_or_get_jwks(state: &JwksState) -> Result { - { +async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result { + if !force { let cached = state.jwks.read().await; if let Some(ref jwks) = *cached { return Ok(jwks.clone()); } } + // Hold the write lock across the fetch so concurrent refreshers + // don't all hammer Keycloak when keys rotate. If another writer + // already populated a fresh JWKS while we were waiting (and we + // weren't asked to force), use theirs. + let mut cached = state.jwks.write().await; + if !force { + if let Some(ref jwks) = *cached { + return Ok(jwks.clone()); + } + } + let resp = reqwest::get(&state.jwks_url) .await .map_err(|e| format!("failed to fetch JWKS: {e}"))?; @@ -215,7 +282,6 @@ async fn fetch_or_get_jwks(state: &JwksState) -> Result { .await .map_err(|e| format!("failed to parse JWKS: {e}"))?; - let mut cached = state.jwks.write().await; *cached = Some(jwks.clone()); Ok(jwks) @@ -293,6 +359,24 @@ mod tests { ); } + #[test] + fn try_validate_returns_stale_when_kid_missing_from_jwks() { + // Empty JWKS — the kid we ask for can't possibly match. The error + // must classify as Stale so the caller refreshes JWKS and retries. + let jwks = JwkSet { keys: vec![] }; + let header = jsonwebtoken::Header { + alg: jsonwebtoken::Algorithm::RS256, + kid: Some("kid-rotated-out".to_string()), + ..Default::default() + }; + let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks) + .expect_err("should fail"); + match err { + ValidationError::Stale(s) => assert!(s.contains("no matching key")), + ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"), + } + } + #[test] fn is_write_detects_methods() { assert!(!is_write(&Method::GET));