fix(m7.1): JWKS refresh-on-failure in auth middleware (#84)
CI / Check (push) Has been skipped
CI / Detect Changes (push) Successful in 3s
CI / Deploy Agent (push) Successful in 11m44s
CI / Deploy Dashboard (push) Successful in 13m1s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 1m53s
CI / Check (push) Has been skipped
CI / Detect Changes (push) Successful in 3s
CI / Deploy Agent (push) Successful in 11m44s
CI / Deploy Dashboard (push) Successful in 13m1s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 1m53s
This commit was merged in pull request #84.
This commit is contained in:
+96
-12
@@ -148,27 +148,83 @@ async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext,
|
|||||||
|
|
||||||
let kid = header
|
let kid = header
|
||||||
.kid
|
.kid
|
||||||
|
.clone()
|
||||||
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
||||||
|
|
||||||
let jwks = fetch_or_get_jwks(state).await?;
|
// First try against whatever's currently cached. If the kid isn't
|
||||||
|
// there or the signature doesn't verify, the cached JWKS is most
|
||||||
|
// likely stale (KC rotated keys) — refresh once and retry before
|
||||||
|
// giving up. Without this every key rotation produces a silent 401
|
||||||
|
// storm that only goes away when the agent restarts.
|
||||||
|
let jwks = fetch_or_get_jwks(state, false).await?;
|
||||||
|
match try_validate(token, &header, &kid, &jwks) {
|
||||||
|
Ok(ctx) => Ok(ctx),
|
||||||
|
Err(ValidationError::Permanent(e)) => Err(e),
|
||||||
|
Err(ValidationError::Stale(reason)) => {
|
||||||
|
tracing::info!(
|
||||||
|
kid = %kid,
|
||||||
|
reason = %reason,
|
||||||
|
"JWKS appears stale — forcing refresh and retrying"
|
||||||
|
);
|
||||||
|
let jwks = fetch_or_get_jwks(state, true).await?;
|
||||||
|
try_validate(token, &header, &kid, &jwks).map_err(|e| match e {
|
||||||
|
ValidationError::Stale(s) | ValidationError::Permanent(s) => s,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let jwk = jwks
|
#[derive(Debug)]
|
||||||
|
enum ValidationError {
|
||||||
|
/// Refresh-eligible: cached JWKS may be stale.
|
||||||
|
Stale(String),
|
||||||
|
/// Refusing the token regardless of JWKS freshness.
|
||||||
|
Permanent(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_validate(
|
||||||
|
token: &str,
|
||||||
|
header: &jsonwebtoken::Header,
|
||||||
|
kid: &str,
|
||||||
|
jwks: &JwkSet,
|
||||||
|
) -> Result<TenantContext, ValidationError> {
|
||||||
|
let jwk = match jwks
|
||||||
.keys
|
.keys
|
||||||
.iter()
|
.iter()
|
||||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
.find(|k| k.common.key_id.as_deref() == Some(kid))
|
||||||
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
{
|
||||||
|
Some(j) => j,
|
||||||
|
None => {
|
||||||
|
return Err(ValidationError::Stale(
|
||||||
|
"no matching key found in JWKS".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let decoding_key =
|
let decoding_key = DecodingKey::from_jwk(jwk)
|
||||||
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
.map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?;
|
||||||
|
|
||||||
let mut validation = Validation::new(header.alg);
|
let mut validation = Validation::new(header.alg);
|
||||||
validation.validate_exp = true;
|
validation.validate_exp = true;
|
||||||
validation.validate_aud = false;
|
validation.validate_aud = false;
|
||||||
|
|
||||||
let data = decode::<Claims>(token, &decoding_key, &validation)
|
let data = match decode::<Claims>(token, &decoding_key, &validation) {
|
||||||
.map_err(|e| format!("token validation failed: {e}"))?;
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
// Signature mismatch is the other refresh-eligible failure:
|
||||||
|
// the matching kid is present but the key bytes don't match.
|
||||||
|
// Everything else (expired, malformed, etc.) is permanent.
|
||||||
|
return Err(
|
||||||
|
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) {
|
||||||
|
ValidationError::Stale(format!("token validation failed: {e}"))
|
||||||
|
} else {
|
||||||
|
ValidationError::Permanent(format!("token validation failed: {e}"))
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
claims_to_context(data.claims)
|
claims_to_context(data.claims).map_err(ValidationError::Permanent)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
|
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
|
||||||
@@ -198,14 +254,25 @@ fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, String> {
|
||||||
{
|
if !force {
|
||||||
let cached = state.jwks.read().await;
|
let cached = state.jwks.read().await;
|
||||||
if let Some(ref jwks) = *cached {
|
if let Some(ref jwks) = *cached {
|
||||||
return Ok(jwks.clone());
|
return Ok(jwks.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hold the write lock across the fetch so concurrent refreshers
|
||||||
|
// don't all hammer Keycloak when keys rotate. If another writer
|
||||||
|
// already populated a fresh JWKS while we were waiting (and we
|
||||||
|
// weren't asked to force), use theirs.
|
||||||
|
let mut cached = state.jwks.write().await;
|
||||||
|
if !force {
|
||||||
|
if let Some(ref jwks) = *cached {
|
||||||
|
return Ok(jwks.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let resp = reqwest::get(&state.jwks_url)
|
let resp = reqwest::get(&state.jwks_url)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
||||||
@@ -215,7 +282,6 @@ async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
||||||
|
|
||||||
let mut cached = state.jwks.write().await;
|
|
||||||
*cached = Some(jwks.clone());
|
*cached = Some(jwks.clone());
|
||||||
|
|
||||||
Ok(jwks)
|
Ok(jwks)
|
||||||
@@ -293,6 +359,24 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn try_validate_returns_stale_when_kid_missing_from_jwks() {
|
||||||
|
// Empty JWKS — the kid we ask for can't possibly match. The error
|
||||||
|
// must classify as Stale so the caller refreshes JWKS and retries.
|
||||||
|
let jwks = JwkSet { keys: vec![] };
|
||||||
|
let header = jsonwebtoken::Header {
|
||||||
|
alg: jsonwebtoken::Algorithm::RS256,
|
||||||
|
kid: Some("kid-rotated-out".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks)
|
||||||
|
.expect_err("should fail");
|
||||||
|
match err {
|
||||||
|
ValidationError::Stale(s) => assert!(s.contains("no matching key")),
|
||||||
|
ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_write_detects_methods() {
|
fn is_write_detects_methods() {
|
||||||
assert!(!is_write(&Method::GET));
|
assert!(!is_write(&Method::GET));
|
||||||
|
|||||||
Reference in New Issue
Block a user