Compare commits

...

1 Commits

Author SHA1 Message Date
Sharang Parnerkar f474699279 fix(core): JWKS refresh-on-failure in M7.1 auth middleware
CI / Check (pull_request) Successful in 8m17s
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped
Without this, every Keycloak signing-key rotation produces a silent
401 storm against every request until the agent restarts — the cached
JWKS is held forever and never reconciled against KC.

Now: when `kid` isn't in the cached JWKS or the matching key fails
signature verification, we classify the failure as Stale, force a JWKS
refresh, and retry once. Anything else (expired, malformed, missing
tenant_id) is Permanent and short-circuits straight to 401.

* Splits the path into a pure `try_validate(token, header, kid, jwks)`
  helper returning a `ValidationError { Stale | Permanent }` enum.
* `fetch_or_get_jwks(state, force)` takes a force flag and holds the
  write lock across the network fetch so concurrent refreshers don't
  all hammer Keycloak when keys rotate (the second writer reuses what
  the first put in cache).
* Adds a unit test for the kid-not-found Stale classification.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-04 16:40:55 +02:00
+96 -12
View File
@@ -148,27 +148,83 @@ async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext,
let kid = header let kid = header
.kid .kid
.clone()
.ok_or_else(|| "JWT missing kid header".to_string())?; .ok_or_else(|| "JWT missing kid header".to_string())?;
let jwks = fetch_or_get_jwks(state).await?; // First try against whatever's currently cached. If the kid isn't
// there or the signature doesn't verify, the cached JWKS is most
// likely stale (KC rotated keys) — refresh once and retry before
// giving up. Without this every key rotation produces a silent 401
// storm that only goes away when the agent restarts.
let jwks = fetch_or_get_jwks(state, false).await?;
match try_validate(token, &header, &kid, &jwks) {
Ok(ctx) => Ok(ctx),
Err(ValidationError::Permanent(e)) => Err(e),
Err(ValidationError::Stale(reason)) => {
tracing::info!(
kid = %kid,
reason = %reason,
"JWKS appears stale — forcing refresh and retrying"
);
let jwks = fetch_or_get_jwks(state, true).await?;
try_validate(token, &header, &kid, &jwks).map_err(|e| match e {
ValidationError::Stale(s) | ValidationError::Permanent(s) => s,
})
}
}
}
let jwk = jwks #[derive(Debug)]
enum ValidationError {
/// Refresh-eligible: cached JWKS may be stale.
Stale(String),
/// Refusing the token regardless of JWKS freshness.
Permanent(String),
}
fn try_validate(
token: &str,
header: &jsonwebtoken::Header,
kid: &str,
jwks: &JwkSet,
) -> Result<TenantContext, ValidationError> {
let jwk = match jwks
.keys .keys
.iter() .iter()
.find(|k| k.common.key_id.as_deref() == Some(&kid)) .find(|k| k.common.key_id.as_deref() == Some(kid))
.ok_or_else(|| "no matching key found in JWKS".to_string())?; {
Some(j) => j,
None => {
return Err(ValidationError::Stale(
"no matching key found in JWKS".to_string(),
))
}
};
let decoding_key = let decoding_key = DecodingKey::from_jwk(jwk)
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?; .map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?;
let mut validation = Validation::new(header.alg); let mut validation = Validation::new(header.alg);
validation.validate_exp = true; validation.validate_exp = true;
validation.validate_aud = false; validation.validate_aud = false;
let data = decode::<Claims>(token, &decoding_key, &validation) let data = match decode::<Claims>(token, &decoding_key, &validation) {
.map_err(|e| format!("token validation failed: {e}"))?; Ok(d) => d,
Err(e) => {
// Signature mismatch is the other refresh-eligible failure:
// the matching kid is present but the key bytes don't match.
// Everything else (expired, malformed, etc.) is permanent.
return Err(
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) {
ValidationError::Stale(format!("token validation failed: {e}"))
} else {
ValidationError::Permanent(format!("token validation failed: {e}"))
},
);
}
};
claims_to_context(data.claims) claims_to_context(data.claims).map_err(ValidationError::Permanent)
} }
/// Map the decoded JWT payload into the platform-wide `TenantContext`. /// Map the decoded JWT payload into the platform-wide `TenantContext`.
@@ -198,14 +254,25 @@ fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
}) })
} }
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> { async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, String> {
{ if !force {
let cached = state.jwks.read().await; let cached = state.jwks.read().await;
if let Some(ref jwks) = *cached { if let Some(ref jwks) = *cached {
return Ok(jwks.clone()); return Ok(jwks.clone());
} }
} }
// Hold the write lock across the fetch so concurrent refreshers
// don't all hammer Keycloak when keys rotate. If another writer
// already populated a fresh JWKS while we were waiting (and we
// weren't asked to force), use theirs.
let mut cached = state.jwks.write().await;
if !force {
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
let resp = reqwest::get(&state.jwks_url) let resp = reqwest::get(&state.jwks_url)
.await .await
.map_err(|e| format!("failed to fetch JWKS: {e}"))?; .map_err(|e| format!("failed to fetch JWKS: {e}"))?;
@@ -215,7 +282,6 @@ async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
.await .await
.map_err(|e| format!("failed to parse JWKS: {e}"))?; .map_err(|e| format!("failed to parse JWKS: {e}"))?;
let mut cached = state.jwks.write().await;
*cached = Some(jwks.clone()); *cached = Some(jwks.clone());
Ok(jwks) Ok(jwks)
@@ -293,6 +359,24 @@ mod tests {
); );
} }
#[test]
fn try_validate_returns_stale_when_kid_missing_from_jwks() {
// Empty JWKS — the kid we ask for can't possibly match. The error
// must classify as Stale so the caller refreshes JWKS and retries.
let jwks = JwkSet { keys: vec![] };
let header = jsonwebtoken::Header {
alg: jsonwebtoken::Algorithm::RS256,
kid: Some("kid-rotated-out".to_string()),
..Default::default()
};
let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks)
.expect_err("should fail");
match err {
ValidationError::Stale(s) => assert!(s.contains("no matching key")),
ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"),
}
}
#[test] #[test]
fn is_write_detects_methods() { fn is_write_detects_methods() {
assert!(!is_write(&Method::GET)); assert!(!is_write(&Method::GET));