use std::{ collections::HashMap, sync::{Arc, RwLock}, }; use axum::{ extract::Query, response::{IntoResponse, Redirect}, Extension, }; use rand::Rng; use tower_sessions::Session; use url::Url; use super::{ error::DashboardError, server_state::ServerState, user_state::{User, UserStateInner}, }; pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user"; #[derive(Debug, Clone)] pub(crate) struct PendingOAuthEntry { pub(crate) redirect_url: Option, pub(crate) code_verifier: String, } #[derive(Debug, Clone, Default)] pub struct PendingOAuthStore(Arc>>); impl PendingOAuthStore { pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) { #[allow(clippy::expect_used)] self.0 .write() .expect("pending oauth store lock poisoned") .insert(state, entry); } pub(crate) fn take(&self, state: &str) -> Option { #[allow(clippy::expect_used)] self.0 .write() .expect("pending oauth store lock poisoned") .remove(state) } } pub(crate) fn generate_state() -> String { let bytes: [u8; 32] = rand::rng().random(); bytes.iter().fold(String::with_capacity(64), |mut acc, b| { use std::fmt::Write; let _ = write!(acc, "{b:02x}"); acc }) } pub(crate) fn generate_code_verifier() -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; let bytes: [u8; 32] = rand::rng().random(); URL_SAFE_NO_PAD.encode(bytes) } pub(crate) fn derive_code_challenge(verifier: &str) -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use sha2::{Digest, Sha256}; let digest = Sha256::digest(verifier.as_bytes()); URL_SAFE_NO_PAD.encode(digest) } #[axum::debug_handler] pub async fn auth_login( Extension(state): Extension, Extension(pending): Extension, Query(params): Query>, ) -> Result { let kc = state .keycloak .ok_or(DashboardError::Other("Keycloak not configured".into()))?; let csrf_state = generate_state(); let code_verifier = generate_code_verifier(); let code_challenge = derive_code_challenge(&code_verifier); let redirect_url = params.get("redirect_url").cloned(); pending.insert( csrf_state.clone(), PendingOAuthEntry { redirect_url, code_verifier, }, ); let mut url = Url::parse(&kc.auth_endpoint()) .map_err(|e| DashboardError::Other(format!("invalid auth endpoint URL: {e}")))?; url.query_pairs_mut() .append_pair("client_id", &kc.client_id) .append_pair("redirect_uri", &kc.redirect_uri) .append_pair("response_type", "code") .append_pair("scope", "openid profile email") .append_pair("state", &csrf_state) .append_pair("code_challenge", &code_challenge) .append_pair("code_challenge_method", "S256"); Ok(Redirect::temporary(url.as_str())) } #[derive(serde::Deserialize)] struct TokenResponse { access_token: String, refresh_token: Option, } #[derive(serde::Deserialize)] struct UserinfoResponse { sub: String, email: Option, preferred_username: Option, name: Option, picture: Option, } #[axum::debug_handler] pub async fn auth_callback( session: Session, Extension(state): Extension, Extension(pending): Extension, Query(params): Query>, ) -> Result { let kc = state .keycloak .ok_or(DashboardError::Other("Keycloak not configured".into()))?; let returned_state = params .get("state") .ok_or_else(|| DashboardError::Other("missing state parameter".into()))?; let entry = pending .take(returned_state) .ok_or_else(|| DashboardError::Other("unknown or expired oauth state".into()))?; let code = params .get("code") .ok_or_else(|| DashboardError::Other("missing code parameter".into()))?; let client = reqwest::Client::new(); let token_resp = client .post(kc.token_endpoint()) .form(&[ ("grant_type", "authorization_code"), ("client_id", kc.client_id.as_str()), ("redirect_uri", kc.redirect_uri.as_str()), ("code", code), ("code_verifier", &entry.code_verifier), ]) .send() .await .map_err(|e| DashboardError::Other(format!("token request failed: {e}")))?; if !token_resp.status().is_success() { let body = token_resp.text().await.unwrap_or_default(); return Err(DashboardError::Other(format!( "token exchange failed: {body}" ))); } let tokens: TokenResponse = token_resp .json() .await .map_err(|e| DashboardError::Other(format!("token parse failed: {e}")))?; let userinfo: UserinfoResponse = client .get(kc.userinfo_endpoint()) .bearer_auth(&tokens.access_token) .send() .await .map_err(|e| DashboardError::Other(format!("userinfo request failed: {e}")))? .json() .await .map_err(|e| DashboardError::Other(format!("userinfo parse failed: {e}")))?; let display_name = userinfo .name .or(userinfo.preferred_username) .unwrap_or_default(); let user_state = UserStateInner { sub: userinfo.sub, access_token: tokens.access_token, refresh_token: tokens.refresh_token.unwrap_or_default(), user: User { email: userinfo.email.unwrap_or_default(), name: display_name, avatar_url: userinfo.picture.unwrap_or_default(), }, }; session .insert(LOGGED_IN_USER_SESS_KEY, user_state) .await .map_err(|e| DashboardError::Other(format!("session insert failed: {e}")))?; let target = entry .redirect_url .filter(|u| !u.is_empty()) .unwrap_or_else(|| "/".into()); Ok(Redirect::temporary(&target)) } #[axum::debug_handler] pub async fn logout( session: Session, Extension(state): Extension, ) -> Result { let kc = state .keycloak .ok_or(DashboardError::Other("Keycloak not configured".into()))?; session .flush() .await .map_err(|e| DashboardError::Other(format!("session flush failed: {e}")))?; let mut url = Url::parse(&kc.logout_endpoint()) .map_err(|e| DashboardError::Other(format!("invalid logout endpoint URL: {e}")))?; url.query_pairs_mut() .append_pair("client_id", &kc.client_id) .append_pair("post_logout_redirect_uri", &kc.app_url); Ok(Redirect::temporary(url.as_str())) }