Files
certifai/src/infrastructure/auth.rs
Sharang Parnerkar e49590313f fix(ci): clippy fixes
2026-02-18 10:15:17 +01:00

308 lines
10 KiB
Rust

use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use axum::{
extract::Query,
response::{IntoResponse, Redirect},
Extension,
};
use rand::RngExt;
use tower_sessions::Session;
use url::Url;
use crate::infrastructure::{state::User, Error, UserStateInner};
pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user";
/// In-memory store for pending OAuth states and their associated redirect
/// URLs. Keyed by the random state string. This avoids dependence on the
/// session cookie surviving the Keycloak redirect round-trip (the `dx serve`
/// proxy can drop `Set-Cookie` headers on 307 responses).
#[derive(Debug, Clone, Default)]
pub struct PendingOAuthStore(Arc<RwLock<HashMap<String, Option<String>>>>);
impl PendingOAuthStore {
/// Insert a pending state with an optional post-login redirect URL.
fn insert(&self, state: String, redirect_url: Option<String>) {
// RwLock::write only panics if the lock is poisoned, which
// indicates a prior panic -- propagating is acceptable here.
#[allow(clippy::expect_used)]
self.0
.write()
.expect("pending oauth store lock poisoned")
.insert(state, redirect_url);
}
/// Remove and return the redirect URL if the state was pending.
/// Returns `None` if the state was never stored (CSRF failure).
fn take(&self, state: &str) -> Option<Option<String>> {
#[allow(clippy::expect_used)]
self.0
.write()
.expect("pending oauth store lock poisoned")
.remove(state)
}
}
/// Configuration loaded from environment variables for Keycloak OAuth.
struct OAuthConfig {
keycloak_url: String,
realm: String,
client_id: String,
redirect_uri: String,
app_url: String,
}
impl OAuthConfig {
/// Load OAuth configuration from environment variables.
///
/// # Errors
///
/// Returns `Error::StateError` if any required env var is missing.
fn from_env() -> Result<Self, Error> {
dotenvy::dotenv().ok();
Ok(Self {
keycloak_url: std::env::var("KEYCLOAK_URL")
.map_err(|_| Error::StateError("KEYCLOAK_URL not set".into()))?,
realm: std::env::var("KEYCLOAK_REALM")
.map_err(|_| Error::StateError("KEYCLOAK_REALM not set".into()))?,
client_id: std::env::var("KEYCLOAK_CLIENT_ID")
.map_err(|_| Error::StateError("KEYCLOAK_CLIENT_ID not set".into()))?,
redirect_uri: std::env::var("REDIRECT_URI")
.map_err(|_| Error::StateError("REDIRECT_URI not set".into()))?,
app_url: std::env::var("APP_URL")
.map_err(|_| Error::StateError("APP_URL not set".into()))?,
})
}
/// Build the Keycloak OpenID Connect authorization endpoint URL.
fn auth_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/auth",
self.keycloak_url, self.realm
)
}
/// Build the Keycloak OpenID Connect token endpoint URL.
fn token_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/token",
self.keycloak_url, self.realm
)
}
/// Build the Keycloak OpenID Connect userinfo endpoint URL.
fn userinfo_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/userinfo",
self.keycloak_url, self.realm
)
}
/// Build the Keycloak OpenID Connect end-session (logout) endpoint URL.
fn logout_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/logout",
self.keycloak_url, self.realm
)
}
}
/// Generate a cryptographically random state string for CSRF protection.
fn generate_state() -> String {
let bytes: [u8; 32] = rand::rng().random();
// Encode as hex to produce a URL-safe string without padding.
bytes.iter().fold(String::with_capacity(64), |mut acc, b| {
use std::fmt::Write;
// write! on a String is infallible, safe to ignore the result.
let _ = write!(acc, "{b:02x}");
acc
})
}
/// Redirect the user to Keycloak's authorization page.
///
/// Generates a random CSRF state, stores it (along with the optional
/// redirect URL) in the server-side `PendingOAuthStore`, and redirects
/// the browser to Keycloak.
///
/// # Query Parameters
///
/// * `redirect_url` - Optional URL to redirect to after successful login.
///
/// # Errors
///
/// Returns `Error` if env vars are missing.
#[axum::debug_handler]
pub async fn auth_login(
Extension(pending): Extension<PendingOAuthStore>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse, Error> {
let config = OAuthConfig::from_env()?;
let state = generate_state();
let redirect_url = params.get("redirect_url").cloned();
pending.insert(state.clone(), redirect_url);
let mut url = Url::parse(&config.auth_endpoint())
.map_err(|e| Error::StateError(format!("invalid auth endpoint URL: {e}")))?;
url.query_pairs_mut()
.append_pair("client_id", &config.client_id)
.append_pair("redirect_uri", &config.redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", "openid profile email")
.append_pair("state", &state);
Ok(Redirect::temporary(url.as_str()))
}
/// Token endpoint response from Keycloak.
#[derive(serde::Deserialize)]
struct TokenResponse {
access_token: String,
refresh_token: Option<String>,
}
/// Userinfo endpoint response from Keycloak.
#[derive(serde::Deserialize)]
struct UserinfoResponse {
/// The subject identifier (unique user ID in Keycloak).
sub: String,
email: Option<String>,
/// Keycloak may include a picture/avatar URL via protocol mappers.
picture: Option<String>,
}
/// Handle the OAuth callback from Keycloak after the user authenticates.
///
/// Validates the CSRF state against the `PendingOAuthStore`, exchanges
/// the authorization code for tokens, fetches user info, stores the
/// logged-in user in the tower-sessions session, and redirects to the app.
///
/// # Query Parameters
///
/// * `code` - The authorization code from Keycloak.
/// * `state` - The CSRF state to verify against the pending store.
///
/// # Errors
///
/// Returns `Error` on CSRF mismatch, token exchange failure, or session issues.
#[axum::debug_handler]
pub async fn auth_callback(
session: Session,
Extension(pending): Extension<PendingOAuthStore>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse, Error> {
let config = OAuthConfig::from_env()?;
// --- CSRF validation via the in-memory pending store ---
let returned_state = params
.get("state")
.ok_or_else(|| Error::StateError("missing state parameter".into()))?;
let redirect_url = pending
.take(returned_state)
.ok_or_else(|| Error::StateError("unknown or expired oauth state".into()))?;
// --- Exchange code for tokens ---
let code = params
.get("code")
.ok_or_else(|| Error::StateError("missing code parameter".into()))?;
let client = reqwest::Client::new();
let token_resp = client
.post(config.token_endpoint())
.form(&[
("grant_type", "authorization_code"),
("client_id", &config.client_id),
("redirect_uri", &config.redirect_uri),
("code", code),
])
.send()
.await
.map_err(|e| Error::StateError(format!("token request failed: {e}")))?;
if !token_resp.status().is_success() {
let body = token_resp.text().await.unwrap_or_default();
return Err(Error::StateError(format!("token exchange failed: {body}")));
}
let tokens: TokenResponse = token_resp
.json()
.await
.map_err(|e| Error::StateError(format!("token parse failed: {e}")))?;
// --- Fetch userinfo ---
let userinfo: UserinfoResponse = client
.get(config.userinfo_endpoint())
.bearer_auth(&tokens.access_token)
.send()
.await
.map_err(|e| Error::StateError(format!("userinfo request failed: {e}")))?
.json()
.await
.map_err(|e| Error::StateError(format!("userinfo parse failed: {e}")))?;
// --- Build user state and persist in session ---
let user_state = UserStateInner {
sub: userinfo.sub,
access_token: tokens.access_token,
refresh_token: tokens.refresh_token.unwrap_or_default(),
user: User {
email: userinfo.email.unwrap_or_default(),
avatar_url: userinfo.picture.unwrap_or_default(),
},
};
set_login_session(session, user_state).await?;
let target = redirect_url
.filter(|u| !u.is_empty())
.unwrap_or_else(|| "/".into());
Ok(Redirect::temporary(&target))
}
/// Clear the user session and redirect to Keycloak's logout endpoint.
///
/// After Keycloak finishes its own logout flow it will redirect
/// back to the application root.
///
/// # Errors
///
/// Returns `Error` if env vars are missing or the session cannot be flushed.
#[axum::debug_handler]
pub async fn logout(session: Session) -> Result<impl IntoResponse, Error> {
let config = OAuthConfig::from_env()?;
// Flush all session data.
session
.flush()
.await
.map_err(|e| Error::StateError(format!("session flush failed: {e}")))?;
let mut url = Url::parse(&config.logout_endpoint())
.map_err(|e| Error::StateError(format!("invalid logout endpoint URL: {e}")))?;
url.query_pairs_mut()
.append_pair("client_id", &config.client_id)
.append_pair("post_logout_redirect_uri", &config.app_url);
Ok(Redirect::temporary(url.as_str()))
}
/// Persist user data into the session.
///
/// # Errors
///
/// Returns `Error` if the session store write fails.
pub async fn set_login_session(session: Session, data: UserStateInner) -> Result<(), Error> {
session
.insert(LOGGED_IN_USER_SESS_KEY, data)
.await
.map_err(|e| Error::StateError(format!("session insert failed: {e}")))
}