Files
compliance-scanner-agent/compliance-dashboard/src/infrastructure/auth.rs
Sharang Parnerkar 7e12d1433a
All checks were successful
CI / Clippy (push) Successful in 3m17s
CI / Security Audit (push) Successful in 1m36s
CI / Format (push) Successful in 2s
CI / Tests (push) Successful in 4m38s
docs: added vite-press docs (#4)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com>
Reviewed-on: #4
2026-03-08 13:59:50 +00:00

235 lines
6.9 KiB
Rust

use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use axum::{
extract::Query,
response::{IntoResponse, Redirect},
Extension,
};
use rand::Rng;
use tower_sessions::Session;
use url::Url;
use super::{
error::DashboardError,
server_state::ServerState,
user_state::{User, UserStateInner},
};
pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user";
#[derive(Debug, Clone)]
pub(crate) struct PendingOAuthEntry {
pub(crate) redirect_url: Option<String>,
pub(crate) code_verifier: String,
}
#[derive(Debug, Clone, Default)]
pub struct PendingOAuthStore(Arc<RwLock<HashMap<String, PendingOAuthEntry>>>);
impl PendingOAuthStore {
pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) {
#[allow(clippy::expect_used)]
self.0
.write()
.expect("pending oauth store lock poisoned")
.insert(state, entry);
}
pub(crate) fn take(&self, state: &str) -> Option<PendingOAuthEntry> {
#[allow(clippy::expect_used)]
self.0
.write()
.expect("pending oauth store lock poisoned")
.remove(state)
}
}
pub(crate) fn generate_state() -> String {
let bytes: [u8; 32] = rand::rng().random();
bytes.iter().fold(String::with_capacity(64), |mut acc, b| {
use std::fmt::Write;
let _ = write!(acc, "{b:02x}");
acc
})
}
pub(crate) fn generate_code_verifier() -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let bytes: [u8; 32] = rand::rng().random();
URL_SAFE_NO_PAD.encode(bytes)
}
pub(crate) fn derive_code_challenge(verifier: &str) -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use sha2::{Digest, Sha256};
let digest = Sha256::digest(verifier.as_bytes());
URL_SAFE_NO_PAD.encode(digest)
}
#[axum::debug_handler]
pub async fn auth_login(
Extension(state): Extension<ServerState>,
Extension(pending): Extension<PendingOAuthStore>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse, DashboardError> {
let kc = state
.keycloak
.ok_or(DashboardError::Other("Keycloak not configured".into()))?;
let csrf_state = generate_state();
let code_verifier = generate_code_verifier();
let code_challenge = derive_code_challenge(&code_verifier);
let redirect_url = params.get("redirect_url").cloned();
pending.insert(
csrf_state.clone(),
PendingOAuthEntry {
redirect_url,
code_verifier,
},
);
let mut url = Url::parse(&kc.auth_endpoint())
.map_err(|e| DashboardError::Other(format!("invalid auth endpoint URL: {e}")))?;
url.query_pairs_mut()
.append_pair("client_id", &kc.client_id)
.append_pair("redirect_uri", &kc.redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", "openid profile email")
.append_pair("state", &csrf_state)
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256");
Ok(Redirect::temporary(url.as_str()))
}
#[derive(serde::Deserialize)]
struct TokenResponse {
access_token: String,
refresh_token: Option<String>,
}
#[derive(serde::Deserialize)]
struct UserinfoResponse {
sub: String,
email: Option<String>,
preferred_username: Option<String>,
name: Option<String>,
picture: Option<String>,
}
#[axum::debug_handler]
pub async fn auth_callback(
session: Session,
Extension(state): Extension<ServerState>,
Extension(pending): Extension<PendingOAuthStore>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse, DashboardError> {
let kc = state
.keycloak
.ok_or(DashboardError::Other("Keycloak not configured".into()))?;
let returned_state = params
.get("state")
.ok_or_else(|| DashboardError::Other("missing state parameter".into()))?;
let entry = pending
.take(returned_state)
.ok_or_else(|| DashboardError::Other("unknown or expired oauth state".into()))?;
let code = params
.get("code")
.ok_or_else(|| DashboardError::Other("missing code parameter".into()))?;
let client = reqwest::Client::new();
let token_resp = client
.post(kc.token_endpoint())
.form(&[
("grant_type", "authorization_code"),
("client_id", kc.client_id.as_str()),
("redirect_uri", kc.redirect_uri.as_str()),
("code", code),
("code_verifier", &entry.code_verifier),
])
.send()
.await
.map_err(|e| DashboardError::Other(format!("token request failed: {e}")))?;
if !token_resp.status().is_success() {
let body = token_resp.text().await.unwrap_or_default();
return Err(DashboardError::Other(format!(
"token exchange failed: {body}"
)));
}
let tokens: TokenResponse = token_resp
.json()
.await
.map_err(|e| DashboardError::Other(format!("token parse failed: {e}")))?;
let userinfo: UserinfoResponse = client
.get(kc.userinfo_endpoint())
.bearer_auth(&tokens.access_token)
.send()
.await
.map_err(|e| DashboardError::Other(format!("userinfo request failed: {e}")))?
.json()
.await
.map_err(|e| DashboardError::Other(format!("userinfo parse failed: {e}")))?;
let display_name = userinfo
.name
.or(userinfo.preferred_username)
.unwrap_or_default();
let user_state = UserStateInner {
sub: userinfo.sub,
access_token: tokens.access_token,
refresh_token: tokens.refresh_token.unwrap_or_default(),
user: User {
email: userinfo.email.unwrap_or_default(),
name: display_name,
avatar_url: userinfo.picture.unwrap_or_default(),
},
};
session
.insert(LOGGED_IN_USER_SESS_KEY, user_state)
.await
.map_err(|e| DashboardError::Other(format!("session insert failed: {e}")))?;
let target = entry
.redirect_url
.filter(|u| !u.is_empty())
.unwrap_or_else(|| "/".into());
Ok(Redirect::temporary(&target))
}
#[axum::debug_handler]
pub async fn logout(
session: Session,
Extension(state): Extension<ServerState>,
) -> Result<impl IntoResponse, DashboardError> {
let kc = state
.keycloak
.ok_or(DashboardError::Other("Keycloak not configured".into()))?;
session
.flush()
.await
.map_err(|e| DashboardError::Other(format!("session flush failed: {e}")))?;
let mut url = Url::parse(&kc.logout_endpoint())
.map_err(|e| DashboardError::Other(format!("invalid logout endpoint URL: {e}")))?;
url.query_pairs_mut()
.append_pair("client_id", &kc.client_id)
.append_pair("post_logout_redirect_uri", &kc.app_url);
Ok(Redirect::temporary(url.as_str()))
}