Files
certifai/src/models/provider.rs
Sharang Parnerkar 01b285ba20
All checks were successful
CI / Format (push) Successful in 3s
CI / Clippy (push) Successful in 2m56s
CI / Security Audit (push) Has been skipped
CI / Tests (push) Has been skipped
CI / Deploy (push) Has been skipped
CI / Format (pull_request) Successful in 22s
CI / Clippy (pull_request) Successful in 2m51s
CI / Security Audit (pull_request) Has been skipped
CI / Tests (pull_request) Has been skipped
CI / Deploy (pull_request) Has been skipped
test: add comprehensive unit test suite (~85 new tests)
Add unit tests across all model and server infrastructure layers,
increasing test count from 7 to 92. Covers serde round-trips, enum
methods, defaults, config parsing, error mapping, PKCE crypto (with
RFC 7636 test vector), OAuth store, and SearXNG ranking/dedup logic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 18:43:27 +01:00

156 lines
4.5 KiB
Rust

use serde::{Deserialize, Serialize};
/// Supported LLM provider backends.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum LlmProvider {
/// Self-hosted models via Ollama
Ollama,
/// Hugging Face Inference API
HuggingFace,
/// OpenAI-compatible endpoints
OpenAi,
/// Anthropic Claude API
Anthropic,
}
impl LlmProvider {
/// Returns the display name for a provider.
pub fn label(&self) -> &'static str {
match self {
Self::Ollama => "Ollama",
Self::HuggingFace => "Hugging Face",
Self::OpenAi => "OpenAI",
Self::Anthropic => "Anthropic",
}
}
}
/// A model available from a provider.
///
/// # Fields
///
/// * `id` - Unique model identifier (e.g. "llama3.1:8b")
/// * `name` - Human-readable display name
/// * `provider` - Which provider hosts this model
/// * `context_window` - Maximum context length in tokens
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ModelEntry {
pub id: String,
pub name: String,
pub provider: LlmProvider,
pub context_window: u32,
}
/// An embedding model available from a provider.
///
/// # Fields
///
/// * `id` - Unique embedding model identifier
/// * `name` - Human-readable display name
/// * `provider` - Which provider hosts this model
/// * `dimensions` - Output embedding dimensions
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingEntry {
pub id: String,
pub name: String,
pub provider: LlmProvider,
pub dimensions: u32,
}
/// Active provider configuration state.
///
/// # Fields
///
/// * `provider` - Currently selected provider
/// * `selected_model` - ID of the active chat model
/// * `selected_embedding` - ID of the active embedding model
/// * `api_key_set` - Whether an API key has been configured
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProviderConfig {
pub provider: LlmProvider,
pub selected_model: String,
pub selected_embedding: String,
pub api_key_set: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn llm_provider_label_ollama() {
assert_eq!(LlmProvider::Ollama.label(), "Ollama");
}
#[test]
fn llm_provider_label_hugging_face() {
assert_eq!(LlmProvider::HuggingFace.label(), "Hugging Face");
}
#[test]
fn llm_provider_label_openai() {
assert_eq!(LlmProvider::OpenAi.label(), "OpenAI");
}
#[test]
fn llm_provider_label_anthropic() {
assert_eq!(LlmProvider::Anthropic.label(), "Anthropic");
}
#[test]
fn llm_provider_serde_round_trip() {
for variant in [
LlmProvider::Ollama,
LlmProvider::HuggingFace,
LlmProvider::OpenAi,
LlmProvider::Anthropic,
] {
let json = serde_json::to_string(&variant)
.unwrap_or_else(|_| panic!("serialize {:?}", variant));
let back: LlmProvider =
serde_json::from_str(&json).unwrap_or_else(|_| panic!("deserialize {:?}", variant));
assert_eq!(variant, back);
}
}
#[test]
fn model_entry_serde_round_trip() {
let entry = ModelEntry {
id: "llama3.1:8b".into(),
name: "Llama 3.1 8B".into(),
provider: LlmProvider::Ollama,
context_window: 8192,
};
let json = serde_json::to_string(&entry).expect("serialize ModelEntry");
let back: ModelEntry = serde_json::from_str(&json).expect("deserialize ModelEntry");
assert_eq!(entry, back);
}
#[test]
fn embedding_entry_serde_round_trip() {
let entry = EmbeddingEntry {
id: "nomic-embed".into(),
name: "Nomic Embed".into(),
provider: LlmProvider::HuggingFace,
dimensions: 768,
};
let json = serde_json::to_string(&entry).expect("serialize EmbeddingEntry");
let back: EmbeddingEntry = serde_json::from_str(&json).expect("deserialize EmbeddingEntry");
assert_eq!(entry, back);
}
#[test]
fn provider_config_serde_round_trip() {
let cfg = ProviderConfig {
provider: LlmProvider::Anthropic,
selected_model: "claude-3".into(),
selected_embedding: "embed-v1".into(),
api_key_set: true,
};
let json = serde_json::to_string(&cfg).expect("serialize ProviderConfig");
let back: ProviderConfig = serde_json::from_str(&json).expect("deserialize ProviderConfig");
assert_eq!(cfg, back);
}
}