feat: add Keycloak authentication for dashboard and API endpoints (#2)
Dashboard: OAuth2/OIDC login flow with PKCE, session-based auth middleware protecting all server function endpoints, check-auth server function for frontend auth state, login page gate in AppShell, user info in sidebar. Agent API: JWT validation middleware using Keycloak JWKS endpoint, conditionally enabled when KEYCLOAK_URL and KEYCLOAK_REALM are set. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Cached JWKS from Keycloak for token validation.
|
||||
#[derive(Clone)]
|
||||
pub struct JwksState {
|
||||
pub jwks: Arc<RwLock<Option<JwkSet>>>,
|
||||
pub jwks_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Claims {
|
||||
#[allow(dead_code)]
|
||||
sub: String,
|
||||
}
|
||||
|
||||
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
|
||||
|
||||
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS.
|
||||
///
|
||||
/// Skips validation for health check endpoints.
|
||||
/// If `JwksState` is not present as an extension (keycloak not configured),
|
||||
/// all requests pass through.
|
||||
pub async fn require_jwt_auth(request: Request, next: Next) -> Response {
|
||||
let path = request.uri().path();
|
||||
|
||||
if PUBLIC_ENDPOINTS.contains(&path) {
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let jwks_state = match request.extensions().get::<JwksState>() {
|
||||
Some(s) => s.clone(),
|
||||
None => return next.run(request).await,
|
||||
};
|
||||
|
||||
let auth_header = match request.headers().get("authorization") {
|
||||
Some(h) => h,
|
||||
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
|
||||
};
|
||||
|
||||
let token = match auth_header.to_str() {
|
||||
Ok(s) if s.starts_with("Bearer ") => &s[7..],
|
||||
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
|
||||
};
|
||||
|
||||
match validate_token(token, &jwks_state).await {
|
||||
Ok(()) => next.run(request).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("JWT validation failed: {e}");
|
||||
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> {
|
||||
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
|
||||
|
||||
let kid = header
|
||||
.kid
|
||||
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
||||
|
||||
let jwks = fetch_or_get_jwks(state).await?;
|
||||
|
||||
let jwk = jwks
|
||||
.keys
|
||||
.iter()
|
||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
||||
|
||||
let decoding_key =
|
||||
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
||||
|
||||
let mut validation = Validation::new(header.alg);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_aud = false;
|
||||
|
||||
decode::<Claims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("token validation failed: {e}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
||||
{
|
||||
let cached = state.jwks.read().await;
|
||||
if let Some(ref jwks) = *cached {
|
||||
return Ok(jwks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let resp = reqwest::get(&state.jwks_url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
||||
|
||||
let jwks: JwkSet = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
||||
|
||||
let mut cached = state.jwks.write().await;
|
||||
*cached = Some(jwks.clone());
|
||||
|
||||
Ok(jwks)
|
||||
}
|
||||
Reference in New Issue
Block a user