Files
certifai/src/infrastructure/litellm.rs
Sharang Parnerkar fe4f8e84ae
Some checks failed
CI / Format (push) Successful in 3s
CI / Clippy (push) Successful in 2m53s
CI / Security Audit (push) Successful in 1m42s
CI / Tests (push) Failing after 3m59s
CI / Deploy (push) Has been skipped
CI / E2E Tests (push) Has been skipped
feat: replaced ollama with litellm (#18)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com>
Reviewed-on: #18
2026-02-26 17:52:47 +00:00

404 lines
12 KiB
Rust

#[cfg(feature = "server")]
use std::collections::HashMap;
use dioxus::prelude::*;
use serde::{Deserialize, Serialize};
use crate::models::LitellmUsageStats;
#[cfg(feature = "server")]
use crate::models::ModelUsage;
/// Status of a LiteLLM proxy instance, including connectivity and available models.
///
/// # Fields
///
/// * `online` - Whether the LiteLLM API responded successfully
/// * `models` - List of model IDs available through the proxy
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LitellmStatus {
pub online: bool,
pub models: Vec<String>,
}
/// Response from LiteLLM's `GET /v1/models` endpoint (OpenAI-compatible).
#[cfg(feature = "server")]
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ModelObject>,
}
/// A single model entry from the OpenAI-compatible models list.
#[cfg(feature = "server")]
#[derive(Deserialize)]
struct ModelObject {
id: String,
}
/// Check the status of a LiteLLM proxy by querying its models endpoint.
///
/// Calls `GET <litellm_url>/v1/models` to list available models and determine
/// whether the instance is reachable. Sends the API key as a Bearer token
/// if configured.
///
/// # Arguments
///
/// * `litellm_url` - Base URL of the LiteLLM proxy (e.g. "http://localhost:4000")
///
/// # Returns
///
/// A `LitellmStatus` with `online: true` and model IDs if reachable,
/// or `online: false` with an empty model list on failure
///
/// # Errors
///
/// Returns `ServerFnError` only on serialization issues; network failures
/// are caught and returned as `online: false`
#[post("/api/litellm-status")]
pub async fn get_litellm_status(litellm_url: String) -> Result<LitellmStatus, ServerFnError> {
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let base_url = if litellm_url.is_empty() {
state.services.litellm_url.clone()
} else {
litellm_url
};
let api_key = state.services.litellm_api_key.clone();
let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| ServerFnError::new(format!("HTTP client error: {e}")))?;
let mut request = client.get(&url);
if !api_key.is_empty() {
request = request.header("Authorization", format!("Bearer {api_key}"));
}
let resp = match request.send().await {
Ok(r) if r.status().is_success() => r,
_ => {
return Ok(LitellmStatus {
online: false,
models: Vec::new(),
});
}
};
let body: ModelsResponse = match resp.json().await {
Ok(b) => b,
Err(_) => {
return Ok(LitellmStatus {
online: true,
models: Vec::new(),
});
}
};
let models = body.data.into_iter().map(|m| m.id).collect();
Ok(LitellmStatus {
online: true,
models,
})
}
/// Response from LiteLLM's `GET /global/activity` endpoint.
///
/// Returns aggregate token counts and API request totals for a date range.
/// Available on the free tier (no Enterprise license needed).
#[cfg(feature = "server")]
#[derive(Debug, Deserialize)]
struct ActivityResponse {
/// Total tokens across all models in the date range
#[serde(default)]
sum_total_tokens: u64,
}
/// Per-model entry from `GET /global/activity/model`.
///
/// Each entry contains a model name and its aggregated token total.
#[cfg(feature = "server")]
#[derive(Debug, Deserialize)]
struct ActivityModelEntry {
/// Model identifier (may be empty for unattributed traffic)
#[serde(default)]
model: String,
/// Sum of tokens used by this model in the date range
#[serde(default)]
sum_total_tokens: u64,
}
/// Per-model spend entry from `GET /global/spend/models`.
///
/// Each entry maps a model name to its total spend in USD.
#[cfg(feature = "server")]
#[derive(Debug, Deserialize)]
struct SpendModelEntry {
/// Model identifier
#[serde(default)]
model: String,
/// Total spend in USD
#[serde(default)]
total_spend: f64,
}
/// Merge per-model token counts and spend data into `ModelUsage` entries.
///
/// Joins `activity_models` (tokens) and `spend_models` (spend) by model
/// name using a HashMap for O(n + m) merge. Entries with empty model
/// names are skipped.
///
/// # Arguments
///
/// * `activity_models` - Per-model token data from `/global/activity/model`
/// * `spend_models` - Per-model spend data from `/global/spend/models`
///
/// # Returns
///
/// Merged list sorted by total tokens descending
#[cfg(feature = "server")]
fn merge_model_data(
activity_models: Vec<ActivityModelEntry>,
spend_models: Vec<SpendModelEntry>,
) -> Vec<ModelUsage> {
let mut model_map: HashMap<String, ModelUsage> = HashMap::new();
for entry in activity_models {
if entry.model.is_empty() {
continue;
}
model_map
.entry(entry.model.clone())
.or_insert_with(|| ModelUsage {
model: entry.model,
..Default::default()
})
.total_tokens = entry.sum_total_tokens;
}
for entry in spend_models {
if entry.model.is_empty() {
continue;
}
model_map
.entry(entry.model.clone())
.or_insert_with(|| ModelUsage {
model: entry.model,
..Default::default()
})
.spend = entry.total_spend;
}
let mut result: Vec<ModelUsage> = model_map.into_values().collect();
result.sort_by(|a, b| b.total_tokens.cmp(&a.total_tokens));
result
}
/// Fetch aggregated usage statistics from LiteLLM's free-tier APIs.
///
/// Combines three endpoints to build a complete usage picture:
/// - `GET /global/activity` - total token counts
/// - `GET /global/activity/model` - per-model token breakdown
/// - `GET /global/spend/models` - per-model spend in USD
///
/// # Arguments
///
/// * `start_date` - Start of the reporting period in `YYYY-MM-DD` format
/// * `end_date` - End of the reporting period in `YYYY-MM-DD` format
///
/// # Returns
///
/// Aggregated usage stats; returns default (zeroed) stats on network
/// failure or permission errors
///
/// # Errors
///
/// Returns `ServerFnError` only on HTTP client construction failure
#[post("/api/litellm-usage")]
pub async fn get_litellm_usage(
start_date: String,
end_date: String,
) -> Result<LitellmUsageStats, ServerFnError> {
let state: crate::infrastructure::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let base_url = &state.services.litellm_url;
let api_key = &state.services.litellm_api_key;
if base_url.is_empty() {
return Ok(LitellmUsageStats::default());
}
let base = base_url.trim_end_matches('/');
let date_params = format!("start_date={start_date}&end_date={end_date}");
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| ServerFnError::new(format!("HTTP client error: {e}")))?;
// Helper closure to build an authenticated GET request
let auth_get = |url: String| {
let mut req = client.get(url);
if !api_key.is_empty() {
req = req.header("Authorization", format!("Bearer {api_key}"));
}
req
};
// Fire all three requests concurrently to minimise latency
let (activity_res, model_activity_res, model_spend_res) = tokio::join!(
auth_get(format!("{base}/global/activity?{date_params}")).send(),
auth_get(format!("{base}/global/activity/model?{date_params}")).send(),
auth_get(format!("{base}/global/spend/models?{date_params}")).send(),
);
// Parse total token count from /global/activity
let total_tokens = match activity_res {
Ok(r) if r.status().is_success() => r
.json::<ActivityResponse>()
.await
.map(|a| a.sum_total_tokens)
.unwrap_or(0),
_ => 0,
};
// Parse per-model token breakdown from /global/activity/model
let activity_models: Vec<ActivityModelEntry> = match model_activity_res {
Ok(r) if r.status().is_success() => r.json().await.unwrap_or_default(),
_ => Vec::new(),
};
// Parse per-model spend from /global/spend/models
let spend_models: Vec<SpendModelEntry> = match model_spend_res {
Ok(r) if r.status().is_success() => r.json().await.unwrap_or_default(),
_ => Vec::new(),
};
let total_spend: f64 = spend_models.iter().map(|m| m.total_spend).sum();
let model_breakdown = merge_model_data(activity_models, spend_models);
Ok(LitellmUsageStats {
total_spend,
// Free-tier endpoints don't provide prompt/completion split;
// total_tokens comes from /global/activity.
total_prompt_tokens: 0,
total_completion_tokens: 0,
total_tokens,
model_breakdown,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn merge_empty_inputs() {
let result = merge_model_data(Vec::new(), Vec::new());
assert!(result.is_empty());
}
#[test]
fn merge_activity_only() {
let activity = vec![ActivityModelEntry {
model: "gpt-4".into(),
sum_total_tokens: 1500,
}];
let result = merge_model_data(activity, Vec::new());
assert_eq!(result.len(), 1);
assert_eq!(result[0].model, "gpt-4");
assert_eq!(result[0].total_tokens, 1500);
assert_eq!(result[0].spend, 0.0);
}
#[test]
fn merge_spend_only() {
let spend = vec![SpendModelEntry {
model: "gpt-4".into(),
total_spend: 2.5,
}];
let result = merge_model_data(Vec::new(), spend);
assert_eq!(result.len(), 1);
assert_eq!(result[0].model, "gpt-4");
assert_eq!(result[0].spend, 2.5);
assert_eq!(result[0].total_tokens, 0);
}
#[test]
fn merge_joins_by_model_name() {
let activity = vec![
ActivityModelEntry {
model: "gpt-4".into(),
sum_total_tokens: 5000,
},
ActivityModelEntry {
model: "claude-3".into(),
sum_total_tokens: 3000,
},
];
let spend = vec![
SpendModelEntry {
model: "gpt-4".into(),
total_spend: 1.0,
},
SpendModelEntry {
model: "claude-3".into(),
total_spend: 0.5,
},
];
let result = merge_model_data(activity, spend);
assert_eq!(result.len(), 2);
// Sorted by tokens descending: gpt-4 (5000) before claude-3 (3000)
assert_eq!(result[0].model, "gpt-4");
assert_eq!(result[0].total_tokens, 5000);
assert_eq!(result[0].spend, 1.0);
assert_eq!(result[1].model, "claude-3");
assert_eq!(result[1].total_tokens, 3000);
assert_eq!(result[1].spend, 0.5);
}
#[test]
fn merge_skips_empty_model_names() {
let activity = vec![
ActivityModelEntry {
model: "".into(),
sum_total_tokens: 100,
},
ActivityModelEntry {
model: "gpt-4".into(),
sum_total_tokens: 500,
},
];
let spend = vec![SpendModelEntry {
model: "".into(),
total_spend: 0.01,
}];
let result = merge_model_data(activity, spend);
assert_eq!(result.len(), 1);
assert_eq!(result[0].model, "gpt-4");
}
#[test]
fn merge_unmatched_models_appear_in_both_directions() {
let activity = vec![ActivityModelEntry {
model: "tokens-only".into(),
sum_total_tokens: 1000,
}];
let spend = vec![SpendModelEntry {
model: "spend-only".into(),
total_spend: 0.5,
}];
let result = merge_model_data(activity, spend);
assert_eq!(result.len(), 2);
// tokens-only has 1000 tokens, spend-only has 0 tokens
assert_eq!(result[0].model, "tokens-only");
assert_eq!(result[0].total_tokens, 1000);
assert_eq!(result[1].model, "spend-only");
assert_eq!(result[1].spend, 0.5);
}
}