diff --git a/src/infrastructure/chat.rs b/src/infrastructure/chat.rs index 03c3015..5b5e99a 100644 --- a/src/infrastructure/chat.rs +++ b/src/infrastructure/chat.rs @@ -440,7 +440,12 @@ pub async fn chat_complete( let session = doc_to_chat_session(&session_doc); // Resolve provider URL and model - let (base_url, model) = resolve_provider_url(&state, &session.provider, &session.model); + let (base_url, model) = resolve_provider_url( + &state.services.ollama_url, + &state.services.ollama_model, + &session.provider, + &session.model, + ); // Parse messages from JSON let chat_msgs: Vec = serde_json::from_str(&messages_json) @@ -480,10 +485,22 @@ pub async fn chat_complete( .ok_or_else(|| ServerFnError::new("empty LLM response")) } -/// Resolve the base URL for a provider, falling back to server defaults. +/// Resolve the base URL for a provider, falling back to Ollama defaults. +/// +/// # Arguments +/// +/// * `ollama_url` - Default Ollama base URL from config +/// * `ollama_model` - Default Ollama model from config +/// * `provider` - Provider name (e.g. "openai", "anthropic", "huggingface") +/// * `model` - Model ID (may be empty for Ollama default) +/// +/// # Returns +/// +/// A `(base_url, model)` tuple resolved for the given provider. #[cfg(feature = "server")] -fn resolve_provider_url( - state: &crate::infrastructure::ServerState, +pub(crate) fn resolve_provider_url( + ollama_url: &str, + ollama_model: &str, provider: &str, model: &str, ) -> (String, String) { @@ -496,12 +513,229 @@ fn resolve_provider_url( ), // Default to Ollama _ => ( - state.services.ollama_url.clone(), + ollama_url.to_string(), if model.is_empty() { - state.services.ollama_model.clone() + ollama_model.to_string() } else { model.to_string() }, ), } } + +#[cfg(test)] +mod tests { + // ----------------------------------------------------------------------- + // BSON document conversion tests (server feature required) + // ----------------------------------------------------------------------- + + #[cfg(feature = "server")] + mod server_tests { + use super::super::{doc_to_chat_message, doc_to_chat_session, resolve_provider_url}; + use crate::models::{ChatNamespace, ChatRole}; + use mongodb::bson::{doc, oid::ObjectId, Document}; + use pretty_assertions::assert_eq; + + // -- doc_to_chat_session -- + + fn sample_session_doc() -> (ObjectId, Document) { + let oid = ObjectId::new(); + let doc = doc! { + "_id": oid, + "user_sub": "user-42", + "title": "Test Session", + "namespace": "News", + "provider": "openai", + "model": "gpt-4", + "created_at": "2025-01-01T00:00:00Z", + "updated_at": "2025-01-02T00:00:00Z", + "article_url": "https://example.com/article", + }; + (oid, doc) + } + + #[test] + fn doc_to_chat_session_extracts_id_as_hex() { + let (oid, doc) = sample_session_doc(); + let session = doc_to_chat_session(&doc); + assert_eq!(session.id, oid.to_hex()); + } + + #[test] + fn doc_to_chat_session_maps_news_namespace() { + let (_, doc) = sample_session_doc(); + let session = doc_to_chat_session(&doc); + assert_eq!(session.namespace, ChatNamespace::News); + } + + #[test] + fn doc_to_chat_session_defaults_to_general_for_unknown() { + let mut doc = sample_session_doc().1; + doc.insert("namespace", "SomethingElse"); + let session = doc_to_chat_session(&doc); + assert_eq!(session.namespace, ChatNamespace::General); + } + + #[test] + fn doc_to_chat_session_extracts_all_string_fields() { + let (_, doc) = sample_session_doc(); + let session = doc_to_chat_session(&doc); + assert_eq!(session.user_sub, "user-42"); + assert_eq!(session.title, "Test Session"); + assert_eq!(session.provider, "openai"); + assert_eq!(session.model, "gpt-4"); + assert_eq!(session.created_at, "2025-01-01T00:00:00Z"); + assert_eq!(session.updated_at, "2025-01-02T00:00:00Z"); + } + + #[test] + fn doc_to_chat_session_handles_missing_article_url() { + let oid = ObjectId::new(); + let doc = doc! { + "_id": oid, + "user_sub": "u", + "title": "t", + "provider": "ollama", + "model": "m", + "created_at": "c", + "updated_at": "u", + }; + let session = doc_to_chat_session(&doc); + assert_eq!(session.article_url, None); + } + + #[test] + fn doc_to_chat_session_filters_empty_article_url() { + let oid = ObjectId::new(); + let doc = doc! { + "_id": oid, + "user_sub": "u", + "title": "t", + "namespace": "News", + "provider": "ollama", + "model": "m", + "created_at": "c", + "updated_at": "u", + "article_url": "", + }; + let session = doc_to_chat_session(&doc); + assert_eq!(session.article_url, None); + } + + // -- doc_to_chat_message -- + + fn sample_message_doc() -> (ObjectId, Document) { + let oid = ObjectId::new(); + let doc = doc! { + "_id": oid, + "session_id": "sess-1", + "role": "Assistant", + "content": "Hello there!", + "timestamp": "2025-01-01T12:00:00Z", + }; + (oid, doc) + } + + #[test] + fn doc_to_chat_message_extracts_id_as_hex() { + let (oid, doc) = sample_message_doc(); + let msg = doc_to_chat_message(&doc); + assert_eq!(msg.id, oid.to_hex()); + } + + #[test] + fn doc_to_chat_message_maps_assistant_role() { + let (_, doc) = sample_message_doc(); + let msg = doc_to_chat_message(&doc); + assert_eq!(msg.role, ChatRole::Assistant); + } + + #[test] + fn doc_to_chat_message_maps_system_role() { + let mut doc = sample_message_doc().1; + doc.insert("role", "System"); + let msg = doc_to_chat_message(&doc); + assert_eq!(msg.role, ChatRole::System); + } + + #[test] + fn doc_to_chat_message_defaults_to_user_for_unknown() { + let mut doc = sample_message_doc().1; + doc.insert("role", "SomethingElse"); + let msg = doc_to_chat_message(&doc); + assert_eq!(msg.role, ChatRole::User); + } + + #[test] + fn doc_to_chat_message_extracts_content_and_timestamp() { + let (_, doc) = sample_message_doc(); + let msg = doc_to_chat_message(&doc); + assert_eq!(msg.content, "Hello there!"); + assert_eq!(msg.timestamp, "2025-01-01T12:00:00Z"); + assert_eq!(msg.session_id, "sess-1"); + } + + #[test] + fn doc_to_chat_message_attachments_always_empty() { + let (_, doc) = sample_message_doc(); + let msg = doc_to_chat_message(&doc); + assert!(msg.attachments.is_empty()); + } + + // -- resolve_provider_url -- + + const TEST_OLLAMA_URL: &str = "http://localhost:11434"; + const TEST_OLLAMA_MODEL: &str = "llama3.1:8b"; + + #[test] + fn resolve_openai_returns_api_openai() { + let (url, model) = + resolve_provider_url(TEST_OLLAMA_URL, TEST_OLLAMA_MODEL, "openai", "gpt-4o"); + assert_eq!(url, "https://api.openai.com"); + assert_eq!(model, "gpt-4o"); + } + + #[test] + fn resolve_anthropic_returns_api_anthropic() { + let (url, model) = resolve_provider_url( + TEST_OLLAMA_URL, + TEST_OLLAMA_MODEL, + "anthropic", + "claude-3-opus", + ); + assert_eq!(url, "https://api.anthropic.com"); + assert_eq!(model, "claude-3-opus"); + } + + #[test] + fn resolve_huggingface_returns_model_url() { + let (url, model) = resolve_provider_url( + TEST_OLLAMA_URL, + TEST_OLLAMA_MODEL, + "huggingface", + "meta-llama/Llama-2-7b", + ); + assert_eq!( + url, + "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b" + ); + assert_eq!(model, "meta-llama/Llama-2-7b"); + } + + #[test] + fn resolve_unknown_defaults_to_ollama() { + let (url, model) = + resolve_provider_url(TEST_OLLAMA_URL, TEST_OLLAMA_MODEL, "ollama", "mistral:7b"); + assert_eq!(url, TEST_OLLAMA_URL); + assert_eq!(model, "mistral:7b"); + } + + #[test] + fn resolve_empty_model_falls_back_to_server_default() { + let (url, model) = + resolve_provider_url(TEST_OLLAMA_URL, TEST_OLLAMA_MODEL, "ollama", ""); + assert_eq!(url, TEST_OLLAMA_URL); + assert_eq!(model, TEST_OLLAMA_MODEL); + } + } +} diff --git a/src/infrastructure/llm.rs b/src/infrastructure/llm.rs index 07379c0..b68e2ab 100644 --- a/src/infrastructure/llm.rs +++ b/src/infrastructure/llm.rs @@ -72,7 +72,25 @@ mod inner { } let html = resp.text().await.ok()?; - let document = scraper::Html::parse_document(&html); + parse_article_html(&html) + } + + /// Parse article text from raw HTML without any network I/O. + /// + /// Uses a tiered extraction strategy: + /// 1. Try content within `
`, `
`, or `[role="main"]` + /// 2. Fall back to all `

` tags outside excluded containers + /// + /// # Arguments + /// + /// * `html` - Raw HTML string to parse + /// + /// # Returns + /// + /// The extracted text, or `None` if extraction yields < 100 chars. + /// Output is capped at 8000 characters. + pub(crate) fn parse_article_html(html: &str) -> Option { + let document = scraper::Html::parse_document(html); // Strategy 1: Extract from semantic article containers. // Most news sites wrap the main content in

,
, @@ -134,7 +152,7 @@ mod inner { } /// Sum the total character length of all collected text parts. - fn joined_len(parts: &[String]) -> usize { + pub(crate) fn joined_len(parts: &[String]) -> usize { parts.iter().map(|s| s.len()).sum() } } @@ -325,3 +343,150 @@ pub async fn chat_followup( .map(|choice| choice.message.content.clone()) .ok_or_else(|| ServerFnError::new("Empty response from Ollama")) } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + // ----------------------------------------------------------------------- + // FollowUpMessage serde tests + // ----------------------------------------------------------------------- + + #[test] + fn followup_message_serde_round_trip() { + let msg = FollowUpMessage { + role: "assistant".into(), + content: "Here is my answer.".into(), + }; + let json = serde_json::to_string(&msg).expect("serialize FollowUpMessage"); + let back: FollowUpMessage = + serde_json::from_str(&json).expect("deserialize FollowUpMessage"); + assert_eq!(msg, back); + } + + #[test] + fn followup_message_deserialize_from_json_literal() { + let json = r#"{"role":"system","content":"You are helpful."}"#; + let msg: FollowUpMessage = serde_json::from_str(json).expect("deserialize literal"); + assert_eq!(msg.role, "system"); + assert_eq!(msg.content, "You are helpful."); + } + + // ----------------------------------------------------------------------- + // joined_len and parse_article_html tests (server feature required) + // ----------------------------------------------------------------------- + + #[cfg(feature = "server")] + mod server_tests { + use super::super::inner::{joined_len, parse_article_html}; + use pretty_assertions::assert_eq; + + #[test] + fn joined_len_empty_input() { + assert_eq!(joined_len(&[]), 0); + } + + #[test] + fn joined_len_sums_correctly() { + let parts = vec!["abc".into(), "de".into(), "fghij".into()]; + assert_eq!(joined_len(&parts), 10); + } + + // ------------------------------------------------------------------- + // parse_article_html tests + // ------------------------------------------------------------------- + + // Helper: generate a string of given length from a repeated word. + fn lorem(len: usize) -> String { + "Lorem ipsum dolor sit amet consectetur adipiscing elit " + .repeat((len / 55) + 1) + .chars() + .take(len) + .collect() + } + + #[test] + fn article_tag_extracts_text() { + let body = lorem(250); + let html = format!("

{body}

"); + let result = parse_article_html(&html); + assert!(result.is_some(), "expected Some for article tag"); + assert!(result.unwrap().contains("Lorem")); + } + + #[test] + fn main_tag_extracts_text() { + let body = lorem(250); + let html = format!("

{body}

"); + let result = parse_article_html(&html); + assert!(result.is_some(), "expected Some for main tag"); + } + + #[test] + fn fallback_to_p_tags_when_article_main_yield_little() { + // No
/
, so falls back to

tags + let body = lorem(250); + let html = format!("

{body}

"); + let result = parse_article_html(&html); + assert!(result.is_some(), "expected fallback to

tags"); + } + + #[test] + fn excludes_nav_footer_aside_content() { + // Content only inside excluded containers -- should be excluded + let body = lorem(250); + let html = format!( + "\ +

\ +

{body}

\ + \ + " + ); + let result = parse_article_html(&html); + assert!(result.is_none(), "expected None for excluded-only content"); + } + + #[test] + fn returns_none_when_text_too_short() { + let html = "

Short.

"; + let result = parse_article_html(html); + assert!(result.is_none(), "expected None for short text"); + } + + #[test] + fn truncates_at_8000_chars() { + let body = lorem(10000); + let html = format!("

{body}

"); + let result = parse_article_html(&html).expect("expected Some"); + assert!( + result.len() <= 8000, + "expected <= 8000 chars, got {}", + result.len() + ); + } + + #[test] + fn skips_fragments_under_30_chars() { + // Only fragments < 30 chars -- should yield None + let html = "
\ +

Short frag one

\ +

Another tiny bit

\ +
"; + let result = parse_article_html(html); + assert!(result.is_none(), "expected None for tiny fragments"); + } + + #[test] + fn extracts_from_role_main_attribute() { + let body = lorem(250); + let html = format!( + "\ +

{body}

\ + " + ); + let result = parse_article_html(&html); + assert!(result.is_some(), "expected Some for role=main"); + } + } +} diff --git a/src/infrastructure/provider_client.rs b/src/infrastructure/provider_client.rs index ce915b1..804eba6 100644 --- a/src/infrastructure/provider_client.rs +++ b/src/infrastructure/provider_client.rs @@ -146,3 +146,30 @@ pub async fn send_chat_request( } } } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn provider_message_serde_round_trip() { + let msg = ProviderMessage { + role: "assistant".into(), + content: "Hello, world!".into(), + }; + let json = serde_json::to_string(&msg).expect("serialize ProviderMessage"); + let back: ProviderMessage = + serde_json::from_str(&json).expect("deserialize ProviderMessage"); + assert_eq!(msg.role, back.role); + assert_eq!(msg.content, back.content); + } + + #[test] + fn provider_message_deserialize_from_json_literal() { + let json = r#"{"role":"user","content":"What is Rust?"}"#; + let msg: ProviderMessage = serde_json::from_str(json).expect("deserialize from literal"); + assert_eq!(msg.role, "user"); + assert_eq!(msg.content, "What is Rust?"); + } +} diff --git a/src/infrastructure/state.rs b/src/infrastructure/state.rs index d6c2bc1..9d3a75e 100644 --- a/src/infrastructure/state.rs +++ b/src/infrastructure/state.rs @@ -44,3 +44,91 @@ pub struct User { /// Avatar / profile picture URL. pub avatar_url: String, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn user_state_inner_default_has_empty_strings() { + let inner = UserStateInner::default(); + assert_eq!(inner.sub, ""); + assert_eq!(inner.access_token, ""); + assert_eq!(inner.refresh_token, ""); + assert_eq!(inner.user.email, ""); + assert_eq!(inner.user.name, ""); + assert_eq!(inner.user.avatar_url, ""); + } + + #[test] + fn user_default_has_empty_strings() { + let user = User::default(); + assert_eq!(user.email, ""); + assert_eq!(user.name, ""); + assert_eq!(user.avatar_url, ""); + } + + #[test] + fn user_state_inner_serde_round_trip() { + let inner = UserStateInner { + sub: "user-123".into(), + access_token: "tok-abc".into(), + refresh_token: "ref-xyz".into(), + user: User { + email: "a@b.com".into(), + name: "Alice".into(), + avatar_url: "https://img.example.com/a.png".into(), + }, + }; + let json = serde_json::to_string(&inner).expect("serialize UserStateInner"); + let back: UserStateInner = serde_json::from_str(&json).expect("deserialize UserStateInner"); + assert_eq!(inner.sub, back.sub); + assert_eq!(inner.access_token, back.access_token); + assert_eq!(inner.refresh_token, back.refresh_token); + assert_eq!(inner.user.email, back.user.email); + assert_eq!(inner.user.name, back.user.name); + assert_eq!(inner.user.avatar_url, back.user.avatar_url); + } + + #[test] + fn user_state_from_inner_and_deref() { + let inner = UserStateInner { + sub: "sub-1".into(), + access_token: "at".into(), + refresh_token: "rt".into(), + user: User { + email: "e@e.com".into(), + name: "Eve".into(), + avatar_url: "".into(), + }, + }; + let state = UserState::from(inner); + // Deref should give access to inner fields + assert_eq!(state.sub, "sub-1"); + assert_eq!(state.user.name, "Eve"); + } + + #[test] + fn user_serde_round_trip() { + let user = User { + email: "bob@test.com".into(), + name: "Bob".into(), + avatar_url: "https://avatars.io/bob".into(), + }; + let json = serde_json::to_string(&user).expect("serialize User"); + let back: User = serde_json::from_str(&json).expect("deserialize User"); + assert_eq!(user.email, back.email); + assert_eq!(user.name, back.name); + assert_eq!(user.avatar_url, back.avatar_url); + } + + #[test] + fn user_state_clone_is_cheap() { + let inner = UserStateInner::default(); + let state = UserState::from(inner); + let cloned = state.clone(); + // Both point to the same Arc allocation + assert_eq!(state.sub, cloned.sub); + } +}