use std::{ collections::HashMap, sync::{Arc, RwLock}, }; use axum::{ extract::Query, response::{IntoResponse, Redirect}, Extension, }; use rand::RngExt; use tower_sessions::Session; use url::Url; use crate::infrastructure::{state::User, Error, UserStateInner}; pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user"; /// Data stored alongside each pending OAuth state. Holds the optional /// post-login redirect URL and the PKCE code verifier needed for the /// token exchange. #[derive(Debug, Clone)] struct PendingOAuthEntry { redirect_url: Option, code_verifier: String, } /// In-memory store for pending OAuth states. Keyed by the random state /// string. This avoids dependence on the session cookie surviving the /// Keycloak redirect round-trip (the `dx serve` proxy can drop /// `Set-Cookie` headers on 307 responses). #[derive(Debug, Clone, Default)] pub struct PendingOAuthStore(Arc>>); impl PendingOAuthStore { /// Insert a pending state with an optional redirect URL and PKCE verifier. fn insert(&self, state: String, entry: PendingOAuthEntry) { // RwLock::write only panics if the lock is poisoned, which // indicates a prior panic -- propagating is acceptable here. #[allow(clippy::expect_used)] self.0 .write() .expect("pending oauth store lock poisoned") .insert(state, entry); } /// Remove and return the entry if the state was pending. /// Returns `None` if the state was never stored (CSRF failure). fn take(&self, state: &str) -> Option { #[allow(clippy::expect_used)] self.0 .write() .expect("pending oauth store lock poisoned") .remove(state) } } /// Configuration loaded from environment variables for Keycloak OAuth. struct OAuthConfig { keycloak_url: String, realm: String, client_id: String, redirect_uri: String, app_url: String, } impl OAuthConfig { /// Load OAuth configuration from environment variables. /// /// # Errors /// /// Returns `Error::StateError` if any required env var is missing. fn from_env() -> Result { dotenvy::dotenv().ok(); Ok(Self { keycloak_url: std::env::var("KEYCLOAK_URL") .map_err(|_| Error::StateError("KEYCLOAK_URL not set".into()))?, realm: std::env::var("KEYCLOAK_REALM") .map_err(|_| Error::StateError("KEYCLOAK_REALM not set".into()))?, client_id: std::env::var("KEYCLOAK_CLIENT_ID") .map_err(|_| Error::StateError("KEYCLOAK_CLIENT_ID not set".into()))?, redirect_uri: std::env::var("REDIRECT_URI") .map_err(|_| Error::StateError("REDIRECT_URI not set".into()))?, app_url: std::env::var("APP_URL") .map_err(|_| Error::StateError("APP_URL not set".into()))?, }) } /// Build the Keycloak OpenID Connect authorization endpoint URL. fn auth_endpoint(&self) -> String { format!( "{}/realms/{}/protocol/openid-connect/auth", self.keycloak_url, self.realm ) } /// Build the Keycloak OpenID Connect token endpoint URL. fn token_endpoint(&self) -> String { format!( "{}/realms/{}/protocol/openid-connect/token", self.keycloak_url, self.realm ) } /// Build the Keycloak OpenID Connect userinfo endpoint URL. fn userinfo_endpoint(&self) -> String { format!( "{}/realms/{}/protocol/openid-connect/userinfo", self.keycloak_url, self.realm ) } /// Build the Keycloak OpenID Connect end-session (logout) endpoint URL. fn logout_endpoint(&self) -> String { format!( "{}/realms/{}/protocol/openid-connect/logout", self.keycloak_url, self.realm ) } } /// Generate a cryptographically random state string for CSRF protection. fn generate_state() -> String { let bytes: [u8; 32] = rand::rng().random(); // Encode as hex to produce a URL-safe string without padding. bytes.iter().fold(String::with_capacity(64), |mut acc, b| { use std::fmt::Write; // write! on a String is infallible, safe to ignore the result. let _ = write!(acc, "{b:02x}"); acc }) } /// Generate a PKCE code verifier (43-128 char URL-safe random string). /// /// Uses 32 random bytes encoded as base64url (no padding) to produce /// a 43-character verifier per RFC 7636. fn generate_code_verifier() -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; let bytes: [u8; 32] = rand::rng().random(); URL_SAFE_NO_PAD.encode(bytes) } /// Derive the S256 code challenge from a code verifier per RFC 7636. /// /// `code_challenge = BASE64URL(SHA256(code_verifier))` fn derive_code_challenge(verifier: &str) -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use sha2::{Digest, Sha256}; let digest = Sha256::digest(verifier.as_bytes()); URL_SAFE_NO_PAD.encode(digest) } /// Redirect the user to Keycloak's authorization page. /// /// Generates a random CSRF state, stores it (along with the optional /// redirect URL) in the server-side `PendingOAuthStore`, and redirects /// the browser to Keycloak. /// /// # Query Parameters /// /// * `redirect_url` - Optional URL to redirect to after successful login. /// /// # Errors /// /// Returns `Error` if env vars are missing. #[axum::debug_handler] pub async fn auth_login( Extension(pending): Extension, Query(params): Query>, ) -> Result { let config = OAuthConfig::from_env()?; let state = generate_state(); let code_verifier = generate_code_verifier(); let code_challenge = derive_code_challenge(&code_verifier); let redirect_url = params.get("redirect_url").cloned(); pending.insert( state.clone(), PendingOAuthEntry { redirect_url, code_verifier, }, ); let mut url = Url::parse(&config.auth_endpoint()) .map_err(|e| Error::StateError(format!("invalid auth endpoint URL: {e}")))?; url.query_pairs_mut() .append_pair("client_id", &config.client_id) .append_pair("redirect_uri", &config.redirect_uri) .append_pair("response_type", "code") .append_pair("scope", "openid profile email") .append_pair("state", &state) .append_pair("code_challenge", &code_challenge) .append_pair("code_challenge_method", "S256"); Ok(Redirect::temporary(url.as_str())) } /// Token endpoint response from Keycloak. #[derive(serde::Deserialize)] struct TokenResponse { access_token: String, refresh_token: Option, } /// Userinfo endpoint response from Keycloak. #[derive(serde::Deserialize)] struct UserinfoResponse { /// The subject identifier (unique user ID in Keycloak). sub: String, email: Option, /// Keycloak may include a picture/avatar URL via protocol mappers. picture: Option, } /// Handle the OAuth callback from Keycloak after the user authenticates. /// /// Validates the CSRF state against the `PendingOAuthStore`, exchanges /// the authorization code for tokens, fetches user info, stores the /// logged-in user in the tower-sessions session, and redirects to the app. /// /// # Query Parameters /// /// * `code` - The authorization code from Keycloak. /// * `state` - The CSRF state to verify against the pending store. /// /// # Errors /// /// Returns `Error` on CSRF mismatch, token exchange failure, or session issues. #[axum::debug_handler] pub async fn auth_callback( session: Session, Extension(pending): Extension, Query(params): Query>, ) -> Result { let config = OAuthConfig::from_env()?; // --- CSRF validation via the in-memory pending store --- let returned_state = params .get("state") .ok_or_else(|| Error::StateError("missing state parameter".into()))?; let entry = pending .take(returned_state) .ok_or_else(|| Error::StateError("unknown or expired oauth state".into()))?; // --- Exchange code for tokens (with PKCE code_verifier) --- let code = params .get("code") .ok_or_else(|| Error::StateError("missing code parameter".into()))?; let client = reqwest::Client::new(); let token_resp = client .post(config.token_endpoint()) .form(&[ ("grant_type", "authorization_code"), ("client_id", &config.client_id), ("redirect_uri", &config.redirect_uri), ("code", code), ("code_verifier", &entry.code_verifier), ]) .send() .await .map_err(|e| Error::StateError(format!("token request failed: {e}")))?; if !token_resp.status().is_success() { let body = token_resp.text().await.unwrap_or_default(); return Err(Error::StateError(format!("token exchange failed: {body}"))); } let tokens: TokenResponse = token_resp .json() .await .map_err(|e| Error::StateError(format!("token parse failed: {e}")))?; // --- Fetch userinfo --- let userinfo: UserinfoResponse = client .get(config.userinfo_endpoint()) .bearer_auth(&tokens.access_token) .send() .await .map_err(|e| Error::StateError(format!("userinfo request failed: {e}")))? .json() .await .map_err(|e| Error::StateError(format!("userinfo parse failed: {e}")))?; // --- Build user state and persist in session --- let user_state = UserStateInner { sub: userinfo.sub, access_token: tokens.access_token, refresh_token: tokens.refresh_token.unwrap_or_default(), user: User { email: userinfo.email.unwrap_or_default(), avatar_url: userinfo.picture.unwrap_or_default(), }, }; set_login_session(session, user_state).await?; let target = entry .redirect_url .filter(|u| !u.is_empty()) .unwrap_or_else(|| "/".into()); Ok(Redirect::temporary(&target)) } /// Clear the user session and redirect to Keycloak's logout endpoint. /// /// After Keycloak finishes its own logout flow it will redirect /// back to the application root. /// /// # Errors /// /// Returns `Error` if env vars are missing or the session cannot be flushed. #[axum::debug_handler] pub async fn logout(session: Session) -> Result { let config = OAuthConfig::from_env()?; // Flush all session data. session .flush() .await .map_err(|e| Error::StateError(format!("session flush failed: {e}")))?; let mut url = Url::parse(&config.logout_endpoint()) .map_err(|e| Error::StateError(format!("invalid logout endpoint URL: {e}")))?; url.query_pairs_mut() .append_pair("client_id", &config.client_id) .append_pair("post_logout_redirect_uri", &config.app_url); Ok(Redirect::temporary(url.as_str())) } /// Persist user data into the session. /// /// # Errors /// /// Returns `Error` if the session store write fails. pub async fn set_login_session(session: Session, data: UserStateInner) -> Result<(), Error> { session .insert(LOGGED_IN_USER_SESS_KEY, data) .await .map_err(|e| Error::StateError(format!("session insert failed: {e}"))) }