feat: added oauth based login and registration (#1)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
307
src/infrastructure/auth.rs
Normal file
307
src/infrastructure/auth.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
response::{IntoResponse, Redirect},
|
||||
Extension,
|
||||
};
|
||||
use rand::RngExt;
|
||||
use tower_sessions::Session;
|
||||
use url::Url;
|
||||
|
||||
use crate::infrastructure::{state::User, Error, UserStateInner};
|
||||
|
||||
pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user";
|
||||
|
||||
/// In-memory store for pending OAuth states and their associated redirect
|
||||
/// URLs. Keyed by the random state string. This avoids dependence on the
|
||||
/// session cookie surviving the Keycloak redirect round-trip (the `dx serve`
|
||||
/// proxy can drop `Set-Cookie` headers on 307 responses).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PendingOAuthStore(Arc<RwLock<HashMap<String, Option<String>>>>);
|
||||
|
||||
impl PendingOAuthStore {
|
||||
/// Insert a pending state with an optional post-login redirect URL.
|
||||
fn insert(&self, state: String, redirect_url: Option<String>) {
|
||||
// RwLock::write only panics if the lock is poisoned, which
|
||||
// indicates a prior panic -- propagating is acceptable here.
|
||||
#[allow(clippy::expect_used)]
|
||||
self.0
|
||||
.write()
|
||||
.expect("pending oauth store lock poisoned")
|
||||
.insert(state, redirect_url);
|
||||
}
|
||||
|
||||
/// Remove and return the redirect URL if the state was pending.
|
||||
/// Returns `None` if the state was never stored (CSRF failure).
|
||||
fn take(&self, state: &str) -> Option<Option<String>> {
|
||||
#[allow(clippy::expect_used)]
|
||||
self.0
|
||||
.write()
|
||||
.expect("pending oauth store lock poisoned")
|
||||
.remove(state)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration loaded from environment variables for Keycloak OAuth.
|
||||
struct OAuthConfig {
|
||||
keycloak_url: String,
|
||||
realm: String,
|
||||
client_id: String,
|
||||
redirect_uri: String,
|
||||
app_url: String,
|
||||
}
|
||||
|
||||
impl OAuthConfig {
|
||||
/// Load OAuth configuration from environment variables.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error::StateError` if any required env var is missing.
|
||||
fn from_env() -> Result<Self, Error> {
|
||||
dotenvy::dotenv().ok();
|
||||
Ok(Self {
|
||||
keycloak_url: std::env::var("KEYCLOAK_URL")
|
||||
.map_err(|_| Error::StateError("KEYCLOAK_URL not set".into()))?,
|
||||
realm: std::env::var("KEYCLOAK_REALM")
|
||||
.map_err(|_| Error::StateError("KEYCLOAK_REALM not set".into()))?,
|
||||
client_id: std::env::var("KEYCLOAK_CLIENT_ID")
|
||||
.map_err(|_| Error::StateError("KEYCLOAK_CLIENT_ID not set".into()))?,
|
||||
redirect_uri: std::env::var("REDIRECT_URI")
|
||||
.map_err(|_| Error::StateError("REDIRECT_URI not set".into()))?,
|
||||
app_url: std::env::var("APP_URL")
|
||||
.map_err(|_| Error::StateError("APP_URL not set".into()))?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build the Keycloak OpenID Connect authorization endpoint URL.
|
||||
fn auth_endpoint(&self) -> String {
|
||||
format!(
|
||||
"{}/realms/{}/protocol/openid-connect/auth",
|
||||
self.keycloak_url, self.realm
|
||||
)
|
||||
}
|
||||
|
||||
/// Build the Keycloak OpenID Connect token endpoint URL.
|
||||
fn token_endpoint(&self) -> String {
|
||||
format!(
|
||||
"{}/realms/{}/protocol/openid-connect/token",
|
||||
self.keycloak_url, self.realm
|
||||
)
|
||||
}
|
||||
|
||||
/// Build the Keycloak OpenID Connect userinfo endpoint URL.
|
||||
fn userinfo_endpoint(&self) -> String {
|
||||
format!(
|
||||
"{}/realms/{}/protocol/openid-connect/userinfo",
|
||||
self.keycloak_url, self.realm
|
||||
)
|
||||
}
|
||||
|
||||
/// Build the Keycloak OpenID Connect end-session (logout) endpoint URL.
|
||||
fn logout_endpoint(&self) -> String {
|
||||
format!(
|
||||
"{}/realms/{}/protocol/openid-connect/logout",
|
||||
self.keycloak_url, self.realm
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a cryptographically random state string for CSRF protection.
|
||||
fn generate_state() -> String {
|
||||
let bytes: [u8; 32] = rand::rng().random();
|
||||
// Encode as hex to produce a URL-safe string without padding.
|
||||
bytes.iter().fold(String::with_capacity(64), |mut acc, b| {
|
||||
use std::fmt::Write;
|
||||
// write! on a String is infallible, safe to ignore the result.
|
||||
let _ = write!(acc, "{b:02x}");
|
||||
acc
|
||||
})
|
||||
}
|
||||
|
||||
/// Redirect the user to Keycloak's authorization page.
|
||||
///
|
||||
/// Generates a random CSRF state, stores it (along with the optional
|
||||
/// redirect URL) in the server-side `PendingOAuthStore`, and redirects
|
||||
/// the browser to Keycloak.
|
||||
///
|
||||
/// # Query Parameters
|
||||
///
|
||||
/// * `redirect_url` - Optional URL to redirect to after successful login.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error` if env vars are missing.
|
||||
#[axum::debug_handler]
|
||||
pub async fn auth_login(
|
||||
Extension(pending): Extension<PendingOAuthStore>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<impl IntoResponse, Error> {
|
||||
let config = OAuthConfig::from_env()?;
|
||||
let state = generate_state();
|
||||
|
||||
let redirect_url = params.get("redirect_url").cloned();
|
||||
pending.insert(state.clone(), redirect_url);
|
||||
|
||||
let mut url = Url::parse(&config.auth_endpoint())
|
||||
.map_err(|e| Error::StateError(format!("invalid auth endpoint URL: {e}")))?;
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("client_id", &config.client_id)
|
||||
.append_pair("redirect_uri", &config.redirect_uri)
|
||||
.append_pair("response_type", "code")
|
||||
.append_pair("scope", "openid profile email")
|
||||
.append_pair("state", &state);
|
||||
|
||||
Ok(Redirect::temporary(url.as_str()))
|
||||
}
|
||||
|
||||
/// Token endpoint response from Keycloak.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Userinfo endpoint response from Keycloak.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct UserinfoResponse {
|
||||
/// The subject identifier (unique user ID in Keycloak).
|
||||
sub: String,
|
||||
email: Option<String>,
|
||||
/// Keycloak may include a picture/avatar URL via protocol mappers.
|
||||
picture: Option<String>,
|
||||
}
|
||||
|
||||
/// Handle the OAuth callback from Keycloak after the user authenticates.
|
||||
///
|
||||
/// Validates the CSRF state against the `PendingOAuthStore`, exchanges
|
||||
/// the authorization code for tokens, fetches user info, stores the
|
||||
/// logged-in user in the tower-sessions session, and redirects to the app.
|
||||
///
|
||||
/// # Query Parameters
|
||||
///
|
||||
/// * `code` - The authorization code from Keycloak.
|
||||
/// * `state` - The CSRF state to verify against the pending store.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error` on CSRF mismatch, token exchange failure, or session issues.
|
||||
#[axum::debug_handler]
|
||||
pub async fn auth_callback(
|
||||
session: Session,
|
||||
Extension(pending): Extension<PendingOAuthStore>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<impl IntoResponse, Error> {
|
||||
let config = OAuthConfig::from_env()?;
|
||||
|
||||
// --- CSRF validation via the in-memory pending store ---
|
||||
let returned_state = params
|
||||
.get("state")
|
||||
.ok_or_else(|| Error::StateError("missing state parameter".into()))?;
|
||||
|
||||
let redirect_url = pending
|
||||
.take(returned_state)
|
||||
.ok_or_else(|| Error::StateError("unknown or expired oauth state".into()))?;
|
||||
|
||||
// --- Exchange code for tokens ---
|
||||
let code = params
|
||||
.get("code")
|
||||
.ok_or_else(|| Error::StateError("missing code parameter".into()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let token_resp = client
|
||||
.post(&config.token_endpoint())
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("client_id", &config.client_id),
|
||||
("redirect_uri", &config.redirect_uri),
|
||||
("code", code),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::StateError(format!("token request failed: {e}")))?;
|
||||
|
||||
if !token_resp.status().is_success() {
|
||||
let body = token_resp.text().await.unwrap_or_default();
|
||||
return Err(Error::StateError(format!("token exchange failed: {body}")));
|
||||
}
|
||||
|
||||
let tokens: TokenResponse = token_resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| Error::StateError(format!("token parse failed: {e}")))?;
|
||||
|
||||
// --- Fetch userinfo ---
|
||||
let userinfo: UserinfoResponse = client
|
||||
.get(&config.userinfo_endpoint())
|
||||
.bearer_auth(&tokens.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::StateError(format!("userinfo request failed: {e}")))?
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| Error::StateError(format!("userinfo parse failed: {e}")))?;
|
||||
|
||||
// --- Build user state and persist in session ---
|
||||
let user_state = UserStateInner {
|
||||
sub: userinfo.sub,
|
||||
access_token: tokens.access_token,
|
||||
refresh_token: tokens.refresh_token.unwrap_or_default(),
|
||||
user: User {
|
||||
email: userinfo.email.unwrap_or_default(),
|
||||
avatar_url: userinfo.picture.unwrap_or_default(),
|
||||
},
|
||||
};
|
||||
|
||||
set_login_session(session, user_state).await?;
|
||||
|
||||
let target = redirect_url
|
||||
.filter(|u| !u.is_empty())
|
||||
.unwrap_or_else(|| "/".into());
|
||||
|
||||
Ok(Redirect::temporary(&target))
|
||||
}
|
||||
|
||||
/// Clear the user session and redirect to Keycloak's logout endpoint.
|
||||
///
|
||||
/// After Keycloak finishes its own logout flow it will redirect
|
||||
/// back to the application root.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error` if env vars are missing or the session cannot be flushed.
|
||||
#[axum::debug_handler]
|
||||
pub async fn logout(session: Session) -> Result<impl IntoResponse, Error> {
|
||||
let config = OAuthConfig::from_env()?;
|
||||
|
||||
// Flush all session data.
|
||||
session
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| Error::StateError(format!("session flush failed: {e}")))?;
|
||||
|
||||
let mut url = Url::parse(&config.logout_endpoint())
|
||||
.map_err(|e| Error::StateError(format!("invalid logout endpoint URL: {e}")))?;
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("client_id", &config.client_id)
|
||||
.append_pair("post_logout_redirect_uri", &config.app_url);
|
||||
|
||||
Ok(Redirect::temporary(url.as_str()))
|
||||
}
|
||||
|
||||
/// Persist user data into the session.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error` if the session store write fails.
|
||||
pub async fn set_login_session(session: Session, data: UserStateInner) -> Result<(), Error> {
|
||||
session
|
||||
.insert(LOGGED_IN_USER_SESS_KEY, data)
|
||||
.await
|
||||
.map_err(|e| Error::StateError(format!("session insert failed: {e}")))
|
||||
}
|
||||
22
src/infrastructure/error.rs
Normal file
22
src/infrastructure/error.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use axum::response::IntoResponse;
|
||||
use reqwest::StatusCode;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("{0}")]
|
||||
StateError(String),
|
||||
|
||||
#[error("IoError: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let msg = self.to_string();
|
||||
tracing::error!("Converting Error to Response: {msg}");
|
||||
match self {
|
||||
Self::StateError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(),
|
||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "Unknown error").into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
10
src/infrastructure/mod.rs
Normal file
10
src/infrastructure/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
#![cfg(feature = "server")]
|
||||
mod auth;
|
||||
mod error;
|
||||
mod server;
|
||||
mod state;
|
||||
|
||||
pub use auth::*;
|
||||
pub use error::*;
|
||||
pub use server::*;
|
||||
pub use state::*;
|
||||
56
src/infrastructure/server.rs
Normal file
56
src/infrastructure/server.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use crate::infrastructure::{
|
||||
auth_callback, auth_login, logout, PendingOAuthStore, UserState, UserStateInner,
|
||||
};
|
||||
|
||||
use dioxus::prelude::*;
|
||||
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
use time::Duration;
|
||||
use tower_sessions::{cookie::Key, MemoryStore, SessionManagerLayer};
|
||||
|
||||
/// Start the Axum server with Dioxus fullstack, session management,
|
||||
/// and Keycloak OAuth routes.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error` if the tokio runtime or TCP listener fails to start.
|
||||
pub fn server_start(app: fn() -> Element) -> Result<(), super::Error> {
|
||||
tokio::runtime::Runtime::new()?.block_on(async move {
|
||||
let state: UserState = UserStateInner {
|
||||
access_token: "abcd".into(),
|
||||
sub: "abcd".into(),
|
||||
refresh_token: "abcd".into(),
|
||||
..Default::default()
|
||||
}
|
||||
.into();
|
||||
let key = Key::generate();
|
||||
let store = MemoryStore::default();
|
||||
let session = SessionManagerLayer::new(store)
|
||||
.with_secure(false)
|
||||
// Lax is required so the browser sends the session cookie
|
||||
// on the redirect back from Keycloak (cross-origin GET).
|
||||
// Strict would silently drop the cookie on that navigation.
|
||||
.with_same_site(tower_sessions::cookie::SameSite::Lax)
|
||||
.with_expiry(tower_sessions::Expiry::OnInactivity(Duration::hours(24)))
|
||||
.with_signed(key);
|
||||
let addr = dioxus_cli_config::fullstack_address_or_localhost();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
// Layers are applied AFTER serve_dioxus_application so they
|
||||
// wrap both the custom Axum routes AND the Dioxus server
|
||||
// function routes (e.g. check_auth needs Session access).
|
||||
let router = axum::Router::new()
|
||||
.route("/auth", get(auth_login))
|
||||
.route("/auth/callback", get(auth_callback))
|
||||
.route("/logout", get(logout))
|
||||
.serve_dioxus_application(ServeConfig::new(), app)
|
||||
.layer(Extension(PendingOAuthStore::default()))
|
||||
.layer(Extension(state))
|
||||
.layer(session);
|
||||
|
||||
info!("Serving at {addr}");
|
||||
axum::serve(listener, router.into_make_service()).await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
61
src/infrastructure/state.rs
Normal file
61
src/infrastructure/state.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::extract::FromRequestParts;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserState(Arc<UserStateInner>);
|
||||
|
||||
impl Deref for UserState {
|
||||
type Target = UserStateInner;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserStateInner> for UserState {
|
||||
fn from(value: UserStateInner) -> Self {
|
||||
Self(Arc::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct UserStateInner {
|
||||
/// Subject in Oauth
|
||||
pub sub: String,
|
||||
/// Access Token
|
||||
pub access_token: String,
|
||||
/// Refresh Token
|
||||
pub refresh_token: String,
|
||||
/// User
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct User {
|
||||
/// Email
|
||||
pub email: String,
|
||||
/// Avatar Url
|
||||
pub avatar_url: String,
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for UserState
|
||||
where
|
||||
S: std::marker::Sync + std::marker::Send,
|
||||
{
|
||||
type Rejection = super::Error;
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
_: &S,
|
||||
) -> Result<Self, super::Error> {
|
||||
parts
|
||||
.extensions
|
||||
.get::<UserState>()
|
||||
.cloned()
|
||||
.ok_or(super::Error::StateError("Unable to get extension".into()))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user