feat: add Keycloak authentication for dashboard and API endpoints (#2)
Dashboard: OAuth2/OIDC login flow with PKCE, session-based auth middleware protecting all server function endpoints, check-auth server function for frontend auth state, login page gate in AppShell, user info in sidebar. Agent API: JWT validation middleware using Keycloak JWKS endpoint, conditionally enabled when KEYCLOAK_URL and KEYCLOAK_REALM are set. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
@@ -35,3 +35,4 @@ walkdir = "2"
|
||||
base64 = "0.22"
|
||||
urlencoding = "2"
|
||||
futures-util = "0.3"
|
||||
jsonwebtoken = "9"
|
||||
|
||||
113
compliance-agent/src/api/auth_middleware.rs
Normal file
113
compliance-agent/src/api/auth_middleware.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Cached JWKS from Keycloak for token validation.
|
||||
#[derive(Clone)]
|
||||
pub struct JwksState {
|
||||
pub jwks: Arc<RwLock<Option<JwkSet>>>,
|
||||
pub jwks_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Claims {
|
||||
#[allow(dead_code)]
|
||||
sub: String,
|
||||
}
|
||||
|
||||
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
|
||||
|
||||
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS.
|
||||
///
|
||||
/// Skips validation for health check endpoints.
|
||||
/// If `JwksState` is not present as an extension (keycloak not configured),
|
||||
/// all requests pass through.
|
||||
pub async fn require_jwt_auth(request: Request, next: Next) -> Response {
|
||||
let path = request.uri().path();
|
||||
|
||||
if PUBLIC_ENDPOINTS.contains(&path) {
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let jwks_state = match request.extensions().get::<JwksState>() {
|
||||
Some(s) => s.clone(),
|
||||
None => return next.run(request).await,
|
||||
};
|
||||
|
||||
let auth_header = match request.headers().get("authorization") {
|
||||
Some(h) => h,
|
||||
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
|
||||
};
|
||||
|
||||
let token = match auth_header.to_str() {
|
||||
Ok(s) if s.starts_with("Bearer ") => &s[7..],
|
||||
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
|
||||
};
|
||||
|
||||
match validate_token(token, &jwks_state).await {
|
||||
Ok(()) => next.run(request).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("JWT validation failed: {e}");
|
||||
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> {
|
||||
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
|
||||
|
||||
let kid = header
|
||||
.kid
|
||||
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
||||
|
||||
let jwks = fetch_or_get_jwks(state).await?;
|
||||
|
||||
let jwk = jwks
|
||||
.keys
|
||||
.iter()
|
||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
||||
|
||||
let decoding_key =
|
||||
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
||||
|
||||
let mut validation = Validation::new(header.alg);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_aud = false;
|
||||
|
||||
decode::<Claims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("token validation failed: {e}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
||||
{
|
||||
let cached = state.jwks.read().await;
|
||||
if let Some(ref jwks) = *cached {
|
||||
return Ok(jwks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let resp = reqwest::get(&state.jwks_url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
||||
|
||||
let jwks: JwkSet = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
||||
|
||||
let mut cached = state.jwks.write().await;
|
||||
*cached = Some(jwks.clone());
|
||||
|
||||
Ok(jwks)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod auth_middleware;
|
||||
pub mod handlers;
|
||||
pub mod routes;
|
||||
pub mod server;
|
||||
|
||||
@@ -1,19 +1,37 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Extension;
|
||||
use axum::{middleware, Extension};
|
||||
use tokio::sync::RwLock;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::api::auth_middleware::{require_jwt_auth, JwksState};
|
||||
use crate::api::routes;
|
||||
use crate::error::AgentError;
|
||||
|
||||
pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> {
|
||||
let app = routes::build_router()
|
||||
.layer(Extension(Arc::new(agent)))
|
||||
let mut app = routes::build_router()
|
||||
.layer(Extension(Arc::new(agent.clone())))
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
if let (Some(kc_url), Some(kc_realm)) =
|
||||
(&agent.config.keycloak_url, &agent.config.keycloak_realm)
|
||||
{
|
||||
let jwks_url = format!("{kc_url}/realms/{kc_realm}/protocol/openid-connect/certs");
|
||||
let jwks_state = JwksState {
|
||||
jwks: Arc::new(RwLock::new(None)),
|
||||
jwks_url,
|
||||
};
|
||||
tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'");
|
||||
app = app
|
||||
.layer(Extension(jwks_state))
|
||||
.layer(middleware::from_fn(require_jwt_auth));
|
||||
} else {
|
||||
tracing::warn!("Keycloak not configured - API endpoints are unprotected");
|
||||
}
|
||||
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
let listener = tokio::net::TcpListener::bind(&addr)
|
||||
.await
|
||||
|
||||
@@ -45,5 +45,7 @@ pub fn load_config() -> Result<AgentConfig, AgentError> {
|
||||
.unwrap_or_else(|| "0 0 0 * * *".to_string()),
|
||||
git_clone_base_path: env_var_opt("GIT_CLONE_BASE_PATH")
|
||||
.unwrap_or_else(|| "/tmp/compliance-scanner/repos".to_string()),
|
||||
keycloak_url: env_var_opt("KEYCLOAK_URL"),
|
||||
keycloak_realm: env_var_opt("KEYCLOAK_REALM"),
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user