feat(db): Added database setup and basic types (#5)
All checks were successful
CI / Format (push) Successful in 3s
CI / Clippy (push) Successful in 2m21s
CI / Security Audit (push) Successful in 1m44s
CI / Tests (push) Successful in 2m55s
CI / Deploy (push) Successful in 2s

Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com>
Reviewed-on: #5
This commit was merged in pull request #5.
This commit is contained in:
2026-02-20 14:58:14 +00:00
parent 5ce600e32b
commit e68f840f2b
23 changed files with 1375 additions and 480 deletions

View File

@@ -12,7 +12,11 @@ use rand::RngExt;
use tower_sessions::Session;
use url::Url;
use crate::infrastructure::{state::User, Error, UserStateInner};
use crate::infrastructure::{
server_state::ServerState,
state::{User, UserStateInner},
Error,
};
pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user";
@@ -55,70 +59,6 @@ impl PendingOAuthStore {
}
}
/// Configuration loaded from environment variables for Keycloak OAuth.
struct OAuthConfig {
keycloak_url: String,
realm: String,
client_id: String,
redirect_uri: String,
app_url: String,
}
impl OAuthConfig {
/// Load OAuth configuration from environment variables.
///
/// # Errors
///
/// Returns `Error::StateError` if any required env var is missing.
fn from_env() -> Result<Self, Error> {
dotenvy::dotenv().ok();
Ok(Self {
keycloak_url: std::env::var("KEYCLOAK_URL")
.map_err(|_| Error::StateError("KEYCLOAK_URL not set".into()))?,
realm: std::env::var("KEYCLOAK_REALM")
.map_err(|_| Error::StateError("KEYCLOAK_REALM not set".into()))?,
client_id: std::env::var("KEYCLOAK_CLIENT_ID")
.map_err(|_| Error::StateError("KEYCLOAK_CLIENT_ID not set".into()))?,
redirect_uri: std::env::var("REDIRECT_URI")
.map_err(|_| Error::StateError("REDIRECT_URI not set".into()))?,
app_url: std::env::var("APP_URL")
.map_err(|_| Error::StateError("APP_URL not set".into()))?,
})
}
/// Build the Keycloak OpenID Connect authorization endpoint URL.
fn auth_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/auth",
self.keycloak_url, self.realm
)
}
/// Build the Keycloak OpenID Connect token endpoint URL.
fn token_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/token",
self.keycloak_url, self.realm
)
}
/// Build the Keycloak OpenID Connect userinfo endpoint URL.
fn userinfo_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/userinfo",
self.keycloak_url, self.realm
)
}
/// Build the Keycloak OpenID Connect end-session (logout) endpoint URL.
fn logout_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/logout",
self.keycloak_url, self.realm
)
}
}
/// Generate a cryptographically random state string for CSRF protection.
fn generate_state() -> String {
let bytes: [u8; 32] = rand::rng().random();
@@ -165,35 +105,36 @@ fn derive_code_challenge(verifier: &str) -> String {
///
/// # Errors
///
/// Returns `Error` if env vars are missing.
/// Returns `Error` if the Keycloak config is missing or the URL is malformed.
#[axum::debug_handler]
pub async fn auth_login(
Extension(state): Extension<ServerState>,
Extension(pending): Extension<PendingOAuthStore>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse, Error> {
let config = OAuthConfig::from_env()?;
let state = generate_state();
let kc = state.keycloak;
let csrf_state = generate_state();
let code_verifier = generate_code_verifier();
let code_challenge = derive_code_challenge(&code_verifier);
let redirect_url = params.get("redirect_url").cloned();
pending.insert(
state.clone(),
csrf_state.clone(),
PendingOAuthEntry {
redirect_url,
code_verifier,
},
);
let mut url = Url::parse(&config.auth_endpoint())
let mut url = Url::parse(&kc.auth_endpoint())
.map_err(|e| Error::StateError(format!("invalid auth endpoint URL: {e}")))?;
url.query_pairs_mut()
.append_pair("client_id", &config.client_id)
.append_pair("redirect_uri", &config.redirect_uri)
.append_pair("client_id", &kc.client_id)
.append_pair("redirect_uri", &kc.redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", "openid profile email")
.append_pair("state", &state)
.append_pair("state", &csrf_state)
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256");
@@ -213,6 +154,10 @@ struct UserinfoResponse {
/// The subject identifier (unique user ID in Keycloak).
sub: String,
email: Option<String>,
/// Keycloak `preferred_username` claim.
preferred_username: Option<String>,
/// Full name from the Keycloak profile.
name: Option<String>,
/// Keycloak may include a picture/avatar URL via protocol mappers.
picture: Option<String>,
}
@@ -234,10 +179,11 @@ struct UserinfoResponse {
#[axum::debug_handler]
pub async fn auth_callback(
session: Session,
Extension(state): Extension<ServerState>,
Extension(pending): Extension<PendingOAuthStore>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse, Error> {
let config = OAuthConfig::from_env()?;
let kc = state.keycloak;
// --- CSRF validation via the in-memory pending store ---
let returned_state = params
@@ -255,11 +201,11 @@ pub async fn auth_callback(
let client = reqwest::Client::new();
let token_resp = client
.post(config.token_endpoint())
.post(kc.token_endpoint())
.form(&[
("grant_type", "authorization_code"),
("client_id", &config.client_id),
("redirect_uri", &config.redirect_uri),
("client_id", kc.client_id.as_str()),
("redirect_uri", kc.redirect_uri.as_str()),
("code", code),
("code_verifier", &entry.code_verifier),
])
@@ -279,7 +225,7 @@ pub async fn auth_callback(
// --- Fetch userinfo ---
let userinfo: UserinfoResponse = client
.get(config.userinfo_endpoint())
.get(kc.userinfo_endpoint())
.bearer_auth(&tokens.access_token)
.send()
.await
@@ -288,6 +234,12 @@ pub async fn auth_callback(
.await
.map_err(|e| Error::StateError(format!("userinfo parse failed: {e}")))?;
// Prefer `name`, fall back to `preferred_username`, then empty.
let display_name = userinfo
.name
.or(userinfo.preferred_username)
.unwrap_or_default();
// --- Build user state and persist in session ---
let user_state = UserStateInner {
sub: userinfo.sub,
@@ -295,6 +247,7 @@ pub async fn auth_callback(
refresh_token: tokens.refresh_token.unwrap_or_default(),
user: User {
email: userinfo.email.unwrap_or_default(),
name: display_name,
avatar_url: userinfo.picture.unwrap_or_default(),
},
};
@@ -316,10 +269,13 @@ pub async fn auth_callback(
///
/// # Errors
///
/// Returns `Error` if env vars are missing or the session cannot be flushed.
/// Returns `Error` if the session cannot be flushed or the URL is malformed.
#[axum::debug_handler]
pub async fn logout(session: Session) -> Result<impl IntoResponse, Error> {
let config = OAuthConfig::from_env()?;
pub async fn logout(
session: Session,
Extension(state): Extension<ServerState>,
) -> Result<impl IntoResponse, Error> {
let kc = state.keycloak;
// Flush all session data.
session
@@ -327,12 +283,12 @@ pub async fn logout(session: Session) -> Result<impl IntoResponse, Error> {
.await
.map_err(|e| Error::StateError(format!("session flush failed: {e}")))?;
let mut url = Url::parse(&config.logout_endpoint())
let mut url = Url::parse(&kc.logout_endpoint())
.map_err(|e| Error::StateError(format!("invalid logout endpoint URL: {e}")))?;
url.query_pairs_mut()
.append_pair("client_id", &config.client_id)
.append_pair("post_logout_redirect_uri", &config.app_url);
.append_pair("client_id", &kc.client_id)
.append_pair("post_logout_redirect_uri", &kc.app_url);
Ok(Redirect::temporary(url.as_str()))
}

View File

@@ -0,0 +1,36 @@
use crate::models::AuthInfo;
use dioxus::prelude::*;
/// Check the current user's authentication state.
///
/// Reads the tower-sessions session on the server and returns an
/// [`AuthInfo`] describing the logged-in user. When no valid session
/// exists, `authenticated` is `false` and all other fields are empty.
///
/// # Errors
///
/// Returns `ServerFnError` if the session store cannot be read.
#[server(endpoint = "check-auth")]
pub async fn check_auth() -> Result<AuthInfo, ServerFnError> {
use crate::infrastructure::auth::LOGGED_IN_USER_SESS_KEY;
use crate::infrastructure::state::UserStateInner;
use dioxus_fullstack::FullstackContext;
let session: tower_sessions::Session = FullstackContext::extract().await?;
let user_state: Option<UserStateInner> = session
.get(LOGGED_IN_USER_SESS_KEY)
.await
.map_err(|e| ServerFnError::new(format!("session read failed: {e}")))?;
match user_state {
Some(u) => Ok(AuthInfo {
authenticated: true,
sub: u.sub,
email: u.user.email,
name: u.user.name,
avatar_url: u.user.avatar_url,
}),
None => Ok(AuthInfo::default()),
}
}

View File

@@ -0,0 +1,41 @@
use axum::{
extract::Request,
middleware::Next,
response::{IntoResponse, Response},
};
use reqwest::StatusCode;
use tower_sessions::Session;
use crate::infrastructure::auth::LOGGED_IN_USER_SESS_KEY;
use crate::infrastructure::state::UserStateInner;
/// Server function endpoints that are allowed without authentication.
///
/// `check-auth` must be public so the frontend can determine login state.
const PUBLIC_API_ENDPOINTS: &[&str] = &["/api/check-auth"];
/// Axum middleware that enforces authentication on `/api/` server
/// function endpoints.
///
/// Requests whose path starts with `/api/` (except those listed in
/// [`PUBLIC_API_ENDPOINTS`]) are rejected with `401 Unauthorized` when
/// no valid session exists. All other paths pass through untouched.
pub async fn require_auth(session: Session, request: Request, next: Next) -> Response {
let path = request.uri().path();
// Only gate /api/ server function routes.
if path.starts_with("/api/") && !PUBLIC_API_ENDPOINTS.contains(&path) {
let is_authed = session
.get::<UserStateInner>(LOGGED_IN_USER_SESS_KEY)
.await
.ok()
.flatten()
.is_some();
if !is_authed {
return (StatusCode::UNAUTHORIZED, "Authentication required").into_response();
}
}
next.run(request).await
}

View File

@@ -0,0 +1,253 @@
//! Configuration structs loaded once at startup from environment variables.
//!
//! Each struct provides a `from_env()` constructor that reads `std::env::var`
//! values. Required variables cause an `Error::ConfigError` on failure;
//! optional ones default to an empty string.
use secrecy::SecretString;
use super::Error;
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Read a required environment variable or return `Error::ConfigError`.
fn required_env(name: &str) -> Result<String, Error> {
std::env::var(name).map_err(|_| Error::ConfigError(format!("{name} is required but not set")))
}
/// Read an optional environment variable, defaulting to an empty string.
fn optional_env(name: &str) -> String {
std::env::var(name).unwrap_or_default()
}
// ---------------------------------------------------------------------------
// KeycloakConfig
// ---------------------------------------------------------------------------
/// Keycloak OpenID Connect settings for the public (frontend) client.
///
/// Also carries the admin service-account credentials used for
/// server-to-server calls (e.g. user management APIs).
#[derive(Debug)]
pub struct KeycloakConfig {
/// Base URL of the Keycloak instance (e.g. `http://localhost:8080`).
pub url: String,
/// Keycloak realm name.
pub realm: String,
/// Public client ID used by the dashboard frontend.
pub client_id: String,
/// OAuth redirect URI registered in Keycloak.
pub redirect_uri: String,
/// Root URL of this application (used for post-logout redirect).
pub app_url: String,
/// Confidential client ID for admin/server-to-server calls.
pub admin_client_id: String,
/// Confidential client secret (wrapped for debug safety).
pub admin_client_secret: SecretString,
}
impl KeycloakConfig {
/// Load Keycloak configuration from environment variables.
///
/// # Errors
///
/// Returns `Error::ConfigError` if a required variable is missing.
pub fn from_env() -> Result<Self, Error> {
Ok(Self {
url: required_env("KEYCLOAK_URL")?,
realm: required_env("KEYCLOAK_REALM")?,
client_id: required_env("KEYCLOAK_CLIENT_ID")?,
redirect_uri: required_env("REDIRECT_URI")?,
app_url: required_env("APP_URL")?,
admin_client_id: optional_env("KEYCLOAK_ADMIN_CLIENT_ID"),
admin_client_secret: SecretString::from(optional_env("KEYCLOAK_ADMIN_CLIENT_SECRET")),
})
}
/// OpenID Connect authorization endpoint URL.
pub fn auth_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/auth",
self.url, self.realm
)
}
/// OpenID Connect token endpoint URL.
pub fn token_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/token",
self.url, self.realm
)
}
/// OpenID Connect userinfo endpoint URL.
pub fn userinfo_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/userinfo",
self.url, self.realm
)
}
/// OpenID Connect end-session (logout) endpoint URL.
pub fn logout_endpoint(&self) -> String {
format!(
"{}/realms/{}/protocol/openid-connect/logout",
self.url, self.realm
)
}
}
// ---------------------------------------------------------------------------
// SmtpConfig
// ---------------------------------------------------------------------------
/// SMTP mail settings for transactional emails (invites, alerts, etc.).
#[derive(Debug)]
pub struct SmtpConfig {
/// SMTP server hostname.
pub host: String,
/// SMTP server port (as string for flexibility, e.g. "587").
pub port: String,
/// SMTP username.
pub username: String,
/// SMTP password (wrapped for debug safety).
pub password: SecretString,
/// Sender address shown in the `From:` header.
pub from_address: String,
}
impl SmtpConfig {
/// Load SMTP configuration from environment variables.
///
/// All fields are optional; defaults to empty strings when absent.
///
/// # Errors
///
/// Currently infallible but returns `Result` for consistency.
pub fn from_env() -> Result<Self, Error> {
Ok(Self {
host: optional_env("SMTP_HOST"),
port: optional_env("SMTP_PORT"),
username: optional_env("SMTP_USERNAME"),
password: SecretString::from(optional_env("SMTP_PASSWORD")),
from_address: optional_env("SMTP_FROM_ADDRESS"),
})
}
}
// ---------------------------------------------------------------------------
// ServiceUrls
// ---------------------------------------------------------------------------
/// URLs and credentials for external services (Ollama, SearXNG, S3, etc.).
#[derive(Debug)]
pub struct ServiceUrls {
/// Ollama LLM instance base URL.
pub ollama_url: String,
/// Default Ollama model to use.
pub ollama_model: String,
/// SearXNG meta-search engine base URL.
pub searxng_url: String,
/// LangChain service URL.
pub langchain_url: String,
/// LangGraph service URL.
pub langgraph_url: String,
/// Langfuse observability URL.
pub langfuse_url: String,
/// Vector database URL.
pub vectordb_url: String,
/// S3-compatible object storage URL.
pub s3_url: String,
/// S3 access key.
pub s3_access_key: String,
/// S3 secret key (wrapped for debug safety).
pub s3_secret_key: SecretString,
}
impl ServiceUrls {
/// Load service URLs from environment variables.
///
/// All fields are optional with sensible defaults where applicable.
///
/// # Errors
///
/// Currently infallible but returns `Result` for consistency.
pub fn from_env() -> Result<Self, Error> {
Ok(Self {
ollama_url: std::env::var("OLLAMA_URL")
.unwrap_or_else(|_| "http://localhost:11434".into()),
ollama_model: std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama3.1:8b".into()),
searxng_url: std::env::var("SEARXNG_URL")
.unwrap_or_else(|_| "http://localhost:8888".into()),
langchain_url: optional_env("LANGCHAIN_URL"),
langgraph_url: optional_env("LANGGRAPH_URL"),
langfuse_url: optional_env("LANGFUSE_URL"),
vectordb_url: optional_env("VECTORDB_URL"),
s3_url: optional_env("S3_URL"),
s3_access_key: optional_env("S3_ACCESS_KEY"),
s3_secret_key: SecretString::from(optional_env("S3_SECRET_KEY")),
})
}
}
// ---------------------------------------------------------------------------
// StripeConfig
// ---------------------------------------------------------------------------
/// Stripe billing configuration.
#[derive(Debug)]
pub struct StripeConfig {
/// Stripe secret API key (wrapped for debug safety).
pub secret_key: SecretString,
/// Stripe webhook signing secret (wrapped for debug safety).
pub webhook_secret: SecretString,
/// Stripe publishable key (safe to expose to the frontend).
pub publishable_key: String,
}
impl StripeConfig {
/// Load Stripe configuration from environment variables.
///
/// # Errors
///
/// Currently infallible but returns `Result` for consistency.
pub fn from_env() -> Result<Self, Error> {
Ok(Self {
secret_key: SecretString::from(optional_env("STRIPE_SECRET_KEY")),
webhook_secret: SecretString::from(optional_env("STRIPE_WEBHOOK_SECRET")),
publishable_key: optional_env("STRIPE_PUBLISHABLE_KEY"),
})
}
}
// ---------------------------------------------------------------------------
// LlmProvidersConfig
// ---------------------------------------------------------------------------
/// Comma-separated list of enabled LLM provider identifiers.
///
/// For example: `LLM_PROVIDERS=ollama,openai,anthropic`
#[derive(Debug)]
pub struct LlmProvidersConfig {
/// Parsed provider names.
pub providers: Vec<String>,
}
impl LlmProvidersConfig {
/// Load the provider list from `LLM_PROVIDERS`.
///
/// # Errors
///
/// Currently infallible but returns `Result` for consistency.
pub fn from_env() -> Result<Self, Error> {
let raw = optional_env("LLM_PROVIDERS");
let providers: Vec<String> = raw
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
Ok(Self { providers })
}
}

View File

@@ -0,0 +1,52 @@
//! MongoDB connection wrapper with typed collection accessors.
use mongodb::{bson::doc, Client, Collection};
use super::Error;
use crate::models::{OrgBillingRecord, OrgSettings, UserPreferences};
/// Thin wrapper around [`mongodb::Database`] that provides typed
/// collection accessors for the application's domain models.
#[derive(Clone, Debug)]
pub struct Database {
inner: mongodb::Database,
}
impl Database {
/// Connect to MongoDB, select the given database, and verify
/// connectivity with a `ping` command.
///
/// # Arguments
///
/// * `uri` - MongoDB connection string (e.g. `mongodb://localhost:27017`)
/// * `db_name` - Database name to use
///
/// # Errors
///
/// Returns `Error::DatabaseError` if the client cannot be created
/// or the ping fails.
pub async fn connect(uri: &str, db_name: &str) -> Result<Self, Error> {
let client = Client::with_uri_str(uri).await?;
let db = client.database(db_name);
// Verify the connection is alive.
db.run_command(doc! { "ping": 1 }).await?;
Ok(Self { inner: db })
}
/// Collection for per-user preferences (theme, custom topics, etc.).
pub fn user_preferences(&self) -> Collection<UserPreferences> {
self.inner.collection("user_preferences")
}
/// Collection for organisation-level settings.
pub fn org_settings(&self) -> Collection<OrgSettings> {
self.inner.collection("org_settings")
}
/// Collection for per-cycle billing records.
pub fn org_billing(&self) -> Collection<OrgBillingRecord> {
self.inner.collection("org_billing")
}
}

View File

@@ -1,22 +1,43 @@
use axum::response::IntoResponse;
use reqwest::StatusCode;
/// Central error type for infrastructure-layer failures.
///
/// Each variant maps to an appropriate HTTP status code when converted
/// into an Axum response.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("{0}")]
StateError(String),
#[error("database error: {0}")]
DatabaseError(String),
#[error("configuration error: {0}")]
ConfigError(String),
#[error("IoError: {0}")]
IoError(#[from] std::io::Error),
}
impl From<mongodb::error::Error> for Error {
fn from(err: mongodb::error::Error) -> Self {
Self::DatabaseError(err.to_string())
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
let msg = self.to_string();
tracing::error!("Converting Error to Response: {msg}");
match self {
Self::StateError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(),
_ => (StatusCode::INTERNAL_SERVER_ERROR, "Unknown error").into_response(),
Self::StateError(e) | Self::ConfigError(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e).into_response()
}
Self::DatabaseError(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
Self::IoError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown error").into_response()
}
}
}
}

View File

@@ -166,19 +166,20 @@ pub async fn summarize_article(
ollama_url: String,
model: String,
) -> Result<String, ServerFnError> {
dotenvy::dotenv().ok();
use inner::{fetch_article_text, ChatMessage, OllamaChatRequest, OllamaChatResponse};
// Fall back to env var or default if the URL is empty
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
// Use caller-provided values or fall back to ServerState config
let base_url = if ollama_url.is_empty() {
std::env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".into())
state.services.ollama_url.clone()
} else {
ollama_url
};
// Fall back to env var or default if the model is empty
let model = if model.is_empty() {
std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama3.1:8b".into())
state.services.ollama_model.clone()
} else {
model
};
@@ -264,17 +265,19 @@ pub async fn chat_followup(
ollama_url: String,
model: String,
) -> Result<String, ServerFnError> {
dotenvy::dotenv().ok();
use inner::{ChatMessage, OllamaChatRequest, OllamaChatResponse};
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let base_url = if ollama_url.is_empty() {
std::env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".into())
state.services.ollama_url.clone()
} else {
ollama_url
};
let model = if model.is_empty() {
std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama3.1:8b".into())
state.services.ollama_model.clone()
} else {
model
};

View File

@@ -1,24 +1,37 @@
// Server function modules (compiled for both web and server features;
// the #[server] macro generates client stubs for the web target)
pub mod auth_check;
pub mod llm;
pub mod ollama;
pub mod searxng;
// Server-only modules (Axum handlers, state, etc.)
// Server-only modules (Axum handlers, state, configs, DB, etc.)
#[cfg(feature = "server")]
mod auth;
#[cfg(feature = "server")]
mod auth_middleware;
#[cfg(feature = "server")]
pub mod config;
#[cfg(feature = "server")]
pub mod database;
#[cfg(feature = "server")]
mod error;
#[cfg(feature = "server")]
mod server;
#[cfg(feature = "server")]
pub mod server_state;
#[cfg(feature = "server")]
mod state;
#[cfg(feature = "server")]
pub use auth::*;
#[cfg(feature = "server")]
pub use auth_middleware::*;
#[cfg(feature = "server")]
pub use error::*;
#[cfg(feature = "server")]
pub use server::*;
#[cfg(feature = "server")]
pub use server_state::*;
#[cfg(feature = "server")]
pub use state::*;

View File

@@ -47,10 +47,11 @@ struct OllamaModel {
/// are caught and returned as `online: false`
#[post("/api/ollama-status")]
pub async fn get_ollama_status(ollama_url: String) -> Result<OllamaStatus, ServerFnError> {
dotenvy::dotenv().ok();
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let base_url = if ollama_url.is_empty() {
std::env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".into())
state.services.ollama_url.clone()
} else {
ollama_url
};

View File

@@ -112,11 +112,11 @@ mod inner {
/// Returns `ServerFnError` if the SearXNG request fails or response parsing fails
#[post("/api/search")]
pub async fn search_topic(query: String) -> Result<Vec<NewsCard>, ServerFnError> {
dotenvy::dotenv().ok();
use inner::{extract_source, rank_and_deduplicate, SearxngResponse};
let searxng_url =
std::env::var("SEARXNG_URL").unwrap_or_else(|_| "http://localhost:8888".into());
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let searxng_url = state.services.searxng_url.clone();
// Enrich the query with "latest news" context for better results,
// similar to how Perplexity reformulates queries before searching.
@@ -198,12 +198,12 @@ pub async fn search_topic(query: String) -> Result<Vec<NewsCard>, ServerFnError>
/// Returns `ServerFnError` if the SearXNG search request fails
#[get("/api/trending")]
pub async fn get_trending_topics() -> Result<Vec<String>, ServerFnError> {
dotenvy::dotenv().ok();
use inner::SearxngResponse;
use std::collections::HashMap;
let searxng_url =
std::env::var("SEARXNG_URL").unwrap_or_else(|_| "http://localhost:8888".into());
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let searxng_url = state.services.searxng_url.clone();
// Use POST to match SearXNG's default `method: "POST"` setting
let search_url = format!("{searxng_url}/search");

View File

@@ -1,54 +1,94 @@
use crate::infrastructure::{
auth_callback, auth_login, logout, PendingOAuthStore, UserState, UserStateInner,
};
use dioxus::prelude::*;
use axum::routing::get;
use axum::Extension;
use axum::{middleware, Extension};
use time::Duration;
use tower_sessions::{cookie::Key, MemoryStore, SessionManagerLayer};
use crate::infrastructure::{
auth_callback, auth_login,
config::{KeycloakConfig, LlmProvidersConfig, ServiceUrls, SmtpConfig, StripeConfig},
database::Database,
logout, require_auth,
server_state::{ServerState, ServerStateInner},
PendingOAuthStore,
};
/// Start the Axum server with Dioxus fullstack, session management,
/// and Keycloak OAuth routes.
/// MongoDB, and Keycloak OAuth routes.
///
/// Loads all configuration from environment variables once, connects
/// to MongoDB, and builds a [`ServerState`] shared across every request.
///
/// # Errors
///
/// Returns `Error` if the tokio runtime or TCP listener fails to start.
/// Returns `Error` if the tokio runtime, config loading, DB connection,
/// or TCP listener fails.
pub fn server_start(app: fn() -> Element) -> Result<(), super::Error> {
tokio::runtime::Runtime::new()?.block_on(async move {
let state: UserState = UserStateInner {
access_token: "abcd".into(),
sub: "abcd".into(),
refresh_token: "abcd".into(),
..Default::default()
// Load .env once at startup.
dotenvy::dotenv().ok();
// ---- Load and leak config structs for 'static lifetime ----
let keycloak: &'static KeycloakConfig = Box::leak(Box::new(KeycloakConfig::from_env()?));
let smtp: &'static SmtpConfig = Box::leak(Box::new(SmtpConfig::from_env()?));
let services: &'static ServiceUrls = Box::leak(Box::new(ServiceUrls::from_env()?));
let stripe: &'static StripeConfig = Box::leak(Box::new(StripeConfig::from_env()?));
let llm_providers: &'static LlmProvidersConfig =
Box::leak(Box::new(LlmProvidersConfig::from_env()?));
tracing::info!("Configuration loaded");
// ---- Connect to MongoDB ----
let mongo_uri =
std::env::var("MONGODB_URI").unwrap_or_else(|_| "mongodb://localhost:27017".into());
let mongo_db = std::env::var("MONGODB_DATABASE").unwrap_or_else(|_| "certifai".into());
let db = Database::connect(&mongo_uri, &mongo_db).await?;
tracing::info!("Connected to MongoDB (database: {mongo_db})");
// ---- Build ServerState ----
let server_state: ServerState = ServerStateInner {
db,
keycloak,
smtp,
services,
stripe,
llm_providers,
}
.into();
// ---- Session layer ----
let key = Key::generate();
let store = MemoryStore::default();
let session = SessionManagerLayer::new(store)
.with_secure(false)
// Lax is required so the browser sends the session cookie
// on the redirect back from Keycloak (cross-origin GET).
// Strict would silently drop the cookie on that navigation.
.with_same_site(tower_sessions::cookie::SameSite::Lax)
.with_expiry(tower_sessions::Expiry::OnInactivity(Duration::hours(24)))
.with_signed(key);
// ---- Build router ----
let addr = dioxus_cli_config::fullstack_address_or_localhost();
let listener = tokio::net::TcpListener::bind(addr).await?;
// Layers are applied AFTER serve_dioxus_application so they
// wrap both the custom Axum routes AND the Dioxus server
// function routes (e.g. check_auth needs Session access).
// Layers wrap in reverse order: session (outermost) -> auth
// middleware -> extensions -> route handlers. The session layer
// must be outermost so the `Session` extractor is available to
// the auth middleware, which gates all `/api/` server function
// routes (except `check-auth`).
let router = axum::Router::new()
.route("/auth", get(auth_login))
.route("/auth/callback", get(auth_callback))
.route("/logout", get(logout))
.serve_dioxus_application(ServeConfig::new(), app)
.layer(Extension(PendingOAuthStore::default()))
.layer(Extension(state))
.layer(Extension(server_state))
.layer(middleware::from_fn(require_auth))
.layer(session);
info!("Serving at {addr}");
tracing::info!("Serving at {addr}");
axum::serve(listener, router.into_make_service()).await?;
Ok(())

View File

@@ -0,0 +1,74 @@
//! Application-wide server state available in both Axum handlers and
//! Dioxus server functions via `extract()`.
//!
//! ```rust,ignore
//! // Inside a #[server] function:
//! let state: ServerState = extract().await?;
//! ```
use std::{ops::Deref, sync::Arc};
use super::{
config::{KeycloakConfig, LlmProvidersConfig, ServiceUrls, SmtpConfig, StripeConfig},
database::Database,
Error,
};
/// Cheap-to-clone handle to the shared server state.
///
/// Stored as an Axum `Extension` so it is accessible from both
/// route handlers and Dioxus `#[server]` functions.
#[derive(Clone)]
pub struct ServerState(Arc<ServerStateInner>);
impl Deref for ServerState {
type Target = ServerStateInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<ServerStateInner> for ServerState {
fn from(value: ServerStateInner) -> Self {
Self(Arc::new(value))
}
}
/// Inner struct holding all long-lived application resources.
///
/// Config references are `&'static` because they are `Box::leak`ed
/// at startup -- they never change at runtime.
pub struct ServerStateInner {
/// MongoDB connection pool.
pub db: Database,
/// Keycloak / OAuth2 settings.
pub keycloak: &'static KeycloakConfig,
/// Outbound email settings.
pub smtp: &'static SmtpConfig,
/// URLs for Ollama, SearXNG, LangChain, S3, etc.
pub services: &'static ServiceUrls,
/// Stripe billing keys.
pub stripe: &'static StripeConfig,
/// Enabled LLM provider list.
pub llm_providers: &'static LlmProvidersConfig,
}
// `FromRequestParts` lets us `extract::<ServerState>()` inside
// Dioxus server functions and regular Axum handlers alike.
impl<S> axum::extract::FromRequestParts<S> for ServerState
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<ServerState>()
.cloned()
.ok_or(Error::StateError("ServerState extension not found".into()))
}
}

View File

@@ -1,8 +1,8 @@
use std::{ops::Deref, sync::Arc};
use axum::extract::FromRequestParts;
use serde::{Deserialize, Serialize};
/// Cheap-to-clone handle to per-session user data.
#[derive(Debug, Clone)]
pub struct UserState(Arc<UserStateInner>);
@@ -19,39 +19,28 @@ impl From<UserStateInner> for UserState {
}
}
/// Per-session user data stored in the tower-sessions session store.
///
/// Persisted across requests for the lifetime of the session.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct UserStateInner {
/// Subject in Oauth
/// Subject identifier from Keycloak (unique user ID).
pub sub: String,
/// Access Token
/// OAuth2 access token.
pub access_token: String,
/// Refresh Token
/// OAuth2 refresh token.
pub refresh_token: String,
/// User
/// Basic user profile.
pub user: User,
}
/// Basic user profile stored alongside the session.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct User {
/// Email
/// Email address.
pub email: String,
/// Avatar Url
/// Display name (preferred_username or full name from Keycloak).
pub name: String,
/// Avatar / profile picture URL.
pub avatar_url: String,
}
impl<S> FromRequestParts<S> for UserState
where
S: std::marker::Sync + std::marker::Send,
{
type Rejection = super::Error;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_: &S,
) -> Result<Self, super::Error> {
parts
.extensions
.get::<UserState>()
.cloned()
.ok_or(super::Error::StateError("Unable to get extension".into()))
}
}