Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 079f913024 | |||
| f583d0788c |
+12
-96
@@ -148,83 +148,27 @@ async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext,
|
|||||||
|
|
||||||
let kid = header
|
let kid = header
|
||||||
.kid
|
.kid
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
||||||
|
|
||||||
// First try against whatever's currently cached. If the kid isn't
|
let jwks = fetch_or_get_jwks(state).await?;
|
||||||
// there or the signature doesn't verify, the cached JWKS is most
|
|
||||||
// likely stale (KC rotated keys) — refresh once and retry before
|
|
||||||
// giving up. Without this every key rotation produces a silent 401
|
|
||||||
// storm that only goes away when the agent restarts.
|
|
||||||
let jwks = fetch_or_get_jwks(state, false).await?;
|
|
||||||
match try_validate(token, &header, &kid, &jwks) {
|
|
||||||
Ok(ctx) => Ok(ctx),
|
|
||||||
Err(ValidationError::Permanent(e)) => Err(e),
|
|
||||||
Err(ValidationError::Stale(reason)) => {
|
|
||||||
tracing::info!(
|
|
||||||
kid = %kid,
|
|
||||||
reason = %reason,
|
|
||||||
"JWKS appears stale — forcing refresh and retrying"
|
|
||||||
);
|
|
||||||
let jwks = fetch_or_get_jwks(state, true).await?;
|
|
||||||
try_validate(token, &header, &kid, &jwks).map_err(|e| match e {
|
|
||||||
ValidationError::Stale(s) | ValidationError::Permanent(s) => s,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
let jwk = jwks
|
||||||
enum ValidationError {
|
|
||||||
/// Refresh-eligible: cached JWKS may be stale.
|
|
||||||
Stale(String),
|
|
||||||
/// Refusing the token regardless of JWKS freshness.
|
|
||||||
Permanent(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
fn try_validate(
|
|
||||||
token: &str,
|
|
||||||
header: &jsonwebtoken::Header,
|
|
||||||
kid: &str,
|
|
||||||
jwks: &JwkSet,
|
|
||||||
) -> Result<TenantContext, ValidationError> {
|
|
||||||
let jwk = match jwks
|
|
||||||
.keys
|
.keys
|
||||||
.iter()
|
.iter()
|
||||||
.find(|k| k.common.key_id.as_deref() == Some(kid))
|
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||||
{
|
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
||||||
Some(j) => j,
|
|
||||||
None => {
|
|
||||||
return Err(ValidationError::Stale(
|
|
||||||
"no matching key found in JWKS".to_string(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let decoding_key = DecodingKey::from_jwk(jwk)
|
let decoding_key =
|
||||||
.map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?;
|
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
||||||
|
|
||||||
let mut validation = Validation::new(header.alg);
|
let mut validation = Validation::new(header.alg);
|
||||||
validation.validate_exp = true;
|
validation.validate_exp = true;
|
||||||
validation.validate_aud = false;
|
validation.validate_aud = false;
|
||||||
|
|
||||||
let data = match decode::<Claims>(token, &decoding_key, &validation) {
|
let data = decode::<Claims>(token, &decoding_key, &validation)
|
||||||
Ok(d) => d,
|
.map_err(|e| format!("token validation failed: {e}"))?;
|
||||||
Err(e) => {
|
|
||||||
// Signature mismatch is the other refresh-eligible failure:
|
|
||||||
// the matching kid is present but the key bytes don't match.
|
|
||||||
// Everything else (expired, malformed, etc.) is permanent.
|
|
||||||
return Err(
|
|
||||||
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) {
|
|
||||||
ValidationError::Stale(format!("token validation failed: {e}"))
|
|
||||||
} else {
|
|
||||||
ValidationError::Permanent(format!("token validation failed: {e}"))
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
claims_to_context(data.claims).map_err(ValidationError::Permanent)
|
claims_to_context(data.claims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
|
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
|
||||||
@@ -254,25 +198,14 @@ fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, String> {
|
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
||||||
if !force {
|
{
|
||||||
let cached = state.jwks.read().await;
|
let cached = state.jwks.read().await;
|
||||||
if let Some(ref jwks) = *cached {
|
if let Some(ref jwks) = *cached {
|
||||||
return Ok(jwks.clone());
|
return Ok(jwks.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hold the write lock across the fetch so concurrent refreshers
|
|
||||||
// don't all hammer Keycloak when keys rotate. If another writer
|
|
||||||
// already populated a fresh JWKS while we were waiting (and we
|
|
||||||
// weren't asked to force), use theirs.
|
|
||||||
let mut cached = state.jwks.write().await;
|
|
||||||
if !force {
|
|
||||||
if let Some(ref jwks) = *cached {
|
|
||||||
return Ok(jwks.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp = reqwest::get(&state.jwks_url)
|
let resp = reqwest::get(&state.jwks_url)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
||||||
@@ -282,6 +215,7 @@ async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, Str
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
||||||
|
|
||||||
|
let mut cached = state.jwks.write().await;
|
||||||
*cached = Some(jwks.clone());
|
*cached = Some(jwks.clone());
|
||||||
|
|
||||||
Ok(jwks)
|
Ok(jwks)
|
||||||
@@ -359,24 +293,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn try_validate_returns_stale_when_kid_missing_from_jwks() {
|
|
||||||
// Empty JWKS — the kid we ask for can't possibly match. The error
|
|
||||||
// must classify as Stale so the caller refreshes JWKS and retries.
|
|
||||||
let jwks = JwkSet { keys: vec![] };
|
|
||||||
let header = jsonwebtoken::Header {
|
|
||||||
alg: jsonwebtoken::Algorithm::RS256,
|
|
||||||
kid: Some("kid-rotated-out".to_string()),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks)
|
|
||||||
.expect_err("should fail");
|
|
||||||
match err {
|
|
||||||
ValidationError::Stale(s) => assert!(s.contains("no matching key")),
|
|
||||||
ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_write_detects_methods() {
|
fn is_write_detects_methods() {
|
||||||
assert!(!is_write(&Method::GET));
|
assert!(!is_write(&Method::GET));
|
||||||
|
|||||||
Reference in New Issue
Block a user