use std::{ collections::HashMap, sync::{Arc, RwLock}, }; use axum::{ extract::Query, response::{IntoResponse, Redirect}, Extension, }; use rand::RngExt; use tower_sessions::Session; use url::Url; use crate::infrastructure::{ server_state::ServerState, state::{User, UserStateInner}, Error, }; pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user"; /// Data stored alongside each pending OAuth state. Holds the optional /// post-login redirect URL and the PKCE code verifier needed for the /// token exchange. #[derive(Debug, Clone)] pub(crate) struct PendingOAuthEntry { pub(crate) redirect_url: Option, pub(crate) code_verifier: String, } /// In-memory store for pending OAuth states. Keyed by the random state /// string. This avoids dependence on the session cookie surviving the /// Keycloak redirect round-trip (the `dx serve` proxy can drop /// `Set-Cookie` headers on 307 responses). #[derive(Debug, Clone, Default)] pub struct PendingOAuthStore(Arc>>); impl PendingOAuthStore { /// Insert a pending state with an optional redirect URL and PKCE verifier. pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) { // RwLock::write only panics if the lock is poisoned, which // indicates a prior panic -- propagating is acceptable here. #[allow(clippy::expect_used)] self.0 .write() .expect("pending oauth store lock poisoned") .insert(state, entry); } /// Remove and return the entry if the state was pending. /// Returns `None` if the state was never stored (CSRF failure). pub(crate) fn take(&self, state: &str) -> Option { #[allow(clippy::expect_used)] self.0 .write() .expect("pending oauth store lock poisoned") .remove(state) } } /// Generate a cryptographically random state string for CSRF protection. #[cfg_attr(test, allow(dead_code))] pub(crate) fn generate_state() -> String { let bytes: [u8; 32] = rand::rng().random(); // Encode as hex to produce a URL-safe string without padding. bytes.iter().fold(String::with_capacity(64), |mut acc, b| { use std::fmt::Write; // write! on a String is infallible, safe to ignore the result. let _ = write!(acc, "{b:02x}"); acc }) } /// Generate a PKCE code verifier (43-128 char URL-safe random string). /// /// Uses 32 random bytes encoded as base64url (no padding) to produce /// a 43-character verifier per RFC 7636. pub(crate) fn generate_code_verifier() -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; let bytes: [u8; 32] = rand::rng().random(); URL_SAFE_NO_PAD.encode(bytes) } /// Derive the S256 code challenge from a code verifier per RFC 7636. /// /// `code_challenge = BASE64URL(SHA256(code_verifier))` pub(crate) fn derive_code_challenge(verifier: &str) -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use sha2::{Digest, Sha256}; let digest = Sha256::digest(verifier.as_bytes()); URL_SAFE_NO_PAD.encode(digest) } /// Redirect the user to Keycloak's authorization page. /// /// Generates a random CSRF state, stores it (along with the optional /// redirect URL) in the server-side `PendingOAuthStore`, and redirects /// the browser to Keycloak. /// /// # Query Parameters /// /// * `redirect_url` - Optional URL to redirect to after successful login. /// /// # Errors /// /// Returns `Error` if the Keycloak config is missing or the URL is malformed. #[axum::debug_handler] pub async fn auth_login( Extension(state): Extension, Extension(pending): Extension, Query(params): Query>, ) -> Result { let kc = state.keycloak; let csrf_state = generate_state(); let code_verifier = generate_code_verifier(); let code_challenge = derive_code_challenge(&code_verifier); let redirect_url = params.get("redirect_url").cloned(); pending.insert( csrf_state.clone(), PendingOAuthEntry { redirect_url, code_verifier, }, ); let mut url = Url::parse(&kc.auth_endpoint()) .map_err(|e| Error::StateError(format!("invalid auth endpoint URL: {e}")))?; url.query_pairs_mut() .append_pair("client_id", &kc.client_id) .append_pair("redirect_uri", &kc.redirect_uri) .append_pair("response_type", "code") .append_pair("scope", "openid profile email") .append_pair("state", &csrf_state) .append_pair("code_challenge", &code_challenge) .append_pair("code_challenge_method", "S256"); Ok(Redirect::temporary(url.as_str())) } /// Token endpoint response from Keycloak. #[derive(serde::Deserialize)] struct TokenResponse { access_token: String, refresh_token: Option, } /// Userinfo endpoint response from Keycloak. #[derive(serde::Deserialize)] struct UserinfoResponse { /// The subject identifier (unique user ID in Keycloak). sub: String, email: Option, /// Keycloak `preferred_username` claim. preferred_username: Option, /// Full name from the Keycloak profile. name: Option, /// Keycloak may include a picture/avatar URL via protocol mappers. picture: Option, } /// Handle the OAuth callback from Keycloak after the user authenticates. /// /// Validates the CSRF state against the `PendingOAuthStore`, exchanges /// the authorization code for tokens, fetches user info, stores the /// logged-in user in the tower-sessions session, and redirects to the app. /// /// # Query Parameters /// /// * `code` - The authorization code from Keycloak. /// * `state` - The CSRF state to verify against the pending store. /// /// # Errors /// /// Returns `Error` on CSRF mismatch, token exchange failure, or session issues. #[axum::debug_handler] pub async fn auth_callback( session: Session, Extension(state): Extension, Extension(pending): Extension, Query(params): Query>, ) -> Result { let kc = state.keycloak; // --- CSRF validation via the in-memory pending store --- let returned_state = params .get("state") .ok_or_else(|| Error::StateError("missing state parameter".into()))?; let entry = pending .take(returned_state) .ok_or_else(|| Error::StateError("unknown or expired oauth state".into()))?; // --- Exchange code for tokens (with PKCE code_verifier) --- let code = params .get("code") .ok_or_else(|| Error::StateError("missing code parameter".into()))?; let client = reqwest::Client::new(); let token_resp = client .post(kc.token_endpoint()) .form(&[ ("grant_type", "authorization_code"), ("client_id", kc.client_id.as_str()), ("redirect_uri", kc.redirect_uri.as_str()), ("code", code), ("code_verifier", &entry.code_verifier), ]) .send() .await .map_err(|e| Error::StateError(format!("token request failed: {e}")))?; if !token_resp.status().is_success() { let body = token_resp.text().await.unwrap_or_default(); return Err(Error::StateError(format!("token exchange failed: {body}"))); } let tokens: TokenResponse = token_resp .json() .await .map_err(|e| Error::StateError(format!("token parse failed: {e}")))?; // --- Fetch userinfo --- let userinfo: UserinfoResponse = client .get(kc.userinfo_endpoint()) .bearer_auth(&tokens.access_token) .send() .await .map_err(|e| Error::StateError(format!("userinfo request failed: {e}")))? .json() .await .map_err(|e| Error::StateError(format!("userinfo parse failed: {e}")))?; // Prefer `name`, fall back to `preferred_username`, then empty. let display_name = userinfo .name .or(userinfo.preferred_username) .unwrap_or_default(); // --- Build user state and persist in session --- let user_state = UserStateInner { sub: userinfo.sub, access_token: tokens.access_token, refresh_token: tokens.refresh_token.unwrap_or_default(), user: User { email: userinfo.email.unwrap_or_default(), name: display_name, avatar_url: userinfo.picture.unwrap_or_default(), }, }; set_login_session(session, user_state).await?; let target = entry .redirect_url .filter(|u| !u.is_empty()) .unwrap_or_else(|| "/".into()); Ok(Redirect::temporary(&target)) } /// Clear the user session and redirect to Keycloak's logout endpoint. /// /// After Keycloak finishes its own logout flow it will redirect /// back to the application root. /// /// # Errors /// /// Returns `Error` if the session cannot be flushed or the URL is malformed. #[axum::debug_handler] pub async fn logout( session: Session, Extension(state): Extension, ) -> Result { let kc = state.keycloak; // Flush all session data. session .flush() .await .map_err(|e| Error::StateError(format!("session flush failed: {e}")))?; let mut url = Url::parse(&kc.logout_endpoint()) .map_err(|e| Error::StateError(format!("invalid logout endpoint URL: {e}")))?; url.query_pairs_mut() .append_pair("client_id", &kc.client_id) .append_pair("post_logout_redirect_uri", &kc.app_url); Ok(Redirect::temporary(url.as_str())) } /// Persist user data into the session. /// /// # Errors /// /// Returns `Error` if the session store write fails. pub async fn set_login_session(session: Session, data: UserStateInner) -> Result<(), Error> { session .insert(LOGGED_IN_USER_SESS_KEY, data) .await .map_err(|e| Error::StateError(format!("session insert failed: {e}"))) } #[cfg(test)] mod tests { #![allow(clippy::unwrap_used, clippy::expect_used)] use super::*; use pretty_assertions::assert_eq; // ----------------------------------------------------------------------- // generate_state() // ----------------------------------------------------------------------- #[test] fn generate_state_length_is_64() { let state = generate_state(); assert_eq!(state.len(), 64); } #[test] fn generate_state_chars_are_hex() { let state = generate_state(); assert!(state.chars().all(|c| c.is_ascii_hexdigit())); } #[test] fn generate_state_two_calls_differ() { let a = generate_state(); let b = generate_state(); assert_ne!(a, b); } // ----------------------------------------------------------------------- // generate_code_verifier() // ----------------------------------------------------------------------- #[test] fn code_verifier_length_is_43() { let verifier = generate_code_verifier(); assert_eq!(verifier.len(), 43); } #[test] fn code_verifier_chars_are_url_safe_base64() { let verifier = generate_code_verifier(); // URL-safe base64 without padding uses [A-Za-z0-9_-] assert!(verifier .chars() .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } // ----------------------------------------------------------------------- // derive_code_challenge() // ----------------------------------------------------------------------- #[test] fn code_challenge_deterministic() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let a = derive_code_challenge(verifier); let b = derive_code_challenge(verifier); assert_eq!(a, b); } #[test] fn code_challenge_rfc7636_test_vector() { // RFC 7636 Appendix B test vector: // verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" // expected challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = derive_code_challenge(verifier); assert_eq!(challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); } // ----------------------------------------------------------------------- // PendingOAuthStore // ----------------------------------------------------------------------- #[test] fn pending_store_insert_and_take() { let store = PendingOAuthStore::default(); store.insert( "state-1".into(), PendingOAuthEntry { redirect_url: Some("/dashboard".into()), code_verifier: "verifier-1".into(), }, ); let entry = store.take("state-1"); assert!(entry.is_some()); let entry = entry.unwrap(); assert_eq!(entry.redirect_url, Some("/dashboard".into())); assert_eq!(entry.code_verifier, "verifier-1"); } #[test] fn pending_store_take_removes_entry() { let store = PendingOAuthStore::default(); store.insert( "state-2".into(), PendingOAuthEntry { redirect_url: None, code_verifier: "v2".into(), }, ); let _ = store.take("state-2"); // Second take should return None since the entry was removed. assert!(store.take("state-2").is_none()); } #[test] fn pending_store_take_unknown_returns_none() { let store = PendingOAuthStore::default(); assert!(store.take("nonexistent").is_none()); } }