Files
certifai/src/infrastructure/auth.rs
2026-02-16 21:23:25 +01:00

110 lines
3.5 KiB
Rust

use super::error::{Error, Result};
use axum::Extension;
use axum::{
extract::FromRequestParts,
http::request::Parts,
response::{IntoResponse, Redirect, Response},
};
use url::form_urlencoded;
pub struct KeycloakVariables {
pub base_url: String,
pub realm: String,
pub client_id: String,
pub client_secret: String,
pub enable_test_user: bool,
}
/// Session data available to the backend when the user is logged in
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct LoggedInData {
pub id: String,
// ID Token value associated with the authenticated session.
pub token_id: String,
pub username: String,
pub avatar_url: Option<String>,
}
/// Used for extracting in the server functions.
/// If the `data` is `Some`, the user is logged in.
pub struct UserSession {
data: Option<LoggedInData>,
}
impl UserSession {
/// Get the [`LoggedInData`].
///
/// Raises a [`Error::UserNotLoggedIn`] error if the user is not logged in.
pub fn data(self) -> Result<LoggedInData> {
self.data.ok_or(Error::UserNotLoggedIn)
}
}
const LOGGED_IN_USER_SESSION_KEY: &str = "logged_in_data";
impl<S: std::marker::Sync + std::marker::Send> FromRequestParts<S> for UserSession {
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> {
let session = parts
.extensions
.get::<tower_sessions::Session>()
.cloned()
.ok_or(Error::AuthSessionLayerNotFound(
"Auth Session Layer not found".to_string(),
))?;
let data: Option<LoggedInData> = session
.get::<LoggedInData>(LOGGED_IN_USER_SESSION_KEY)
.await?;
Ok(Self { data })
}
}
/// Helper function to log the user in by setting the session data
pub async fn login(session: &tower_sessions::Session, data: &LoggedInData) -> Result<()> {
session.insert(LOGGED_IN_USER_SESSION_KEY, data).await?;
Ok(())
}
/// Handler to run when the user wants to logout
#[axum::debug_handler]
pub async fn logout(
state: Extension<super::server_state::ServerState>,
session: tower_sessions::Session,
) -> Result<Response> {
let dashboard_base_url = "http://localhost:8000";
let redirect_uri = format!("{dashboard_base_url}/");
let encoded_redirect_uri: String =
form_urlencoded::byte_serialize(redirect_uri.as_bytes()).collect();
// clear the session value for this session
if let Some(login_data) = session
.remove::<LoggedInData>(LOGGED_IN_USER_SESSION_KEY)
.await?
{
let kc_base_url = &state.keycloak_variables.base_url;
let kc_realm = &state.keycloak_variables.realm;
let kc_client_id = &state.keycloak_variables.client_id;
// Needed for running locally.
// This will not panic on production and it will return the original so we can keep it
let routed_kc_base_url = kc_base_url.replace("keycloak", "localhost");
let token_id = login_data.token_id;
// redirect to Keycloak logout endpoint
let logout_url = format!(
"{routed_kc_base_url}/realms/{kc_realm}/protocol/openid-connect/logout\
?post_logout_redirect_uri={encoded_redirect_uri}\
&client_id={kc_client_id}\
&id_token_hint={token_id}"
);
Ok(Redirect::to(&logout_url).into_response())
} else {
// No id_token in session; just redirect to homepage
Ok(Redirect::to(&redirect_uri).into_response())
}
}