diff --git a/Cargo.lock b/Cargo.lock index 57bf95e..9e9ff42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -776,6 +776,7 @@ dependencies = [ "maud", "mongodb", "petname", + "pretty_assertions", "pulldown-cmark", "rand 0.10.0", "reqwest 0.13.2", @@ -783,6 +784,7 @@ dependencies = [ "secrecy", "serde", "serde_json", + "serial_test", "sha2", "thiserror 2.0.18", "time", @@ -882,6 +884,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -3246,6 +3254,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -3823,6 +3841,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.28" @@ -3862,6 +3889,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + [[package]] name = "secrecy" version = "0.10.3" @@ -4082,6 +4115,32 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "serial_test" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "servo_arc" version = "0.4.3" @@ -5683,6 +5742,12 @@ version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yazi" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 8caa25c..9de0a77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,6 +112,10 @@ server = [ "dep:bytes", ] +[dev-dependencies] +pretty_assertions = "1.4" +serial_test = "3.2" + [[bin]] name = "dashboard" path = "bin/main.rs" diff --git a/src/infrastructure/auth.rs b/src/infrastructure/auth.rs index 9894878..8c46836 100644 --- a/src/infrastructure/auth.rs +++ b/src/infrastructure/auth.rs @@ -24,9 +24,9 @@ pub const LOGGED_IN_USER_SESS_KEY: &str = "logged-in-user"; /// post-login redirect URL and the PKCE code verifier needed for the /// token exchange. #[derive(Debug, Clone)] -struct PendingOAuthEntry { - redirect_url: Option, - code_verifier: String, +pub(crate) struct PendingOAuthEntry { + pub(crate) redirect_url: Option, + pub(crate) code_verifier: String, } /// In-memory store for pending OAuth states. Keyed by the random state @@ -38,7 +38,7 @@ pub struct PendingOAuthStore(Arc>>); impl PendingOAuthStore { /// Insert a pending state with an optional redirect URL and PKCE verifier. - fn insert(&self, state: String, entry: PendingOAuthEntry) { + pub(crate) fn insert(&self, state: String, entry: PendingOAuthEntry) { // RwLock::write only panics if the lock is poisoned, which // indicates a prior panic -- propagating is acceptable here. #[allow(clippy::expect_used)] @@ -50,7 +50,7 @@ impl PendingOAuthStore { /// Remove and return the entry if the state was pending. /// Returns `None` if the state was never stored (CSRF failure). - fn take(&self, state: &str) -> Option { + pub(crate) fn take(&self, state: &str) -> Option { #[allow(clippy::expect_used)] self.0 .write() @@ -60,7 +60,8 @@ impl PendingOAuthStore { } /// Generate a cryptographically random state string for CSRF protection. -fn generate_state() -> String { +#[cfg_attr(test, allow(dead_code))] +pub(crate) fn generate_state() -> String { let bytes: [u8; 32] = rand::rng().random(); // Encode as hex to produce a URL-safe string without padding. bytes.iter().fold(String::with_capacity(64), |mut acc, b| { @@ -75,7 +76,7 @@ fn generate_state() -> String { /// /// Uses 32 random bytes encoded as base64url (no padding) to produce /// a 43-character verifier per RFC 7636. -fn generate_code_verifier() -> String { +pub(crate) fn generate_code_verifier() -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; let bytes: [u8; 32] = rand::rng().random(); @@ -85,7 +86,7 @@ fn generate_code_verifier() -> String { /// Derive the S256 code challenge from a code verifier per RFC 7636. /// /// `code_challenge = BASE64URL(SHA256(code_verifier))` -fn derive_code_challenge(verifier: &str) -> String { +pub(crate) fn derive_code_challenge(verifier: &str) -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use sha2::{Digest, Sha256}; @@ -304,3 +305,117 @@ pub async fn set_login_session(session: Session, data: UserStateInner) -> Result .await .map_err(|e| Error::StateError(format!("session insert failed: {e}"))) } + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used, clippy::expect_used)] + + use super::*; + use pretty_assertions::assert_eq; + + // ----------------------------------------------------------------------- + // generate_state() + // ----------------------------------------------------------------------- + + #[test] + fn generate_state_length_is_64() { + let state = generate_state(); + assert_eq!(state.len(), 64); + } + + #[test] + fn generate_state_chars_are_hex() { + let state = generate_state(); + assert!(state.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn generate_state_two_calls_differ() { + let a = generate_state(); + let b = generate_state(); + assert_ne!(a, b); + } + + // ----------------------------------------------------------------------- + // generate_code_verifier() + // ----------------------------------------------------------------------- + + #[test] + fn code_verifier_length_is_43() { + let verifier = generate_code_verifier(); + assert_eq!(verifier.len(), 43); + } + + #[test] + fn code_verifier_chars_are_url_safe_base64() { + let verifier = generate_code_verifier(); + // URL-safe base64 without padding uses [A-Za-z0-9_-] + assert!(verifier + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); + } + + // ----------------------------------------------------------------------- + // derive_code_challenge() + // ----------------------------------------------------------------------- + + #[test] + fn code_challenge_deterministic() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + let a = derive_code_challenge(verifier); + let b = derive_code_challenge(verifier); + assert_eq!(a, b); + } + + #[test] + fn code_challenge_rfc7636_test_vector() { + // RFC 7636 Appendix B test vector: + // verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + // expected challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + let challenge = derive_code_challenge(verifier); + assert_eq!(challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); + } + + // ----------------------------------------------------------------------- + // PendingOAuthStore + // ----------------------------------------------------------------------- + + #[test] + fn pending_store_insert_and_take() { + let store = PendingOAuthStore::default(); + store.insert( + "state-1".into(), + PendingOAuthEntry { + redirect_url: Some("/dashboard".into()), + code_verifier: "verifier-1".into(), + }, + ); + let entry = store.take("state-1"); + assert!(entry.is_some()); + let entry = entry.unwrap(); + assert_eq!(entry.redirect_url, Some("/dashboard".into())); + assert_eq!(entry.code_verifier, "verifier-1"); + } + + #[test] + fn pending_store_take_removes_entry() { + let store = PendingOAuthStore::default(); + store.insert( + "state-2".into(), + PendingOAuthEntry { + redirect_url: None, + code_verifier: "v2".into(), + }, + ); + let _ = store.take("state-2"); + // Second take should return None since the entry was removed. + assert!(store.take("state-2").is_none()); + } + + #[test] + fn pending_store_take_unknown_returns_none() { + let store = PendingOAuthStore::default(); + assert!(store.take("nonexistent").is_none()); + } +} diff --git a/src/infrastructure/config.rs b/src/infrastructure/config.rs index c068aa7..3ce3ac5 100644 --- a/src/infrastructure/config.rs +++ b/src/infrastructure/config.rs @@ -251,3 +251,160 @@ impl LlmProvidersConfig { Ok(Self { providers }) } } + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used, clippy::expect_used)] + + use super::*; + use pretty_assertions::assert_eq; + use serial_test::serial; + + // ----------------------------------------------------------------------- + // KeycloakConfig endpoint methods (no env vars needed) + // ----------------------------------------------------------------------- + + fn sample_keycloak() -> KeycloakConfig { + KeycloakConfig { + url: "https://auth.example.com".into(), + realm: "myrealm".into(), + client_id: "dashboard".into(), + redirect_uri: "https://app.example.com/callback".into(), + app_url: "https://app.example.com".into(), + admin_client_id: String::new(), + admin_client_secret: SecretString::from(String::new()), + } + } + + #[test] + fn keycloak_auth_endpoint() { + let kc = sample_keycloak(); + assert_eq!( + kc.auth_endpoint(), + "https://auth.example.com/realms/myrealm/protocol/openid-connect/auth" + ); + } + + #[test] + fn keycloak_token_endpoint() { + let kc = sample_keycloak(); + assert_eq!( + kc.token_endpoint(), + "https://auth.example.com/realms/myrealm/protocol/openid-connect/token" + ); + } + + #[test] + fn keycloak_userinfo_endpoint() { + let kc = sample_keycloak(); + assert_eq!( + kc.userinfo_endpoint(), + "https://auth.example.com/realms/myrealm/protocol/openid-connect/userinfo" + ); + } + + #[test] + fn keycloak_logout_endpoint() { + let kc = sample_keycloak(); + assert_eq!( + kc.logout_endpoint(), + "https://auth.example.com/realms/myrealm/protocol/openid-connect/logout" + ); + } + + // ----------------------------------------------------------------------- + // LlmProvidersConfig::from_env() + // ----------------------------------------------------------------------- + + #[test] + #[serial] + fn llm_providers_empty_string() { + std::env::set_var("LLM_PROVIDERS", ""); + let cfg = LlmProvidersConfig::from_env().unwrap(); + assert!(cfg.providers.is_empty()); + std::env::remove_var("LLM_PROVIDERS"); + } + + #[test] + #[serial] + fn llm_providers_single() { + std::env::set_var("LLM_PROVIDERS", "ollama"); + let cfg = LlmProvidersConfig::from_env().unwrap(); + assert_eq!(cfg.providers, vec!["ollama"]); + std::env::remove_var("LLM_PROVIDERS"); + } + + #[test] + #[serial] + fn llm_providers_multiple() { + std::env::set_var("LLM_PROVIDERS", "ollama,openai,anthropic"); + let cfg = LlmProvidersConfig::from_env().unwrap(); + assert_eq!(cfg.providers, vec!["ollama", "openai", "anthropic"]); + std::env::remove_var("LLM_PROVIDERS"); + } + + #[test] + #[serial] + fn llm_providers_trims_whitespace() { + std::env::set_var("LLM_PROVIDERS", " ollama , openai "); + let cfg = LlmProvidersConfig::from_env().unwrap(); + assert_eq!(cfg.providers, vec!["ollama", "openai"]); + std::env::remove_var("LLM_PROVIDERS"); + } + + #[test] + #[serial] + fn llm_providers_filters_empty_entries() { + std::env::set_var("LLM_PROVIDERS", "ollama,,openai,"); + let cfg = LlmProvidersConfig::from_env().unwrap(); + assert_eq!(cfg.providers, vec!["ollama", "openai"]); + std::env::remove_var("LLM_PROVIDERS"); + } + + // ----------------------------------------------------------------------- + // ServiceUrls::from_env() defaults + // ----------------------------------------------------------------------- + + #[test] + #[serial] + fn service_urls_default_ollama_url() { + std::env::remove_var("OLLAMA_URL"); + let svc = ServiceUrls::from_env().unwrap(); + assert_eq!(svc.ollama_url, "http://localhost:11434"); + } + + #[test] + #[serial] + fn service_urls_default_ollama_model() { + std::env::remove_var("OLLAMA_MODEL"); + let svc = ServiceUrls::from_env().unwrap(); + assert_eq!(svc.ollama_model, "llama3.1:8b"); + } + + #[test] + #[serial] + fn service_urls_default_searxng_url() { + std::env::remove_var("SEARXNG_URL"); + let svc = ServiceUrls::from_env().unwrap(); + assert_eq!(svc.searxng_url, "http://localhost:8888"); + } + + #[test] + #[serial] + fn service_urls_custom_ollama_url() { + std::env::set_var("OLLAMA_URL", "http://gpu-host:11434"); + let svc = ServiceUrls::from_env().unwrap(); + assert_eq!(svc.ollama_url, "http://gpu-host:11434"); + std::env::remove_var("OLLAMA_URL"); + } + + #[test] + #[serial] + fn required_env_missing_returns_config_error() { + std::env::remove_var("__TEST_REQUIRED_MISSING__"); + let result = required_env("__TEST_REQUIRED_MISSING__"); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("__TEST_REQUIRED_MISSING__")); + } +} diff --git a/src/infrastructure/error.rs b/src/infrastructure/error.rs index 65b2d51..838d20a 100644 --- a/src/infrastructure/error.rs +++ b/src/infrastructure/error.rs @@ -41,3 +41,53 @@ impl IntoResponse for Error { } } } + +#[cfg(test)] +mod tests { + use super::*; + use axum::response::IntoResponse; + use pretty_assertions::assert_eq; + + #[test] + fn state_error_display() { + let err = Error::StateError("bad state".into()); + assert_eq!(err.to_string(), "bad state"); + } + + #[test] + fn database_error_display() { + let err = Error::DatabaseError("connection lost".into()); + assert_eq!(err.to_string(), "database error: connection lost"); + } + + #[test] + fn config_error_display() { + let err = Error::ConfigError("missing var".into()); + assert_eq!(err.to_string(), "configuration error: missing var"); + } + + #[test] + fn state_error_into_response_500() { + let resp = Error::StateError("oops".into()).into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[test] + fn database_error_into_response_503() { + let resp = Error::DatabaseError("down".into()).into_response(); + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); + } + + #[test] + fn config_error_into_response_500() { + let resp = Error::ConfigError("bad cfg".into()).into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[test] + fn io_error_into_response_500() { + let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "not found"); + let resp = Error::IoError(io_err).into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/src/infrastructure/searxng.rs b/src/infrastructure/searxng.rs index 713e67e..4e808e4 100644 --- a/src/infrastructure/searxng.rs +++ b/src/infrastructure/searxng.rs @@ -5,13 +5,13 @@ use dioxus::prelude::*; // The #[server] macro generates a client stub for the web build that // sends a network request instead of executing this function body. #[cfg(feature = "server")] -mod inner { +pub(crate) mod inner { use serde::Deserialize; use std::collections::HashSet; /// Individual result from the SearXNG search API. #[derive(Debug, Deserialize)] - pub(super) struct SearxngResult { + pub(crate) struct SearxngResult { pub title: String, pub url: String, pub content: Option, @@ -25,7 +25,7 @@ mod inner { /// Top-level response from the SearXNG search API. #[derive(Debug, Deserialize)] - pub(super) struct SearxngResponse { + pub(crate) struct SearxngResponse { pub results: Vec, } @@ -40,7 +40,7 @@ mod inner { /// # Returns /// /// The domain host or a fallback "Web" string - pub(super) fn extract_source(url_str: &str) -> String { + pub(crate) fn extract_source(url_str: &str) -> String { url::Url::parse(url_str) .ok() .and_then(|u| u.host_str().map(String::from)) @@ -64,7 +64,7 @@ mod inner { /// # Returns /// /// Filtered, deduplicated, and ranked results - pub(super) fn rank_and_deduplicate( + pub(crate) fn rank_and_deduplicate( mut results: Vec, max_results: usize, ) -> Vec { @@ -285,3 +285,166 @@ pub async fn get_trending_topics() -> Result, ServerFnError> { Ok(topics) } + +#[cfg(all(test, feature = "server"))] +mod tests { + #![allow(clippy::unwrap_used, clippy::expect_used)] + + use super::inner::*; + use pretty_assertions::assert_eq; + + // ----------------------------------------------------------------------- + // extract_source() + // ----------------------------------------------------------------------- + + #[test] + fn extract_source_strips_www() { + assert_eq!( + extract_source("https://www.example.com/page"), + "example.com" + ); + } + + #[test] + fn extract_source_returns_domain() { + assert_eq!( + extract_source("https://techcrunch.com/article"), + "techcrunch.com" + ); + } + + #[test] + fn extract_source_invalid_url_returns_web() { + assert_eq!(extract_source("not-a-url"), "Web"); + } + + #[test] + fn extract_source_no_scheme_returns_web() { + // url::Url::parse requires a scheme; bare domain fails + assert_eq!(extract_source("example.com/path"), "Web"); + } + + // ----------------------------------------------------------------------- + // rank_and_deduplicate() + // ----------------------------------------------------------------------- + + fn make_result(url: &str, content: &str, score: f64) -> SearxngResult { + SearxngResult { + title: "Title".into(), + url: url.into(), + content: if content.is_empty() { + None + } else { + Some(content.into()) + }, + published_date: None, + thumbnail: None, + score, + } + } + + #[test] + fn rank_filters_empty_content() { + let results = vec![ + make_result("https://a.com", "", 10.0), + make_result( + "https://b.com", + "This is meaningful content that passes the length filter", + 5.0, + ), + ]; + let ranked = rank_and_deduplicate(results, 10); + assert_eq!(ranked.len(), 1); + assert_eq!(ranked[0].url, "https://b.com"); + } + + #[test] + fn rank_filters_short_content() { + let results = vec![ + make_result("https://a.com", "short", 10.0), + make_result( + "https://b.com", + "This content is long enough to pass the 20-char filter threshold", + 5.0, + ), + ]; + let ranked = rank_and_deduplicate(results, 10); + assert_eq!(ranked.len(), 1); + } + + #[test] + fn rank_deduplicates_by_domain_keeps_highest() { + let results = vec![ + make_result( + "https://example.com/page1", + "First result with enough content here for the filter", + 3.0, + ), + make_result( + "https://example.com/page2", + "Second result with enough content here for the filter", + 8.0, + ), + ]; + let ranked = rank_and_deduplicate(results, 10); + assert_eq!(ranked.len(), 1); + // Should keep the highest-scored one (page2 with score 8.0) + assert_eq!(ranked[0].url, "https://example.com/page2"); + } + + #[test] + fn rank_sorts_by_score_descending() { + let results = vec![ + make_result( + "https://a.com/p", + "Content A that is long enough to pass the filter check", + 1.0, + ), + make_result( + "https://b.com/p", + "Content B that is long enough to pass the filter check", + 5.0, + ), + make_result( + "https://c.com/p", + "Content C that is long enough to pass the filter check", + 3.0, + ), + ]; + let ranked = rank_and_deduplicate(results, 10); + assert_eq!(ranked.len(), 3); + assert!(ranked[0].score >= ranked[1].score); + assert!(ranked[1].score >= ranked[2].score); + } + + #[test] + fn rank_truncates_to_max_results() { + let results: Vec<_> = (0..20) + .map(|i| { + make_result( + &format!("https://site{i}.com/page"), + &format!("Content for site {i} that is long enough to pass the filter"), + i as f64, + ) + }) + .collect(); + let ranked = rank_and_deduplicate(results, 5); + assert_eq!(ranked.len(), 5); + } + + #[test] + fn rank_empty_input_returns_empty() { + let ranked = rank_and_deduplicate(vec![], 10); + assert!(ranked.is_empty()); + } + + #[test] + fn rank_all_filtered_returns_empty() { + let results = vec![ + make_result("https://a.com", "", 10.0), + make_result("https://b.com", "too short", 5.0), + ]; + let ranked = rank_and_deduplicate(results, 10); + assert!(ranked.is_empty()); + } +} diff --git a/src/models/chat.rs b/src/models/chat.rs index e6f6134..aa869de 100644 --- a/src/models/chat.rs +++ b/src/models/chat.rs @@ -105,3 +105,163 @@ pub struct ChatMessage { pub attachments: Vec, pub timestamp: String, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn chat_namespace_default_is_general() { + assert_eq!(ChatNamespace::default(), ChatNamespace::General); + } + + #[test] + fn chat_role_serde_round_trip() { + for role in [ChatRole::User, ChatRole::Assistant, ChatRole::System] { + let json = + serde_json::to_string(&role).unwrap_or_else(|_| panic!("serialize {:?}", role)); + let back: ChatRole = + serde_json::from_str(&json).unwrap_or_else(|_| panic!("deserialize {:?}", role)); + assert_eq!(role, back); + } + } + + #[test] + fn chat_namespace_serde_round_trip() { + for ns in [ChatNamespace::General, ChatNamespace::News] { + let json = serde_json::to_string(&ns).unwrap_or_else(|_| panic!("serialize {:?}", ns)); + let back: ChatNamespace = + serde_json::from_str(&json).unwrap_or_else(|_| panic!("deserialize {:?}", ns)); + assert_eq!(ns, back); + } + } + + #[test] + fn attachment_kind_serde_round_trip() { + for kind in [ + AttachmentKind::Image, + AttachmentKind::Document, + AttachmentKind::Code, + ] { + let json = + serde_json::to_string(&kind).unwrap_or_else(|_| panic!("serialize {:?}", kind)); + let back: AttachmentKind = + serde_json::from_str(&json).unwrap_or_else(|_| panic!("deserialize {:?}", kind)); + assert_eq!(kind, back); + } + } + + #[test] + fn attachment_serde_round_trip() { + let att = Attachment { + name: "photo.png".into(), + kind: AttachmentKind::Image, + size_bytes: 2048, + }; + let json = serde_json::to_string(&att).expect("serialize Attachment"); + let back: Attachment = serde_json::from_str(&json).expect("deserialize Attachment"); + assert_eq!(att, back); + } + + #[test] + fn chat_session_serde_round_trip() { + let session = ChatSession { + id: "abc123".into(), + user_sub: "user-1".into(), + title: "Test Chat".into(), + namespace: ChatNamespace::General, + provider: "ollama".into(), + model: "llama3.1:8b".into(), + created_at: "2025-01-01T00:00:00Z".into(), + updated_at: "2025-01-01T01:00:00Z".into(), + article_url: None, + }; + let json = serde_json::to_string(&session).expect("serialize ChatSession"); + let back: ChatSession = serde_json::from_str(&json).expect("deserialize ChatSession"); + assert_eq!(session, back); + } + + #[test] + fn chat_session_id_alias_deserialization() { + // MongoDB returns `_id` instead of `id` + let json = r#"{ + "_id": "mongo-id", + "user_sub": "u1", + "title": "t", + "provider": "ollama", + "model": "m", + "created_at": "2025-01-01", + "updated_at": "2025-01-01" + }"#; + let session: ChatSession = serde_json::from_str(json).expect("deserialize with _id"); + assert_eq!(session.id, "mongo-id"); + } + + #[test] + fn chat_session_empty_id_skips_serialization() { + let session = ChatSession { + id: String::new(), + user_sub: "u1".into(), + title: "t".into(), + namespace: ChatNamespace::default(), + provider: "ollama".into(), + model: "m".into(), + created_at: "2025-01-01".into(), + updated_at: "2025-01-01".into(), + article_url: None, + }; + let json = serde_json::to_string(&session).expect("serialize"); + // `id` field should be absent when empty due to skip_serializing_if + assert!(!json.contains("\"id\"")); + } + + #[test] + fn chat_session_none_article_url_skips_serialization() { + let session = ChatSession { + id: "s1".into(), + user_sub: "u1".into(), + title: "t".into(), + namespace: ChatNamespace::default(), + provider: "ollama".into(), + model: "m".into(), + created_at: "2025-01-01".into(), + updated_at: "2025-01-01".into(), + article_url: None, + }; + let json = serde_json::to_string(&session).expect("serialize"); + assert!(!json.contains("article_url")); + } + + #[test] + fn chat_message_serde_round_trip() { + let msg = ChatMessage { + id: "msg-1".into(), + session_id: "s1".into(), + role: ChatRole::User, + content: "Hello AI".into(), + attachments: vec![Attachment { + name: "doc.pdf".into(), + kind: AttachmentKind::Document, + size_bytes: 4096, + }], + timestamp: "2025-01-01T00:00:00Z".into(), + }; + let json = serde_json::to_string(&msg).expect("serialize ChatMessage"); + let back: ChatMessage = serde_json::from_str(&json).expect("deserialize ChatMessage"); + assert_eq!(msg, back); + } + + #[test] + fn chat_message_id_alias_deserialization() { + let json = r#"{ + "_id": "mongo-msg-id", + "session_id": "s1", + "role": "User", + "content": "hi", + "timestamp": "2025-01-01" + }"#; + let msg: ChatMessage = serde_json::from_str(json).expect("deserialize with _id"); + assert_eq!(msg.id, "mongo-msg-id"); + } +} diff --git a/src/models/developer.rs b/src/models/developer.rs index 1138e96..9ba530d 100644 --- a/src/models/developer.rs +++ b/src/models/developer.rs @@ -45,3 +45,63 @@ pub struct AnalyticsMetric { pub value: String, pub change_pct: f64, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn agent_entry_serde_round_trip() { + let agent = AgentEntry { + id: "a1".into(), + name: "RAG Agent".into(), + description: "Retrieval-augmented generation".into(), + status: "running".into(), + }; + let json = serde_json::to_string(&agent).expect("serialize AgentEntry"); + let back: AgentEntry = serde_json::from_str(&json).expect("deserialize AgentEntry"); + assert_eq!(agent, back); + } + + #[test] + fn flow_entry_serde_round_trip() { + let flow = FlowEntry { + id: "f1".into(), + name: "Data Pipeline".into(), + node_count: 5, + last_run: Some("2025-06-01T12:00:00Z".into()), + }; + let json = serde_json::to_string(&flow).expect("serialize FlowEntry"); + let back: FlowEntry = serde_json::from_str(&json).expect("deserialize FlowEntry"); + assert_eq!(flow, back); + } + + #[test] + fn flow_entry_with_none_last_run() { + let flow = FlowEntry { + id: "f2".into(), + name: "New Flow".into(), + node_count: 0, + last_run: None, + }; + let json = serde_json::to_string(&flow).expect("serialize"); + let back: FlowEntry = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(flow, back); + assert_eq!(back.last_run, None); + } + + #[test] + fn analytics_metric_negative_change_pct() { + let metric = AnalyticsMetric { + label: "Latency".into(), + value: "120ms".into(), + change_pct: -15.5, + }; + let json = serde_json::to_string(&metric).expect("serialize AnalyticsMetric"); + let back: AnalyticsMetric = + serde_json::from_str(&json).expect("deserialize AnalyticsMetric"); + assert_eq!(metric, back); + assert!(back.change_pct < 0.0); + } +} diff --git a/src/models/news.rs b/src/models/news.rs index 833920a..ffa3930 100644 --- a/src/models/news.rs +++ b/src/models/news.rs @@ -23,3 +23,61 @@ pub struct NewsCard { pub thumbnail_url: Option, pub published_at: String, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn news_card_serde_round_trip() { + let card = NewsCard { + title: "AI Breakthrough".into(), + source: "techcrunch.com".into(), + summary: "New model released".into(), + content: "Full article content here".into(), + category: "AI".into(), + url: "https://example.com/article".into(), + thumbnail_url: Some("https://example.com/thumb.jpg".into()), + published_at: "2025-06-01".into(), + }; + let json = serde_json::to_string(&card).expect("serialize NewsCard"); + let back: NewsCard = serde_json::from_str(&json).expect("deserialize NewsCard"); + assert_eq!(card, back); + } + + #[test] + fn news_card_thumbnail_none() { + let card = NewsCard { + title: "No Thumb".into(), + source: "bbc.com".into(), + summary: "Summary".into(), + content: "Content".into(), + category: "Tech".into(), + url: "https://bbc.com/article".into(), + thumbnail_url: None, + published_at: "2025-06-01".into(), + }; + let json = serde_json::to_string(&card).expect("serialize"); + let back: NewsCard = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(card, back); + } + + #[test] + fn news_card_thumbnail_some() { + let card = NewsCard { + title: "With Thumb".into(), + source: "cnn.com".into(), + summary: "Summary".into(), + content: "Content".into(), + category: "News".into(), + url: "https://cnn.com/article".into(), + thumbnail_url: Some("https://cnn.com/img.jpg".into()), + published_at: "2025-06-01".into(), + }; + let json = serde_json::to_string(&card).expect("serialize"); + assert!(json.contains("img.jpg")); + let back: NewsCard = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(card.thumbnail_url, back.thumbnail_url); + } +} diff --git a/src/models/organization.rs b/src/models/organization.rs index 790e687..0c6745d 100644 --- a/src/models/organization.rs +++ b/src/models/organization.rs @@ -116,3 +116,122 @@ pub struct OrgBillingRecord { /// Number of tokens consumed during this cycle. pub tokens_used: u64, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn member_role_label_admin() { + assert_eq!(MemberRole::Admin.label(), "Admin"); + } + + #[test] + fn member_role_label_member() { + assert_eq!(MemberRole::Member.label(), "Member"); + } + + #[test] + fn member_role_label_viewer() { + assert_eq!(MemberRole::Viewer.label(), "Viewer"); + } + + #[test] + fn member_role_all_returns_three_in_order() { + let all = MemberRole::all(); + assert_eq!(all.len(), 3); + assert_eq!(all[0], MemberRole::Admin); + assert_eq!(all[1], MemberRole::Member); + assert_eq!(all[2], MemberRole::Viewer); + } + + #[test] + fn member_role_serde_round_trip() { + for role in MemberRole::all() { + let json = + serde_json::to_string(role).unwrap_or_else(|_| panic!("serialize {:?}", role)); + let back: MemberRole = + serde_json::from_str(&json).unwrap_or_else(|_| panic!("deserialize {:?}", role)); + assert_eq!(*role, back); + } + } + + #[test] + fn org_member_serde_round_trip() { + let member = OrgMember { + id: "m1".into(), + name: "Alice".into(), + email: "alice@example.com".into(), + role: MemberRole::Admin, + joined_at: "2025-01-01T00:00:00Z".into(), + }; + let json = serde_json::to_string(&member).expect("serialize OrgMember"); + let back: OrgMember = serde_json::from_str(&json).expect("deserialize OrgMember"); + assert_eq!(member, back); + } + + #[test] + fn pricing_plan_with_max_seats() { + let plan = PricingPlan { + id: "team".into(), + name: "Team".into(), + price_eur: 49, + features: vec!["SSO".into(), "Priority".into()], + highlighted: true, + max_seats: Some(25), + }; + let json = serde_json::to_string(&plan).expect("serialize PricingPlan"); + let back: PricingPlan = serde_json::from_str(&json).expect("deserialize PricingPlan"); + assert_eq!(plan, back); + } + + #[test] + fn pricing_plan_without_max_seats() { + let plan = PricingPlan { + id: "enterprise".into(), + name: "Enterprise".into(), + price_eur: 199, + features: vec!["Unlimited".into()], + highlighted: false, + max_seats: None, + }; + let json = serde_json::to_string(&plan).expect("serialize PricingPlan"); + let back: PricingPlan = serde_json::from_str(&json).expect("deserialize PricingPlan"); + assert_eq!(plan, back); + assert!(json.contains("null") || !json.contains("max_seats")); + } + + #[test] + fn billing_usage_serde_round_trip() { + let usage = BillingUsage { + seats_used: 5, + seats_total: 10, + tokens_used: 1_000_000, + tokens_limit: 5_000_000, + billing_cycle_end: "2025-12-31".into(), + }; + let json = serde_json::to_string(&usage).expect("serialize BillingUsage"); + let back: BillingUsage = serde_json::from_str(&json).expect("deserialize BillingUsage"); + assert_eq!(usage, back); + } + + #[test] + fn org_settings_default() { + let settings = OrgSettings::default(); + assert_eq!(settings.org_id, ""); + assert_eq!(settings.plan_id, ""); + assert!(settings.enabled_features.is_empty()); + assert_eq!(settings.stripe_customer_id, ""); + } + + #[test] + fn org_billing_record_default() { + let record = OrgBillingRecord::default(); + assert_eq!(record.org_id, ""); + assert_eq!(record.cycle_start, ""); + assert_eq!(record.cycle_end, ""); + assert_eq!(record.seats_used, 0); + assert_eq!(record.tokens_used, 0); + } +} diff --git a/src/models/provider.rs b/src/models/provider.rs index a08a637..48ee498 100644 --- a/src/models/provider.rs +++ b/src/models/provider.rs @@ -72,3 +72,84 @@ pub struct ProviderConfig { pub selected_embedding: String, pub api_key_set: bool, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn llm_provider_label_ollama() { + assert_eq!(LlmProvider::Ollama.label(), "Ollama"); + } + + #[test] + fn llm_provider_label_hugging_face() { + assert_eq!(LlmProvider::HuggingFace.label(), "Hugging Face"); + } + + #[test] + fn llm_provider_label_openai() { + assert_eq!(LlmProvider::OpenAi.label(), "OpenAI"); + } + + #[test] + fn llm_provider_label_anthropic() { + assert_eq!(LlmProvider::Anthropic.label(), "Anthropic"); + } + + #[test] + fn llm_provider_serde_round_trip() { + for variant in [ + LlmProvider::Ollama, + LlmProvider::HuggingFace, + LlmProvider::OpenAi, + LlmProvider::Anthropic, + ] { + let json = serde_json::to_string(&variant) + .unwrap_or_else(|_| panic!("serialize {:?}", variant)); + let back: LlmProvider = + serde_json::from_str(&json).unwrap_or_else(|_| panic!("deserialize {:?}", variant)); + assert_eq!(variant, back); + } + } + + #[test] + fn model_entry_serde_round_trip() { + let entry = ModelEntry { + id: "llama3.1:8b".into(), + name: "Llama 3.1 8B".into(), + provider: LlmProvider::Ollama, + context_window: 8192, + }; + let json = serde_json::to_string(&entry).expect("serialize ModelEntry"); + let back: ModelEntry = serde_json::from_str(&json).expect("deserialize ModelEntry"); + assert_eq!(entry, back); + } + + #[test] + fn embedding_entry_serde_round_trip() { + let entry = EmbeddingEntry { + id: "nomic-embed".into(), + name: "Nomic Embed".into(), + provider: LlmProvider::HuggingFace, + dimensions: 768, + }; + let json = serde_json::to_string(&entry).expect("serialize EmbeddingEntry"); + let back: EmbeddingEntry = serde_json::from_str(&json).expect("deserialize EmbeddingEntry"); + assert_eq!(entry, back); + } + + #[test] + fn provider_config_serde_round_trip() { + let cfg = ProviderConfig { + provider: LlmProvider::Anthropic, + selected_model: "claude-3".into(), + selected_embedding: "embed-v1".into(), + api_key_set: true, + }; + let json = serde_json::to_string(&cfg).expect("serialize ProviderConfig"); + let back: ProviderConfig = serde_json::from_str(&json).expect("deserialize ProviderConfig"); + assert_eq!(cfg, back); + } +} diff --git a/src/models/user.rs b/src/models/user.rs index a3367bd..5bbc8f9 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -70,3 +70,81 @@ pub struct UserPreferences { #[serde(default)] pub provider_config: UserProviderConfig, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn user_data_default() { + let ud = UserData::default(); + assert_eq!(ud.name, ""); + } + + #[test] + fn auth_info_default_not_authenticated() { + let info = AuthInfo::default(); + assert!(!info.authenticated); + assert_eq!(info.sub, ""); + assert_eq!(info.email, ""); + assert_eq!(info.name, ""); + assert_eq!(info.avatar_url, ""); + assert_eq!(info.librechat_url, ""); + } + + #[test] + fn auth_info_serde_round_trip() { + let info = AuthInfo { + authenticated: true, + sub: "sub-123".into(), + email: "test@example.com".into(), + name: "Test User".into(), + avatar_url: "https://example.com/avatar.png".into(), + librechat_url: "https://chat.example.com".into(), + }; + let json = serde_json::to_string(&info).expect("serialize AuthInfo"); + let back: AuthInfo = serde_json::from_str(&json).expect("deserialize AuthInfo"); + assert_eq!(info, back); + } + + #[test] + fn user_preferences_default() { + let prefs = UserPreferences::default(); + assert_eq!(prefs.sub, ""); + assert_eq!(prefs.org_id, ""); + assert!(prefs.custom_topics.is_empty()); + assert!(prefs.recent_searches.is_empty()); + } + + #[test] + fn user_provider_config_optional_keys_skip_none() { + let cfg = UserProviderConfig { + default_provider: "ollama".into(), + default_model: "llama3.1:8b".into(), + openai_api_key: None, + anthropic_api_key: None, + huggingface_api_key: None, + ollama_url_override: String::new(), + }; + let json = serde_json::to_string(&cfg).expect("serialize UserProviderConfig"); + assert!(!json.contains("openai_api_key")); + assert!(!json.contains("anthropic_api_key")); + assert!(!json.contains("huggingface_api_key")); + } + + #[test] + fn user_provider_config_serde_round_trip_with_keys() { + let cfg = UserProviderConfig { + default_provider: "openai".into(), + default_model: "gpt-4o".into(), + openai_api_key: Some("sk-test".into()), + anthropic_api_key: Some("ak-test".into()), + huggingface_api_key: None, + ollama_url_override: "http://custom:11434".into(), + }; + let json = serde_json::to_string(&cfg).expect("serialize"); + let back: UserProviderConfig = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(cfg, back); + } +}