Add unit tests across all model and server infrastructure layers, increasing test count from 7 to 92. Covers serde round-trips, enum methods, defaults, config parsing, error mapping, PKCE crypto (with RFC 7636 test vector), OAuth store, and SearXNG ranking/dedup logic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
422 lines
14 KiB
Rust
422 lines
14 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
sync::{Arc, RwLock},
|
|
};
|
|
|
|
use axum::{
|
|
extract::Query,
|
|
response::{IntoResponse, Redirect},
|
|
Extension,
|
|
};
|
|
use rand::RngExt;
|
|
use tower_sessions::Session;
|
|
use url::Url;
|
|
|
|
use crate::infrastructure::{
|
|
server_state::ServerState,
|
|
state::{User, UserStateInner},
|
|
Error,
|
|
};
|
|
|
|
pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user";
|
|
|
|
/// Data stored alongside each pending OAuth state. Holds the optional
|
|
/// post-login redirect URL and the PKCE code verifier needed for the
|
|
/// token exchange.
|
|
#[derive(Debug, Clone)]
|
|
pub(crate) struct PendingOAuthEntry {
|
|
pub(crate) redirect_url: Option<String>,
|
|
pub(crate) code_verifier: String,
|
|
}
|
|
|
|
/// In-memory store for pending OAuth states. Keyed by the random state
|
|
/// string. This avoids dependence on the session cookie surviving the
|
|
/// Keycloak redirect round-trip (the `dx serve` proxy can drop
|
|
/// `Set-Cookie` headers on 307 responses).
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct PendingOAuthStore(Arc<RwLock<HashMap<String, PendingOAuthEntry>>>);
|
|
|
|
impl PendingOAuthStore {
|
|
/// Insert a pending state with an optional redirect URL and PKCE verifier.
|
|
pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) {
|
|
// RwLock::write only panics if the lock is poisoned, which
|
|
// indicates a prior panic -- propagating is acceptable here.
|
|
#[allow(clippy::expect_used)]
|
|
self.0
|
|
.write()
|
|
.expect("pending oauth store lock poisoned")
|
|
.insert(state, entry);
|
|
}
|
|
|
|
/// Remove and return the entry if the state was pending.
|
|
/// Returns `None` if the state was never stored (CSRF failure).
|
|
pub(crate) fn take(&self, state: &str) -> Option<PendingOAuthEntry> {
|
|
#[allow(clippy::expect_used)]
|
|
self.0
|
|
.write()
|
|
.expect("pending oauth store lock poisoned")
|
|
.remove(state)
|
|
}
|
|
}
|
|
|
|
/// Generate a cryptographically random state string for CSRF protection.
|
|
#[cfg_attr(test, allow(dead_code))]
|
|
pub(crate) fn generate_state() -> String {
|
|
let bytes: [u8; 32] = rand::rng().random();
|
|
// Encode as hex to produce a URL-safe string without padding.
|
|
bytes.iter().fold(String::with_capacity(64), |mut acc, b| {
|
|
use std::fmt::Write;
|
|
// write! on a String is infallible, safe to ignore the result.
|
|
let _ = write!(acc, "{b:02x}");
|
|
acc
|
|
})
|
|
}
|
|
|
|
/// Generate a PKCE code verifier (43-128 char URL-safe random string).
|
|
///
|
|
/// Uses 32 random bytes encoded as base64url (no padding) to produce
|
|
/// a 43-character verifier per RFC 7636.
|
|
pub(crate) fn generate_code_verifier() -> String {
|
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
|
|
|
let bytes: [u8; 32] = rand::rng().random();
|
|
URL_SAFE_NO_PAD.encode(bytes)
|
|
}
|
|
|
|
/// Derive the S256 code challenge from a code verifier per RFC 7636.
|
|
///
|
|
/// `code_challenge = BASE64URL(SHA256(code_verifier))`
|
|
pub(crate) fn derive_code_challenge(verifier: &str) -> String {
|
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
|
use sha2::{Digest, Sha256};
|
|
|
|
let digest = Sha256::digest(verifier.as_bytes());
|
|
URL_SAFE_NO_PAD.encode(digest)
|
|
}
|
|
|
|
/// Redirect the user to Keycloak's authorization page.
|
|
///
|
|
/// Generates a random CSRF state, stores it (along with the optional
|
|
/// redirect URL) in the server-side `PendingOAuthStore`, and redirects
|
|
/// the browser to Keycloak.
|
|
///
|
|
/// # Query Parameters
|
|
///
|
|
/// * `redirect_url` - Optional URL to redirect to after successful login.
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `Error` if the Keycloak config is missing or the URL is malformed.
|
|
#[axum::debug_handler]
|
|
pub async fn auth_login(
|
|
Extension(state): Extension<ServerState>,
|
|
Extension(pending): Extension<PendingOAuthStore>,
|
|
Query(params): Query<HashMap<String, String>>,
|
|
) -> Result<impl IntoResponse, Error> {
|
|
let kc = state.keycloak;
|
|
let csrf_state = generate_state();
|
|
let code_verifier = generate_code_verifier();
|
|
let code_challenge = derive_code_challenge(&code_verifier);
|
|
|
|
let redirect_url = params.get("redirect_url").cloned();
|
|
pending.insert(
|
|
csrf_state.clone(),
|
|
PendingOAuthEntry {
|
|
redirect_url,
|
|
code_verifier,
|
|
},
|
|
);
|
|
|
|
let mut url = Url::parse(&kc.auth_endpoint())
|
|
.map_err(|e| Error::StateError(format!("invalid auth endpoint URL: {e}")))?;
|
|
|
|
url.query_pairs_mut()
|
|
.append_pair("client_id", &kc.client_id)
|
|
.append_pair("redirect_uri", &kc.redirect_uri)
|
|
.append_pair("response_type", "code")
|
|
.append_pair("scope", "openid profile email")
|
|
.append_pair("state", &csrf_state)
|
|
.append_pair("code_challenge", &code_challenge)
|
|
.append_pair("code_challenge_method", "S256");
|
|
|
|
Ok(Redirect::temporary(url.as_str()))
|
|
}
|
|
|
|
/// Token endpoint response from Keycloak.
|
|
#[derive(serde::Deserialize)]
|
|
struct TokenResponse {
|
|
access_token: String,
|
|
refresh_token: Option<String>,
|
|
}
|
|
|
|
/// Userinfo endpoint response from Keycloak.
|
|
#[derive(serde::Deserialize)]
|
|
struct UserinfoResponse {
|
|
/// The subject identifier (unique user ID in Keycloak).
|
|
sub: String,
|
|
email: Option<String>,
|
|
/// Keycloak `preferred_username` claim.
|
|
preferred_username: Option<String>,
|
|
/// Full name from the Keycloak profile.
|
|
name: Option<String>,
|
|
/// Keycloak may include a picture/avatar URL via protocol mappers.
|
|
picture: Option<String>,
|
|
}
|
|
|
|
/// Handle the OAuth callback from Keycloak after the user authenticates.
|
|
///
|
|
/// Validates the CSRF state against the `PendingOAuthStore`, exchanges
|
|
/// the authorization code for tokens, fetches user info, stores the
|
|
/// logged-in user in the tower-sessions session, and redirects to the app.
|
|
///
|
|
/// # Query Parameters
|
|
///
|
|
/// * `code` - The authorization code from Keycloak.
|
|
/// * `state` - The CSRF state to verify against the pending store.
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `Error` on CSRF mismatch, token exchange failure, or session issues.
|
|
#[axum::debug_handler]
|
|
pub async fn auth_callback(
|
|
session: Session,
|
|
Extension(state): Extension<ServerState>,
|
|
Extension(pending): Extension<PendingOAuthStore>,
|
|
Query(params): Query<HashMap<String, String>>,
|
|
) -> Result<impl IntoResponse, Error> {
|
|
let kc = state.keycloak;
|
|
|
|
// --- CSRF validation via the in-memory pending store ---
|
|
let returned_state = params
|
|
.get("state")
|
|
.ok_or_else(|| Error::StateError("missing state parameter".into()))?;
|
|
|
|
let entry = pending
|
|
.take(returned_state)
|
|
.ok_or_else(|| Error::StateError("unknown or expired oauth state".into()))?;
|
|
|
|
// --- Exchange code for tokens (with PKCE code_verifier) ---
|
|
let code = params
|
|
.get("code")
|
|
.ok_or_else(|| Error::StateError("missing code parameter".into()))?;
|
|
|
|
let client = reqwest::Client::new();
|
|
let token_resp = client
|
|
.post(kc.token_endpoint())
|
|
.form(&[
|
|
("grant_type", "authorization_code"),
|
|
("client_id", kc.client_id.as_str()),
|
|
("redirect_uri", kc.redirect_uri.as_str()),
|
|
("code", code),
|
|
("code_verifier", &entry.code_verifier),
|
|
])
|
|
.send()
|
|
.await
|
|
.map_err(|e| Error::StateError(format!("token request failed: {e}")))?;
|
|
|
|
if !token_resp.status().is_success() {
|
|
let body = token_resp.text().await.unwrap_or_default();
|
|
return Err(Error::StateError(format!("token exchange failed: {body}")));
|
|
}
|
|
|
|
let tokens: TokenResponse = token_resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| Error::StateError(format!("token parse failed: {e}")))?;
|
|
|
|
// --- Fetch userinfo ---
|
|
let userinfo: UserinfoResponse = client
|
|
.get(kc.userinfo_endpoint())
|
|
.bearer_auth(&tokens.access_token)
|
|
.send()
|
|
.await
|
|
.map_err(|e| Error::StateError(format!("userinfo request failed: {e}")))?
|
|
.json()
|
|
.await
|
|
.map_err(|e| Error::StateError(format!("userinfo parse failed: {e}")))?;
|
|
|
|
// Prefer `name`, fall back to `preferred_username`, then empty.
|
|
let display_name = userinfo
|
|
.name
|
|
.or(userinfo.preferred_username)
|
|
.unwrap_or_default();
|
|
|
|
// --- Build user state and persist in session ---
|
|
let user_state = UserStateInner {
|
|
sub: userinfo.sub,
|
|
access_token: tokens.access_token,
|
|
refresh_token: tokens.refresh_token.unwrap_or_default(),
|
|
user: User {
|
|
email: userinfo.email.unwrap_or_default(),
|
|
name: display_name,
|
|
avatar_url: userinfo.picture.unwrap_or_default(),
|
|
},
|
|
};
|
|
|
|
set_login_session(session, user_state).await?;
|
|
|
|
let target = entry
|
|
.redirect_url
|
|
.filter(|u| !u.is_empty())
|
|
.unwrap_or_else(|| "/".into());
|
|
|
|
Ok(Redirect::temporary(&target))
|
|
}
|
|
|
|
/// Clear the user session and redirect to Keycloak's logout endpoint.
|
|
///
|
|
/// After Keycloak finishes its own logout flow it will redirect
|
|
/// back to the application root.
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `Error` if the session cannot be flushed or the URL is malformed.
|
|
#[axum::debug_handler]
|
|
pub async fn logout(
|
|
session: Session,
|
|
Extension(state): Extension<ServerState>,
|
|
) -> Result<impl IntoResponse, Error> {
|
|
let kc = state.keycloak;
|
|
|
|
// Flush all session data.
|
|
session
|
|
.flush()
|
|
.await
|
|
.map_err(|e| Error::StateError(format!("session flush failed: {e}")))?;
|
|
|
|
let mut url = Url::parse(&kc.logout_endpoint())
|
|
.map_err(|e| Error::StateError(format!("invalid logout endpoint URL: {e}")))?;
|
|
|
|
url.query_pairs_mut()
|
|
.append_pair("client_id", &kc.client_id)
|
|
.append_pair("post_logout_redirect_uri", &kc.app_url);
|
|
|
|
Ok(Redirect::temporary(url.as_str()))
|
|
}
|
|
|
|
/// Persist user data into the session.
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `Error` if the session store write fails.
|
|
pub async fn set_login_session(session: Session, data: UserStateInner) -> Result<(), Error> {
|
|
session
|
|
.insert(LOGGED_IN_USER_SESS_KEY, data)
|
|
.await
|
|
.map_err(|e| Error::StateError(format!("session insert failed: {e}")))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
|
|
|
use super::*;
|
|
use pretty_assertions::assert_eq;
|
|
|
|
// -----------------------------------------------------------------------
|
|
// generate_state()
|
|
// -----------------------------------------------------------------------
|
|
|
|
#[test]
|
|
fn generate_state_length_is_64() {
|
|
let state = generate_state();
|
|
assert_eq!(state.len(), 64);
|
|
}
|
|
|
|
#[test]
|
|
fn generate_state_chars_are_hex() {
|
|
let state = generate_state();
|
|
assert!(state.chars().all(|c| c.is_ascii_hexdigit()));
|
|
}
|
|
|
|
#[test]
|
|
fn generate_state_two_calls_differ() {
|
|
let a = generate_state();
|
|
let b = generate_state();
|
|
assert_ne!(a, b);
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// generate_code_verifier()
|
|
// -----------------------------------------------------------------------
|
|
|
|
#[test]
|
|
fn code_verifier_length_is_43() {
|
|
let verifier = generate_code_verifier();
|
|
assert_eq!(verifier.len(), 43);
|
|
}
|
|
|
|
#[test]
|
|
fn code_verifier_chars_are_url_safe_base64() {
|
|
let verifier = generate_code_verifier();
|
|
// URL-safe base64 without padding uses [A-Za-z0-9_-]
|
|
assert!(verifier
|
|
.chars()
|
|
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// derive_code_challenge()
|
|
// -----------------------------------------------------------------------
|
|
|
|
#[test]
|
|
fn code_challenge_deterministic() {
|
|
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
|
|
let a = derive_code_challenge(verifier);
|
|
let b = derive_code_challenge(verifier);
|
|
assert_eq!(a, b);
|
|
}
|
|
|
|
#[test]
|
|
fn code_challenge_rfc7636_test_vector() {
|
|
// RFC 7636 Appendix B test vector:
|
|
// verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
|
// expected challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
|
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
|
|
let challenge = derive_code_challenge(verifier);
|
|
assert_eq!(challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// PendingOAuthStore
|
|
// -----------------------------------------------------------------------
|
|
|
|
#[test]
|
|
fn pending_store_insert_and_take() {
|
|
let store = PendingOAuthStore::default();
|
|
store.insert(
|
|
"state-1".into(),
|
|
PendingOAuthEntry {
|
|
redirect_url: Some("/dashboard".into()),
|
|
code_verifier: "verifier-1".into(),
|
|
},
|
|
);
|
|
let entry = store.take("state-1");
|
|
assert!(entry.is_some());
|
|
let entry = entry.unwrap();
|
|
assert_eq!(entry.redirect_url, Some("/dashboard".into()));
|
|
assert_eq!(entry.code_verifier, "verifier-1");
|
|
}
|
|
|
|
#[test]
|
|
fn pending_store_take_removes_entry() {
|
|
let store = PendingOAuthStore::default();
|
|
store.insert(
|
|
"state-2".into(),
|
|
PendingOAuthEntry {
|
|
redirect_url: None,
|
|
code_verifier: "v2".into(),
|
|
},
|
|
);
|
|
let _ = store.take("state-2");
|
|
// Second take should return None since the entry was removed.
|
|
assert!(store.take("state-2").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn pending_store_take_unknown_returns_none() {
|
|
let store = PendingOAuthStore::default();
|
|
assert!(store.take("nonexistent").is_none());
|
|
}
|
|
}
|