Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #4
235 lines
6.9 KiB
Rust
235 lines
6.9 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
sync::{Arc, RwLock},
|
|
};
|
|
|
|
use axum::{
|
|
extract::Query,
|
|
response::{IntoResponse, Redirect},
|
|
Extension,
|
|
};
|
|
use rand::Rng;
|
|
use tower_sessions::Session;
|
|
use url::Url;
|
|
|
|
use super::{
|
|
error::DashboardError,
|
|
server_state::ServerState,
|
|
user_state::{User, UserStateInner},
|
|
};
|
|
|
|
pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user";
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub(crate) struct PendingOAuthEntry {
|
|
pub(crate) redirect_url: Option<String>,
|
|
pub(crate) code_verifier: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct PendingOAuthStore(Arc<RwLock<HashMap<String, PendingOAuthEntry>>>);
|
|
|
|
impl PendingOAuthStore {
|
|
pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) {
|
|
#[allow(clippy::expect_used)]
|
|
self.0
|
|
.write()
|
|
.expect("pending oauth store lock poisoned")
|
|
.insert(state, entry);
|
|
}
|
|
|
|
pub(crate) fn take(&self, state: &str) -> Option<PendingOAuthEntry> {
|
|
#[allow(clippy::expect_used)]
|
|
self.0
|
|
.write()
|
|
.expect("pending oauth store lock poisoned")
|
|
.remove(state)
|
|
}
|
|
}
|
|
|
|
pub(crate) fn generate_state() -> String {
|
|
let bytes: [u8; 32] = rand::rng().random();
|
|
bytes.iter().fold(String::with_capacity(64), |mut acc, b| {
|
|
use std::fmt::Write;
|
|
let _ = write!(acc, "{b:02x}");
|
|
acc
|
|
})
|
|
}
|
|
|
|
pub(crate) fn generate_code_verifier() -> String {
|
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
|
let bytes: [u8; 32] = rand::rng().random();
|
|
URL_SAFE_NO_PAD.encode(bytes)
|
|
}
|
|
|
|
pub(crate) fn derive_code_challenge(verifier: &str) -> String {
|
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
|
use sha2::{Digest, Sha256};
|
|
let digest = Sha256::digest(verifier.as_bytes());
|
|
URL_SAFE_NO_PAD.encode(digest)
|
|
}
|
|
|
|
#[axum::debug_handler]
|
|
pub async fn auth_login(
|
|
Extension(state): Extension<ServerState>,
|
|
Extension(pending): Extension<PendingOAuthStore>,
|
|
Query(params): Query<HashMap<String, String>>,
|
|
) -> Result<impl IntoResponse, DashboardError> {
|
|
let kc = state
|
|
.keycloak
|
|
.ok_or(DashboardError::Other("Keycloak not configured".into()))?;
|
|
let csrf_state = generate_state();
|
|
let code_verifier = generate_code_verifier();
|
|
let code_challenge = derive_code_challenge(&code_verifier);
|
|
|
|
let redirect_url = params.get("redirect_url").cloned();
|
|
pending.insert(
|
|
csrf_state.clone(),
|
|
PendingOAuthEntry {
|
|
redirect_url,
|
|
code_verifier,
|
|
},
|
|
);
|
|
|
|
let mut url = Url::parse(&kc.auth_endpoint())
|
|
.map_err(|e| DashboardError::Other(format!("invalid auth endpoint URL: {e}")))?;
|
|
|
|
url.query_pairs_mut()
|
|
.append_pair("client_id", &kc.client_id)
|
|
.append_pair("redirect_uri", &kc.redirect_uri)
|
|
.append_pair("response_type", "code")
|
|
.append_pair("scope", "openid profile email")
|
|
.append_pair("state", &csrf_state)
|
|
.append_pair("code_challenge", &code_challenge)
|
|
.append_pair("code_challenge_method", "S256");
|
|
|
|
Ok(Redirect::temporary(url.as_str()))
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
struct TokenResponse {
|
|
access_token: String,
|
|
refresh_token: Option<String>,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
struct UserinfoResponse {
|
|
sub: String,
|
|
email: Option<String>,
|
|
preferred_username: Option<String>,
|
|
name: Option<String>,
|
|
picture: Option<String>,
|
|
}
|
|
|
|
#[axum::debug_handler]
|
|
pub async fn auth_callback(
|
|
session: Session,
|
|
Extension(state): Extension<ServerState>,
|
|
Extension(pending): Extension<PendingOAuthStore>,
|
|
Query(params): Query<HashMap<String, String>>,
|
|
) -> Result<impl IntoResponse, DashboardError> {
|
|
let kc = state
|
|
.keycloak
|
|
.ok_or(DashboardError::Other("Keycloak not configured".into()))?;
|
|
|
|
let returned_state = params
|
|
.get("state")
|
|
.ok_or_else(|| DashboardError::Other("missing state parameter".into()))?;
|
|
|
|
let entry = pending
|
|
.take(returned_state)
|
|
.ok_or_else(|| DashboardError::Other("unknown or expired oauth state".into()))?;
|
|
|
|
let code = params
|
|
.get("code")
|
|
.ok_or_else(|| DashboardError::Other("missing code parameter".into()))?;
|
|
|
|
let client = reqwest::Client::new();
|
|
let token_resp = client
|
|
.post(kc.token_endpoint())
|
|
.form(&[
|
|
("grant_type", "authorization_code"),
|
|
("client_id", kc.client_id.as_str()),
|
|
("redirect_uri", kc.redirect_uri.as_str()),
|
|
("code", code),
|
|
("code_verifier", &entry.code_verifier),
|
|
])
|
|
.send()
|
|
.await
|
|
.map_err(|e| DashboardError::Other(format!("token request failed: {e}")))?;
|
|
|
|
if !token_resp.status().is_success() {
|
|
let body = token_resp.text().await.unwrap_or_default();
|
|
return Err(DashboardError::Other(format!(
|
|
"token exchange failed: {body}"
|
|
)));
|
|
}
|
|
|
|
let tokens: TokenResponse = token_resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| DashboardError::Other(format!("token parse failed: {e}")))?;
|
|
|
|
let userinfo: UserinfoResponse = client
|
|
.get(kc.userinfo_endpoint())
|
|
.bearer_auth(&tokens.access_token)
|
|
.send()
|
|
.await
|
|
.map_err(|e| DashboardError::Other(format!("userinfo request failed: {e}")))?
|
|
.json()
|
|
.await
|
|
.map_err(|e| DashboardError::Other(format!("userinfo parse failed: {e}")))?;
|
|
|
|
let display_name = userinfo
|
|
.name
|
|
.or(userinfo.preferred_username)
|
|
.unwrap_or_default();
|
|
|
|
let user_state = UserStateInner {
|
|
sub: userinfo.sub,
|
|
access_token: tokens.access_token,
|
|
refresh_token: tokens.refresh_token.unwrap_or_default(),
|
|
user: User {
|
|
email: userinfo.email.unwrap_or_default(),
|
|
name: display_name,
|
|
avatar_url: userinfo.picture.unwrap_or_default(),
|
|
},
|
|
};
|
|
|
|
session
|
|
.insert(LOGGED_IN_USER_SESS_KEY, user_state)
|
|
.await
|
|
.map_err(|e| DashboardError::Other(format!("session insert failed: {e}")))?;
|
|
|
|
let target = entry
|
|
.redirect_url
|
|
.filter(|u| !u.is_empty())
|
|
.unwrap_or_else(|| "/".into());
|
|
|
|
Ok(Redirect::temporary(&target))
|
|
}
|
|
|
|
#[axum::debug_handler]
|
|
pub async fn logout(
|
|
session: Session,
|
|
Extension(state): Extension<ServerState>,
|
|
) -> Result<impl IntoResponse, DashboardError> {
|
|
let kc = state
|
|
.keycloak
|
|
.ok_or(DashboardError::Other("Keycloak not configured".into()))?;
|
|
|
|
session
|
|
.flush()
|
|
.await
|
|
.map_err(|e| DashboardError::Other(format!("session flush failed: {e}")))?;
|
|
|
|
let mut url = Url::parse(&kc.logout_endpoint())
|
|
.map_err(|e| DashboardError::Other(format!("invalid logout endpoint URL: {e}")))?;
|
|
|
|
url.query_pairs_mut()
|
|
.append_pair("client_id", &kc.client_id)
|
|
.append_pair("post_logout_redirect_uri", &kc.app_url);
|
|
|
|
Ok(Redirect::temporary(url.as_str()))
|
|
}
|