From 0cb06d3d6dacffc7898da354c4b6165343e051c8 Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Sat, 7 Mar 2026 23:50:56 +0000 Subject: [PATCH] feat: add Keycloak authentication for dashboard and API endpoints (#2) Dashboard: OAuth2/OIDC login flow with PKCE, session-based auth middleware protecting all server function endpoints, check-auth server function for frontend auth state, login page gate in AppShell, user info in sidebar. Agent API: JWT validation middleware using Keycloak JWKS endpoint, conditionally enabled when KEYCLOAK_URL and KEYCLOAK_REALM are set. Co-Authored-By: Claude Opus 4.6 Co-authored-by: Sharang Parnerkar Reviewed-on: https://gitea.meghsakha.com/sharang/compliance-scanner-agent/pulls/2 --- Cargo.lock | 79 ++++++ compliance-agent/Cargo.toml | 1 + compliance-agent/src/api/auth_middleware.rs | 113 +++++++++ compliance-agent/src/api/mod.rs | 1 + compliance-agent/src/api/server.rs | 24 +- compliance-agent/src/config.rs | 2 + compliance-core/src/config.rs | 2 + compliance-core/src/models/auth.rs | 14 ++ compliance-core/src/models/mod.rs | 2 + compliance-dashboard/Cargo.toml | 12 + .../src/components/app_shell.rs | 52 +++- .../src/components/sidebar.rs | 29 ++- .../src/infrastructure/auth.rs | 228 ++++++++++++++++++ .../src/infrastructure/auth_check.rs | 32 +++ .../src/infrastructure/auth_middleware.rs | 33 +++ .../src/infrastructure/error.rs | 11 + .../src/infrastructure/keycloak_config.rs | 56 +++++ .../src/infrastructure/mod.rs | 15 +- .../src/infrastructure/server.rs | 28 ++- .../src/infrastructure/server_state.rs | 2 + .../src/infrastructure/user_state.rs | 18 ++ 21 files changed, 741 insertions(+), 13 deletions(-) create mode 100644 compliance-agent/src/api/auth_middleware.rs create mode 100644 compliance-core/src/models/auth.rs create mode 100644 compliance-dashboard/src/infrastructure/auth.rs create mode 100644 compliance-dashboard/src/infrastructure/auth_check.rs create mode 100644 compliance-dashboard/src/infrastructure/auth_middleware.rs create mode 100644 compliance-dashboard/src/infrastructure/keycloak_config.rs create mode 100644 compliance-dashboard/src/infrastructure/user_state.rs diff --git a/Cargo.lock b/Cargo.lock index a7a9f22..098d800 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -555,6 +555,7 @@ dependencies = [ "git2", "hex", "hmac", + "jsonwebtoken", "mongodb", "octocrab", "regex", @@ -595,6 +596,7 @@ name = "compliance-dashboard" version = "0.1.0" dependencies = [ "axum", + "base64", "chrono", "compliance-core", "dioxus", @@ -605,14 +607,19 @@ dependencies = [ "dotenvy", "gloo-timers", "mongodb", + "rand 0.9.2", "reqwest", "secrecy", "serde", "serde_json", + "sha2", "thiserror 2.0.18", + "time", "tokio", "tower-http", + "tower-sessions", "tracing", + "url", "web-sys", ] @@ -792,7 +799,12 @@ version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" dependencies = [ + "base64", + "hmac", "percent-encoding", + "rand 0.8.5", + "sha2", + "subtle", "time", "version_check", ] @@ -5228,6 +5240,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-cookies" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36" +dependencies = [ + "axum-core", + "cookie", + "futures-util", + "http", + "parking_lot", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-http" version = "0.6.8" @@ -5268,6 +5296,57 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "tower-sessions" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518dca34b74a17cadfcee06e616a09d2bd0c3984eff1769e1e76d58df978fc78" +dependencies = [ + "async-trait", + "http", + "time", + "tokio", + "tower-cookies", + "tower-layer", + "tower-service", + "tower-sessions-core", + "tower-sessions-memory-store", + "tracing", +] + +[[package]] +name = "tower-sessions-core" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "568531ec3dfcf3ffe493de1958ae5662a0284ac5d767476ecdb6a34ff8c6b06c" +dependencies = [ + "async-trait", + "axum-core", + "base64", + "futures", + "http", + "parking_lot", + "rand 0.9.2", + "serde", + "serde_json", + "thiserror 2.0.18", + "time", + "tokio", + "tracing", +] + +[[package]] +name = "tower-sessions-memory-store" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713fabf882b6560a831e2bbed6204048b35bdd60e50bbb722902c74f8df33460" +dependencies = [ + "async-trait", + "time", + "tokio", + "tower-sessions-core", +] + [[package]] name = "tracing" version = "0.1.44" diff --git a/compliance-agent/Cargo.toml b/compliance-agent/Cargo.toml index 7349248..0c6af63 100644 --- a/compliance-agent/Cargo.toml +++ b/compliance-agent/Cargo.toml @@ -35,3 +35,4 @@ walkdir = "2" base64 = "0.22" urlencoding = "2" futures-util = "0.3" +jsonwebtoken = "9" diff --git a/compliance-agent/src/api/auth_middleware.rs b/compliance-agent/src/api/auth_middleware.rs new file mode 100644 index 0000000..90e3a49 --- /dev/null +++ b/compliance-agent/src/api/auth_middleware.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use axum::{ + extract::Request, + middleware::Next, + response::{IntoResponse, Response}, +}; +use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation}; +use reqwest::StatusCode; +use serde::Deserialize; +use tokio::sync::RwLock; + +/// Cached JWKS from Keycloak for token validation. +#[derive(Clone)] +pub struct JwksState { + pub jwks: Arc>>, + pub jwks_url: String, +} + +#[derive(Debug, Deserialize)] +struct Claims { + #[allow(dead_code)] + sub: String, +} + +const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"]; + +/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS. +/// +/// Skips validation for health check endpoints. +/// If `JwksState` is not present as an extension (keycloak not configured), +/// all requests pass through. +pub async fn require_jwt_auth(request: Request, next: Next) -> Response { + let path = request.uri().path(); + + if PUBLIC_ENDPOINTS.contains(&path) { + return next.run(request).await; + } + + let jwks_state = match request.extensions().get::() { + Some(s) => s.clone(), + None => return next.run(request).await, + }; + + let auth_header = match request.headers().get("authorization") { + Some(h) => h, + None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(), + }; + + let token = match auth_header.to_str() { + Ok(s) if s.starts_with("Bearer ") => &s[7..], + _ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(), + }; + + match validate_token(token, &jwks_state).await { + Ok(()) => next.run(request).await, + Err(e) => { + tracing::warn!("JWT validation failed: {e}"); + (StatusCode::UNAUTHORIZED, "Invalid token").into_response() + } + } +} + +async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> { + let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?; + + let kid = header + .kid + .ok_or_else(|| "JWT missing kid header".to_string())?; + + let jwks = fetch_or_get_jwks(state).await?; + + let jwk = jwks + .keys + .iter() + .find(|k| k.common.key_id.as_deref() == Some(&kid)) + .ok_or_else(|| "no matching key found in JWKS".to_string())?; + + let decoding_key = + DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?; + + let mut validation = Validation::new(header.alg); + validation.validate_exp = true; + validation.validate_aud = false; + + decode::(token, &decoding_key, &validation) + .map_err(|e| format!("token validation failed: {e}"))?; + + Ok(()) +} + +async fn fetch_or_get_jwks(state: &JwksState) -> Result { + { + let cached = state.jwks.read().await; + if let Some(ref jwks) = *cached { + return Ok(jwks.clone()); + } + } + + let resp = reqwest::get(&state.jwks_url) + .await + .map_err(|e| format!("failed to fetch JWKS: {e}"))?; + + let jwks: JwkSet = resp + .json() + .await + .map_err(|e| format!("failed to parse JWKS: {e}"))?; + + let mut cached = state.jwks.write().await; + *cached = Some(jwks.clone()); + + Ok(jwks) +} diff --git a/compliance-agent/src/api/mod.rs b/compliance-agent/src/api/mod.rs index 43a142c..f969c3b 100644 --- a/compliance-agent/src/api/mod.rs +++ b/compliance-agent/src/api/mod.rs @@ -1,3 +1,4 @@ +pub mod auth_middleware; pub mod handlers; pub mod routes; pub mod server; diff --git a/compliance-agent/src/api/server.rs b/compliance-agent/src/api/server.rs index f6c1979..3038083 100644 --- a/compliance-agent/src/api/server.rs +++ b/compliance-agent/src/api/server.rs @@ -1,19 +1,37 @@ use std::sync::Arc; -use axum::Extension; +use axum::{middleware, Extension}; +use tokio::sync::RwLock; use tower_http::cors::CorsLayer; use tower_http::trace::TraceLayer; use crate::agent::ComplianceAgent; +use crate::api::auth_middleware::{require_jwt_auth, JwksState}; use crate::api::routes; use crate::error::AgentError; pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> { - let app = routes::build_router() - .layer(Extension(Arc::new(agent))) + let mut app = routes::build_router() + .layer(Extension(Arc::new(agent.clone()))) .layer(CorsLayer::permissive()) .layer(TraceLayer::new_for_http()); + if let (Some(kc_url), Some(kc_realm)) = + (&agent.config.keycloak_url, &agent.config.keycloak_realm) + { + let jwks_url = format!("{kc_url}/realms/{kc_realm}/protocol/openid-connect/certs"); + let jwks_state = JwksState { + jwks: Arc::new(RwLock::new(None)), + jwks_url, + }; + tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'"); + app = app + .layer(Extension(jwks_state)) + .layer(middleware::from_fn(require_jwt_auth)); + } else { + tracing::warn!("Keycloak not configured - API endpoints are unprotected"); + } + let addr = format!("0.0.0.0:{port}"); let listener = tokio::net::TcpListener::bind(&addr) .await diff --git a/compliance-agent/src/config.rs b/compliance-agent/src/config.rs index 06bf03d..612fc7d 100644 --- a/compliance-agent/src/config.rs +++ b/compliance-agent/src/config.rs @@ -45,5 +45,7 @@ pub fn load_config() -> Result { .unwrap_or_else(|| "0 0 0 * * *".to_string()), git_clone_base_path: env_var_opt("GIT_CLONE_BASE_PATH") .unwrap_or_else(|| "/tmp/compliance-scanner/repos".to_string()), + keycloak_url: env_var_opt("KEYCLOAK_URL"), + keycloak_realm: env_var_opt("KEYCLOAK_REALM"), }) } diff --git a/compliance-core/src/config.rs b/compliance-core/src/config.rs index 3f38740..aba5725 100644 --- a/compliance-core/src/config.rs +++ b/compliance-core/src/config.rs @@ -24,6 +24,8 @@ pub struct AgentConfig { pub scan_schedule: String, pub cve_monitor_schedule: String, pub git_clone_base_path: String, + pub keycloak_url: Option, + pub keycloak_realm: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/compliance-core/src/models/auth.rs b/compliance-core/src/models/auth.rs new file mode 100644 index 0000000..b0c935b --- /dev/null +++ b/compliance-core/src/models/auth.rs @@ -0,0 +1,14 @@ +use serde::{Deserialize, Serialize}; + +/// Authentication state returned by the `check_auth` server function. +/// +/// When no valid session exists, `authenticated` is `false` and all +/// other fields are empty strings. +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +pub struct AuthInfo { + pub authenticated: bool, + pub sub: String, + pub email: String, + pub name: String, + pub avatar_url: String, +} diff --git a/compliance-core/src/models/mod.rs b/compliance-core/src/models/mod.rs index c695d78..7718a57 100644 --- a/compliance-core/src/models/mod.rs +++ b/compliance-core/src/models/mod.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod chat; pub mod cve; pub mod dast; @@ -9,6 +10,7 @@ pub mod repository; pub mod sbom; pub mod scan; +pub use auth::AuthInfo; pub use chat::{ChatMessage, ChatRequest, ChatResponse, SourceReference}; pub use cve::{CveAlert, CveSource}; pub use dast::{ diff --git a/compliance-dashboard/Cargo.toml b/compliance-dashboard/Cargo.toml index 9f1366c..4825d66 100644 --- a/compliance-dashboard/Cargo.toml +++ b/compliance-dashboard/Cargo.toml @@ -27,6 +27,12 @@ server = [ "dep:dioxus-cli-config", "dep:dioxus-fullstack", "dep:tokio", + "dep:tower-sessions", + "dep:time", + "dep:rand", + "dep:url", + "dep:sha2", + "dep:base64", ] [dependencies] @@ -54,3 +60,9 @@ dotenvy = { version = "0.15", optional = true } tokio = { workspace = true, optional = true } dioxus-cli-config = { version = "=0.7.3", optional = true } dioxus-fullstack = { version = "=0.7.3", optional = true } +tower-sessions = { version = "0.15", default-features = false, features = ["axum-core", "memory-store", "signed"], optional = true } +time = { version = "0.3", default-features = false, optional = true } +rand = { version = "0.9", optional = true } +url = { version = "2", optional = true } +sha2 = { workspace = true, optional = true } +base64 = { version = "0.22", optional = true } diff --git a/compliance-dashboard/src/components/app_shell.rs b/compliance-dashboard/src/components/app_shell.rs index f165d5d..1f003ec 100644 --- a/compliance-dashboard/src/components/app_shell.rs +++ b/compliance-dashboard/src/components/app_shell.rs @@ -3,17 +3,57 @@ use dioxus::prelude::*; use crate::app::Route; use crate::components::sidebar::Sidebar; use crate::components::toast::{ToastContainer, Toasts}; +use crate::infrastructure::auth_check::check_auth; #[component] pub fn AppShell() -> Element { use_context_provider(Toasts::new); - rsx! { - div { class: "app-shell", - Sidebar {} - main { class: "main-content", - Outlet:: {} + + let auth = use_server_future(check_auth)?; + + match auth() { + Some(Ok(info)) if info.authenticated => { + use_context_provider(|| Signal::new(info.clone())); + rsx! { + div { class: "app-shell", + Sidebar {} + main { class: "main-content", + Outlet:: {} + } + ToastContainer {} + } + } + } + Some(Ok(_)) => { + rsx! { LoginPage {} } + } + Some(Err(e)) => { + tracing::error!("Auth check failed: {e}"); + rsx! { LoginPage {} } + } + None => { + rsx! { + div { class: "flex items-center justify-center h-screen bg-gray-950", + p { class: "text-gray-400", "Loading..." } + } + } + } + } +} + +#[component] +fn LoginPage() -> Element { + rsx! { + div { class: "flex items-center justify-center h-screen bg-gray-950", + div { class: "text-center", + h1 { class: "text-3xl font-bold text-white mb-4", "Compliance Scanner" } + p { class: "text-gray-400 mb-8", "Sign in to access the dashboard" } + a { + href: "/auth", + class: "px-6 py-3 bg-blue-600 text-white rounded-lg hover:bg-blue-500 transition-colors font-medium", + "Sign in with Keycloak" + } } - ToastContainer {} } } } diff --git a/compliance-dashboard/src/components/sidebar.rs b/compliance-dashboard/src/components/sidebar.rs index bb9ab69..42ba63e 100644 --- a/compliance-dashboard/src/components/sidebar.rs +++ b/compliance-dashboard/src/components/sidebar.rs @@ -1,3 +1,4 @@ +use compliance_core::models::auth::AuthInfo; use dioxus::prelude::*; use dioxus_free_icons::icons::bs_icons::*; use dioxus_free_icons::Icon; @@ -114,8 +115,32 @@ pub fn Sidebar() -> Element { Icon { icon: BsChevronLeft, width: 14, height: 14 } } } - if !collapsed() { - div { class: "sidebar-footer", "v0.1.0" } + { + let auth_info = use_context::>(); + let info = auth_info(); + let initials = info.name.chars().next().unwrap_or('U').to_uppercase().to_string(); + rsx! { + div { class: "sidebar-user", + div { class: "user-avatar", + if info.avatar_url.is_empty() { + span { class: "avatar-initials", "{initials}" } + } else { + img { src: "{info.avatar_url}", alt: "avatar", class: "avatar-img" } + } + } + if !collapsed() { + div { class: "user-info", + span { class: "user-name", "{info.name}" } + a { + href: "/logout", + class: "logout-link", + Icon { icon: BsBoxArrowRight, width: 14, height: 14 } + " Logout" + } + } + } + } + } } } } diff --git a/compliance-dashboard/src/infrastructure/auth.rs b/compliance-dashboard/src/infrastructure/auth.rs new file mode 100644 index 0000000..eac180a --- /dev/null +++ b/compliance-dashboard/src/infrastructure/auth.rs @@ -0,0 +1,228 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use axum::{ + extract::Query, + response::{IntoResponse, Redirect}, + Extension, +}; +use rand::Rng; +use tower_sessions::Session; +use url::Url; + +use super::{ + error::DashboardError, + server_state::ServerState, + user_state::{User, UserStateInner}, +}; + +pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user"; + +#[derive(Debug, Clone)] +pub(crate) struct PendingOAuthEntry { + pub(crate) redirect_url: Option, + pub(crate) code_verifier: String, +} + +#[derive(Debug, Clone, Default)] +pub struct PendingOAuthStore(Arc>>); + +impl PendingOAuthStore { + pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) { + #[allow(clippy::expect_used)] + self.0 + .write() + .expect("pending oauth store lock poisoned") + .insert(state, entry); + } + + pub(crate) fn take(&self, state: &str) -> Option { + #[allow(clippy::expect_used)] + self.0 + .write() + .expect("pending oauth store lock poisoned") + .remove(state) + } +} + +pub(crate) fn generate_state() -> String { + let bytes: [u8; 32] = rand::rng().random(); + bytes.iter().fold(String::with_capacity(64), |mut acc, b| { + use std::fmt::Write; + let _ = write!(acc, "{b:02x}"); + acc + }) +} + +pub(crate) fn generate_code_verifier() -> String { + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; + let bytes: [u8; 32] = rand::rng().random(); + URL_SAFE_NO_PAD.encode(bytes) +} + +pub(crate) fn derive_code_challenge(verifier: &str) -> String { + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; + use sha2::{Digest, Sha256}; + let digest = Sha256::digest(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(digest) +} + +#[axum::debug_handler] +pub async fn auth_login( + Extension(state): Extension, + Extension(pending): Extension, + Query(params): Query>, +) -> Result { + let kc = state.keycloak; + let csrf_state = generate_state(); + let code_verifier = generate_code_verifier(); + let code_challenge = derive_code_challenge(&code_verifier); + + let redirect_url = params.get("redirect_url").cloned(); + pending.insert( + csrf_state.clone(), + PendingOAuthEntry { + redirect_url, + code_verifier, + }, + ); + + let mut url = Url::parse(&kc.auth_endpoint()) + .map_err(|e| DashboardError::Other(format!("invalid auth endpoint URL: {e}")))?; + + url.query_pairs_mut() + .append_pair("client_id", &kc.client_id) + .append_pair("redirect_uri", &kc.redirect_uri) + .append_pair("response_type", "code") + .append_pair("scope", "openid profile email") + .append_pair("state", &csrf_state) + .append_pair("code_challenge", &code_challenge) + .append_pair("code_challenge_method", "S256"); + + Ok(Redirect::temporary(url.as_str())) +} + +#[derive(serde::Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: Option, +} + +#[derive(serde::Deserialize)] +struct UserinfoResponse { + sub: String, + email: Option, + preferred_username: Option, + name: Option, + picture: Option, +} + +#[axum::debug_handler] +pub async fn auth_callback( + session: Session, + Extension(state): Extension, + Extension(pending): Extension, + Query(params): Query>, +) -> Result { + let kc = state.keycloak; + + let returned_state = params + .get("state") + .ok_or_else(|| DashboardError::Other("missing state parameter".into()))?; + + let entry = pending + .take(returned_state) + .ok_or_else(|| DashboardError::Other("unknown or expired oauth state".into()))?; + + let code = params + .get("code") + .ok_or_else(|| DashboardError::Other("missing code parameter".into()))?; + + let client = reqwest::Client::new(); + let token_resp = client + .post(kc.token_endpoint()) + .form(&[ + ("grant_type", "authorization_code"), + ("client_id", kc.client_id.as_str()), + ("redirect_uri", kc.redirect_uri.as_str()), + ("code", code), + ("code_verifier", &entry.code_verifier), + ]) + .send() + .await + .map_err(|e| DashboardError::Other(format!("token request failed: {e}")))?; + + if !token_resp.status().is_success() { + let body = token_resp.text().await.unwrap_or_default(); + return Err(DashboardError::Other(format!( + "token exchange failed: {body}" + ))); + } + + let tokens: TokenResponse = token_resp + .json() + .await + .map_err(|e| DashboardError::Other(format!("token parse failed: {e}")))?; + + let userinfo: UserinfoResponse = client + .get(kc.userinfo_endpoint()) + .bearer_auth(&tokens.access_token) + .send() + .await + .map_err(|e| DashboardError::Other(format!("userinfo request failed: {e}")))? + .json() + .await + .map_err(|e| DashboardError::Other(format!("userinfo parse failed: {e}")))?; + + let display_name = userinfo + .name + .or(userinfo.preferred_username) + .unwrap_or_default(); + + let user_state = UserStateInner { + sub: userinfo.sub, + access_token: tokens.access_token, + refresh_token: tokens.refresh_token.unwrap_or_default(), + user: User { + email: userinfo.email.unwrap_or_default(), + name: display_name, + avatar_url: userinfo.picture.unwrap_or_default(), + }, + }; + + session + .insert(LOGGED_IN_USER_SESS_KEY, user_state) + .await + .map_err(|e| DashboardError::Other(format!("session insert failed: {e}")))?; + + let target = entry + .redirect_url + .filter(|u| !u.is_empty()) + .unwrap_or_else(|| "/".into()); + + Ok(Redirect::temporary(&target)) +} + +#[axum::debug_handler] +pub async fn logout( + session: Session, + Extension(state): Extension, +) -> Result { + let kc = state.keycloak; + + session + .flush() + .await + .map_err(|e| DashboardError::Other(format!("session flush failed: {e}")))?; + + let mut url = Url::parse(&kc.logout_endpoint()) + .map_err(|e| DashboardError::Other(format!("invalid logout endpoint URL: {e}")))?; + + url.query_pairs_mut() + .append_pair("client_id", &kc.client_id) + .append_pair("post_logout_redirect_uri", &kc.app_url); + + Ok(Redirect::temporary(url.as_str())) +} diff --git a/compliance-dashboard/src/infrastructure/auth_check.rs b/compliance-dashboard/src/infrastructure/auth_check.rs new file mode 100644 index 0000000..52ea5d3 --- /dev/null +++ b/compliance-dashboard/src/infrastructure/auth_check.rs @@ -0,0 +1,32 @@ +use compliance_core::models::auth::AuthInfo; +use dioxus::prelude::*; + +/// Check the current user's authentication state. +/// +/// Reads the tower-sessions session on the server and returns an +/// [`AuthInfo`] describing the logged-in user. When no valid session +/// exists, `authenticated` is `false` and all other fields are empty. +#[server(endpoint = "check-auth")] +pub async fn check_auth() -> Result { + use super::auth::LOGGED_IN_USER_SESS_KEY; + use super::user_state::UserStateInner; + use dioxus_fullstack::FullstackContext; + + let session: tower_sessions::Session = FullstackContext::extract().await?; + + let user_state: Option = session + .get(LOGGED_IN_USER_SESS_KEY) + .await + .map_err(|e| ServerFnError::new(format!("session read failed: {e}")))?; + + match user_state { + Some(u) => Ok(AuthInfo { + authenticated: true, + sub: u.sub, + email: u.user.email, + name: u.user.name, + avatar_url: u.user.avatar_url, + }), + None => Ok(AuthInfo::default()), + } +} diff --git a/compliance-dashboard/src/infrastructure/auth_middleware.rs b/compliance-dashboard/src/infrastructure/auth_middleware.rs new file mode 100644 index 0000000..01a1cb6 --- /dev/null +++ b/compliance-dashboard/src/infrastructure/auth_middleware.rs @@ -0,0 +1,33 @@ +use axum::{ + extract::Request, + middleware::Next, + response::{IntoResponse, Response}, +}; +use reqwest::StatusCode; +use tower_sessions::Session; + +use super::auth::LOGGED_IN_USER_SESS_KEY; +use super::user_state::UserStateInner; + +const PUBLIC_API_ENDPOINTS: &[&str] = &["/api/check-auth"]; + +/// Axum middleware that enforces authentication on `/api/` server +/// function endpoints. +pub async fn require_auth(session: Session, request: Request, next: Next) -> Response { + let path = request.uri().path(); + + if path.starts_with("/api/") && !PUBLIC_API_ENDPOINTS.contains(&path) { + let is_authed = session + .get::(LOGGED_IN_USER_SESS_KEY) + .await + .ok() + .flatten() + .is_some(); + + if !is_authed { + return (StatusCode::UNAUTHORIZED, "Authentication required").into_response(); + } + } + + next.run(request).await +} diff --git a/compliance-dashboard/src/infrastructure/error.rs b/compliance-dashboard/src/infrastructure/error.rs index 8b1c15d..9fa16d4 100644 --- a/compliance-dashboard/src/infrastructure/error.rs +++ b/compliance-dashboard/src/infrastructure/error.rs @@ -24,3 +24,14 @@ impl From for ServerFnError { ServerFnError::new(err.to_string()) } } + +#[cfg(feature = "server")] +impl axum::response::IntoResponse for DashboardError { + fn into_response(self) -> axum::response::Response { + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + self.to_string(), + ) + .into_response() + } +} diff --git a/compliance-dashboard/src/infrastructure/keycloak_config.rs b/compliance-dashboard/src/infrastructure/keycloak_config.rs new file mode 100644 index 0000000..561d681 --- /dev/null +++ b/compliance-dashboard/src/infrastructure/keycloak_config.rs @@ -0,0 +1,56 @@ +use super::error::DashboardError; + +/// Keycloak OpenID Connect settings. +#[derive(Debug)] +pub struct KeycloakConfig { + pub url: String, + pub realm: String, + pub client_id: String, + pub redirect_uri: String, + pub app_url: String, +} + +impl KeycloakConfig { + pub fn from_env() -> Result { + Ok(Self { + url: required_env("KEYCLOAK_URL")?, + realm: required_env("KEYCLOAK_REALM")?, + client_id: required_env("KEYCLOAK_CLIENT_ID")?, + redirect_uri: required_env("REDIRECT_URI")?, + app_url: required_env("APP_URL")?, + }) + } + + pub fn auth_endpoint(&self) -> String { + format!( + "{}/realms/{}/protocol/openid-connect/auth", + self.url, self.realm + ) + } + + pub fn token_endpoint(&self) -> String { + format!( + "{}/realms/{}/protocol/openid-connect/token", + self.url, self.realm + ) + } + + pub fn userinfo_endpoint(&self) -> String { + format!( + "{}/realms/{}/protocol/openid-connect/userinfo", + self.url, self.realm + ) + } + + pub fn logout_endpoint(&self) -> String { + format!( + "{}/realms/{}/protocol/openid-connect/logout", + self.url, self.realm + ) + } +} + +fn required_env(name: &str) -> Result { + std::env::var(name) + .map_err(|_| DashboardError::Config(format!("{name} is required but not set"))) +} diff --git a/compliance-dashboard/src/infrastructure/mod.rs b/compliance-dashboard/src/infrastructure/mod.rs index 9ee2706..2b8831d 100644 --- a/compliance-dashboard/src/infrastructure/mod.rs +++ b/compliance-dashboard/src/infrastructure/mod.rs @@ -1,5 +1,6 @@ // Server function modules (compiled for both web and server; // the #[server] macro generates client stubs for the web target) +pub mod auth_check; pub mod chat; pub mod dast; pub mod findings; @@ -12,15 +13,27 @@ pub mod stats; // Server-only modules #[cfg(feature = "server")] +mod auth; +#[cfg(feature = "server")] +mod auth_middleware; +#[cfg(feature = "server")] pub mod config; #[cfg(feature = "server")] pub mod database; #[cfg(feature = "server")] pub mod error; #[cfg(feature = "server")] -pub mod server; +pub mod keycloak_config; +#[cfg(feature = "server")] +mod server; #[cfg(feature = "server")] pub mod server_state; +#[cfg(feature = "server")] +mod user_state; +#[cfg(feature = "server")] +pub use auth::{auth_callback, auth_login, logout, PendingOAuthStore}; +#[cfg(feature = "server")] +pub use auth_middleware::require_auth; #[cfg(feature = "server")] pub use server::server_start; diff --git a/compliance-dashboard/src/infrastructure/server.rs b/compliance-dashboard/src/infrastructure/server.rs index 0df8a46..0fa6e52 100644 --- a/compliance-dashboard/src/infrastructure/server.rs +++ b/compliance-dashboard/src/infrastructure/server.rs @@ -1,9 +1,15 @@ +use axum::routing::get; +use axum::{middleware, Extension}; use dioxus::prelude::*; +use time::Duration; +use tower_sessions::{cookie::Key, MemoryStore, SessionManagerLayer}; use super::config; use super::database::Database; use super::error::DashboardError; +use super::keycloak_config::KeycloakConfig; use super::server_state::{ServerState, ServerStateInner}; +use super::{auth_callback, auth_login, logout, require_auth, PendingOAuthStore}; pub fn server_start(app: fn() -> Element) -> Result<(), DashboardError> { tokio::runtime::Runtime::new() @@ -12,15 +18,29 @@ pub fn server_start(app: fn() -> Element) -> Result<(), DashboardError> { dotenvy::dotenv().ok(); let config = config::load_config()?; + let keycloak: &'static KeycloakConfig = + Box::leak(Box::new(KeycloakConfig::from_env()?)); let db = Database::connect(&config.mongodb_uri, &config.mongodb_database).await?; + tracing::info!("Keycloak configured for realm '{}'", keycloak.realm); + let server_state: ServerState = ServerStateInner { agent_api_url: config.agent_api_url.clone(), db, config, + keycloak, } .into(); + // Session layer + let key = Key::generate(); + let store = MemoryStore::default(); + let session = SessionManagerLayer::new(store) + .with_secure(false) + .with_same_site(tower_sessions::cookie::SameSite::Lax) + .with_expiry(tower_sessions::Expiry::OnInactivity(Duration::hours(24))) + .with_signed(key); + let addr = dioxus_cli_config::fullstack_address_or_localhost(); let listener = tokio::net::TcpListener::bind(addr) .await @@ -29,8 +49,14 @@ pub fn server_start(app: fn() -> Element) -> Result<(), DashboardError> { tracing::info!("Dashboard server listening on {addr}"); let router = axum::Router::new() + .route("/auth", get(auth_login)) + .route("/auth/callback", get(auth_callback)) + .route("/logout", get(logout)) .serve_dioxus_application(ServeConfig::new(), app) - .layer(axum::Extension(server_state)); + .layer(Extension(PendingOAuthStore::default())) + .layer(Extension(server_state)) + .layer(middleware::from_fn(require_auth)) + .layer(session); axum::serve(listener, router.into_make_service()) .await diff --git a/compliance-dashboard/src/infrastructure/server_state.rs b/compliance-dashboard/src/infrastructure/server_state.rs index 9f6cec2..2130784 100644 --- a/compliance-dashboard/src/infrastructure/server_state.rs +++ b/compliance-dashboard/src/infrastructure/server_state.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use compliance_core::DashboardConfig; use super::database::Database; +use super::keycloak_config::KeycloakConfig; #[derive(Clone)] pub struct ServerState(Arc); @@ -19,6 +20,7 @@ pub struct ServerStateInner { pub db: Database, pub config: DashboardConfig, pub agent_api_url: String, + pub keycloak: &'static KeycloakConfig, } impl From for ServerState { diff --git a/compliance-dashboard/src/infrastructure/user_state.rs b/compliance-dashboard/src/infrastructure/user_state.rs new file mode 100644 index 0000000..320f8b1 --- /dev/null +++ b/compliance-dashboard/src/infrastructure/user_state.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +/// Per-session user data stored in the tower-sessions session store. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct UserStateInner { + pub sub: String, + pub access_token: String, + pub refresh_token: String, + pub user: User, +} + +/// Basic user profile stored alongside the session. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct User { + pub email: String, + pub name: String, + pub avatar_url: String, +}