refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s

This commit was merged in pull request #13.
This commit is contained in:
2026-03-13 08:03:45 +00:00
parent acc5b86aa4
commit 3bb690e5bb
89 changed files with 11884 additions and 6046 deletions

View File

@@ -1,147 +1,17 @@
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use super::types::*;
use crate::error::AgentError;
#[derive(Clone)]
pub struct LlmClient {
base_url: String,
api_key: SecretString,
model: String,
embed_model: String,
http: reqwest::Client,
pub(crate) base_url: String,
pub(crate) api_key: SecretString,
pub(crate) model: String,
pub(crate) embed_model: String,
pub(crate) http: reqwest::Client,
}
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
struct ToolDefinitionPayload {
r#type: String,
function: ToolFunctionPayload,
}
#[derive(Serialize)]
struct ToolFunctionPayload {
name: String,
description: String,
parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
struct ChatChoice {
message: ChatResponseMessage,
}
#[derive(Deserialize)]
struct ChatResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Deserialize)]
struct ToolCallResponse {
id: String,
function: ToolCallFunction,
}
#[derive(Deserialize)]
struct ToolCallFunction {
name: String,
arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
/// Tool calls with optional reasoning text from the LLM
ToolCalls { calls: Vec<LlmToolCall>, reasoning: String },
}
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Implementation ─────────────────────────────────────────────
impl LlmClient {
pub fn new(
base_url: String,
@@ -158,18 +28,14 @@ impl LlmClient {
}
}
pub fn embed_model(&self) -> &str {
&self.embed_model
}
fn chat_url(&self) -> String {
pub(crate) fn chat_url(&self) -> String {
format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
)
}
fn auth_header(&self) -> Option<String> {
pub(crate) fn auth_header(&self) -> Option<String> {
let key = self.api_key.expose_secret();
if key.is_empty() {
None
@@ -241,12 +107,12 @@ impl LlmClient {
tools: None,
};
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
self.send_chat_request(&request_body)
.await
.map(|resp| match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls { .. } => String::new(),
}
})
})
}
/// Chat with tool definitions — returns either content or tool calls.
@@ -292,7 +158,7 @@ impl LlmClient {
) -> Result<LlmResponse, AgentError> {
let mut req = self
.http
.post(&self.chat_url())
.post(self.chat_url())
.header("content-type", "application/json")
.json(request_body);
@@ -345,54 +211,7 @@ impl LlmClient {
}
// Otherwise return content
let content = choice
.message
.content
.clone()
.unwrap_or_default();
let content = choice.message.content.clone().unwrap_or_default();
Ok(LlmResponse::Content(content))
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -0,0 +1,74 @@
use serde::{Deserialize, Serialize};
use super::client::LlmClient;
use crate::error::AgentError;
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Embedding implementation ───────────────────────────────────
impl LlmClient {
pub fn embed_model(&self) -> &str {
&self.embed_model
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -1,11 +1,16 @@
pub mod client;
#[allow(dead_code)]
pub mod descriptions;
pub mod embedding;
#[allow(dead_code)]
pub mod fixes;
#[allow(dead_code)]
pub mod pr_review;
pub mod review_prompts;
pub mod triage;
pub mod types;
pub use client::LlmClient;
pub use types::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};

View File

@@ -278,3 +278,220 @@ struct TriageResult {
fn default_action() -> String {
"confirm".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::Severity;
// ── classify_file_path ───────────────────────────────────────
#[test]
fn classify_none_path() {
assert_eq!(classify_file_path(None), "unknown");
}
#[test]
fn classify_production_path() {
assert_eq!(classify_file_path(Some("src/main.rs")), "production");
assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production");
}
#[test]
fn classify_test_paths() {
assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test");
assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test");
assert_eq!(classify_file_path(Some("foo_test.go")), "test");
assert_eq!(classify_file_path(Some("bar.test.js")), "test");
assert_eq!(classify_file_path(Some("baz.spec.ts")), "test");
assert_eq!(
classify_file_path(Some("data/fixtures/sample.json")),
"test"
);
assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test");
}
#[test]
fn classify_example_paths() {
assert_eq!(
classify_file_path(Some("docs/examples/basic.rs")),
"example"
);
// /example matches because contains("/example")
assert_eq!(classify_file_path(Some("src/example/main.py")), "example");
assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example");
assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example");
}
#[test]
fn classify_generated_paths() {
assert_eq!(
classify_file_path(Some("src/generated/api.rs")),
"generated"
);
assert_eq!(
classify_file_path(Some("proto/gen/service.go")),
"generated"
);
assert_eq!(classify_file_path(Some("api.generated.ts")), "generated");
assert_eq!(classify_file_path(Some("service.pb.go")), "generated");
assert_eq!(classify_file_path(Some("model_generated.rs")), "generated");
}
#[test]
fn classify_vendored_paths() {
// Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes)
assert_eq!(
classify_file_path(Some("src/vendor/lib/foo.go")),
"vendored"
);
assert_eq!(
classify_file_path(Some("src/node_modules/pkg/index.js")),
"vendored"
);
assert_eq!(
classify_file_path(Some("src/third_party/lib.c")),
"vendored"
);
}
#[test]
fn classify_is_case_insensitive() {
assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test");
assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored");
assert_eq!(
classify_file_path(Some("src/GENERATED/foo.ts")),
"generated"
);
}
// ── adjust_confidence ────────────────────────────────────────
#[test]
fn adjust_confidence_production() {
assert_eq!(adjust_confidence(8.0, "production"), 8.0);
}
#[test]
fn adjust_confidence_test() {
assert_eq!(adjust_confidence(10.0, "test"), 5.0);
}
#[test]
fn adjust_confidence_example() {
assert_eq!(adjust_confidence(10.0, "example"), 6.0);
}
#[test]
fn adjust_confidence_generated() {
assert_eq!(adjust_confidence(10.0, "generated"), 3.0);
}
#[test]
fn adjust_confidence_vendored() {
assert_eq!(adjust_confidence(10.0, "vendored"), 4.0);
}
#[test]
fn adjust_confidence_unknown_classification() {
assert_eq!(adjust_confidence(7.0, "unknown"), 7.0);
assert_eq!(adjust_confidence(7.0, "something_else"), 7.0);
}
#[test]
fn adjust_confidence_zero() {
assert_eq!(adjust_confidence(0.0, "test"), 0.0);
assert_eq!(adjust_confidence(0.0, "production"), 0.0);
}
// ── downgrade_severity ───────────────────────────────────────
#[test]
fn downgrade_severity_all_levels() {
assert_eq!(downgrade_severity(&Severity::Critical), Severity::High);
assert_eq!(downgrade_severity(&Severity::High), Severity::Medium);
assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low);
assert_eq!(downgrade_severity(&Severity::Low), Severity::Info);
assert_eq!(downgrade_severity(&Severity::Info), Severity::Info);
}
#[test]
fn downgrade_severity_info_is_floor() {
// Downgrading Info twice should still be Info
let s = downgrade_severity(&Severity::Info);
assert_eq!(downgrade_severity(&s), Severity::Info);
}
// ── upgrade_severity ─────────────────────────────────────────
#[test]
fn upgrade_severity_all_levels() {
assert_eq!(upgrade_severity(&Severity::Info), Severity::Low);
assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium);
assert_eq!(upgrade_severity(&Severity::Medium), Severity::High);
assert_eq!(upgrade_severity(&Severity::High), Severity::Critical);
assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical);
}
#[test]
fn upgrade_severity_critical_is_ceiling() {
let s = upgrade_severity(&Severity::Critical);
assert_eq!(upgrade_severity(&s), Severity::Critical);
}
// ── upgrade/downgrade roundtrip ──────────────────────────────
#[test]
fn upgrade_then_downgrade_is_identity_for_middle_values() {
for sev in [Severity::Low, Severity::Medium, Severity::High] {
assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev);
}
}
// ── TriageResult deserialization ─────────────────────────────
#[test]
fn triage_result_full() {
let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "dismiss");
assert_eq!(r.confidence, 8.5);
assert_eq!(r.rationale, "false positive");
assert_eq!(r.remediation.as_deref(), Some("remove code"));
}
#[test]
fn triage_result_defaults() {
let json = r#"{}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "confirm");
assert_eq!(r.confidence, 0.0);
assert_eq!(r.rationale, "");
assert!(r.remediation.is_none());
}
#[test]
fn triage_result_partial() {
let json = r#"{"action":"downgrade","confidence":6.0}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "downgrade");
assert_eq!(r.confidence, 6.0);
assert_eq!(r.rationale, "");
assert!(r.remediation.is_none());
}
#[test]
fn triage_result_with_markdown_fences() {
// Simulate LLM wrapping response in markdown code fences
let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```";
let cleaned = raw
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
let r: TriageResult = serde_json::from_str(cleaned).unwrap();
assert_eq!(r.action, "upgrade");
assert_eq!(r.confidence, 9.0);
}
}

View File

@@ -0,0 +1,369 @@
use serde::{Deserialize, Serialize};
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
pub(crate) struct ChatCompletionRequest {
pub(crate) model: String,
pub(crate) messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
pub(crate) struct ToolDefinitionPayload {
pub(crate) r#type: String,
pub(crate) function: ToolFunctionPayload,
}
#[derive(Serialize)]
pub(crate) struct ToolFunctionPayload {
pub(crate) name: String,
pub(crate) description: String,
pub(crate) parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
pub(crate) struct ChatCompletionResponse {
pub(crate) choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
pub(crate) struct ChatChoice {
pub(crate) message: ChatResponseMessage,
}
#[derive(Deserialize)]
pub(crate) struct ChatResponseMessage {
#[serde(default)]
pub(crate) content: Option<String>,
#[serde(default)]
pub(crate) tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Deserialize)]
pub(crate) struct ToolCallResponse {
pub(crate) id: String,
pub(crate) function: ToolCallFunction,
}
#[derive(Deserialize)]
pub(crate) struct ToolCallFunction {
pub(crate) name: String,
pub(crate) arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
/// Tool calls with optional reasoning text from the LLM
ToolCalls {
calls: Vec<LlmToolCall>,
reasoning: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ── ChatMessage ──────────────────────────────────────────────
#[test]
fn chat_message_serializes_minimal() {
let msg = ChatMessage {
role: "user".to_string(),
content: Some("hello".to_string()),
tool_calls: None,
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert_eq!(v["role"], "user");
assert_eq!(v["content"], "hello");
// None fields with skip_serializing_if should be absent
assert!(v.get("tool_calls").is_none());
assert!(v.get("tool_call_id").is_none());
}
#[test]
fn chat_message_serializes_with_tool_calls() {
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(vec![ToolCallRequest {
id: "call_1".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "get_weather".to_string(),
arguments: r#"{"city":"NYC"}"#.to_string(),
},
}]),
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert!(v["tool_calls"].is_array());
assert_eq!(v["tool_calls"][0]["function"]["name"], "get_weather");
}
#[test]
fn chat_message_content_null_when_none() {
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: None,
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert!(v["content"].is_null());
}
// ── ToolDefinition ───────────────────────────────────────────
#[test]
fn tool_definition_serializes() {
let td = ToolDefinition {
name: "search".to_string(),
description: "Search the web".to_string(),
parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}),
};
let v = serde_json::to_value(&td).unwrap();
assert_eq!(v["name"], "search");
assert_eq!(v["parameters"]["type"], "object");
}
#[test]
fn tool_definition_empty_parameters() {
let td = ToolDefinition {
name: "noop".to_string(),
description: "".to_string(),
parameters: json!({}),
};
let v = serde_json::to_value(&td).unwrap();
assert_eq!(v["parameters"], json!({}));
}
// ── LlmToolCall ──────────────────────────────────────────────
#[test]
fn llm_tool_call_roundtrip() {
let call = LlmToolCall {
id: "tc_42".to_string(),
name: "run_scan".to_string(),
arguments: json!({"path": "/tmp", "verbose": true}),
};
let serialized = serde_json::to_string(&call).unwrap();
let deserialized: LlmToolCall = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.id, "tc_42");
assert_eq!(deserialized.name, "run_scan");
assert_eq!(deserialized.arguments["path"], "/tmp");
assert_eq!(deserialized.arguments["verbose"], true);
}
#[test]
fn llm_tool_call_empty_arguments() {
let call = LlmToolCall {
id: "tc_0".to_string(),
name: "noop".to_string(),
arguments: json!({}),
};
let rt: LlmToolCall = serde_json::from_str(&serde_json::to_string(&call).unwrap()).unwrap();
assert!(rt.arguments.as_object().unwrap().is_empty());
}
// ── ToolCallRequest / ToolCallRequestFunction ────────────────
#[test]
fn tool_call_request_roundtrip() {
let req = ToolCallRequest {
id: "call_abc".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "my_func".to_string(),
arguments: r#"{"x":1}"#.to_string(),
},
};
let json_str = serde_json::to_string(&req).unwrap();
let back: ToolCallRequest = serde_json::from_str(&json_str).unwrap();
assert_eq!(back.id, "call_abc");
assert_eq!(back.r#type, "function");
assert_eq!(back.function.name, "my_func");
assert_eq!(back.function.arguments, r#"{"x":1}"#);
}
#[test]
fn tool_call_request_type_field_serializes_as_type() {
let req = ToolCallRequest {
id: "id".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "f".to_string(),
arguments: "{}".to_string(),
},
};
let v = serde_json::to_value(&req).unwrap();
// The field should be "type" in JSON, not "r#type"
assert!(v.get("type").is_some());
assert!(v.get("r#type").is_none());
}
// ── ChatCompletionRequest ────────────────────────────────────
#[test]
fn chat_completion_request_skips_none_fields() {
let req = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
tools: None,
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v["model"], "gpt-4");
assert!(v.get("temperature").is_none());
assert!(v.get("max_tokens").is_none());
assert!(v.get("tools").is_none());
}
#[test]
fn chat_completion_request_includes_set_fields() {
let req = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: Some(0.7),
max_tokens: Some(1024),
tools: Some(vec![]),
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v["temperature"], 0.7);
assert_eq!(v["max_tokens"], 1024);
assert!(v["tools"].is_array());
}
// ── ChatCompletionResponse deserialization ───────────────────
#[test]
fn chat_completion_response_deserializes_content() {
let json_str = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
assert!(resp.choices[0].message.tool_calls.is_none());
}
#[test]
fn chat_completion_response_deserializes_tool_calls() {
let json_str = r#"{
"choices": [{
"message": {
"tool_calls": [{
"id": "call_1",
"function": {"name": "search", "arguments": "{\"q\":\"rust\"}"}
}]
}
}]
}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
let tc = resp.choices[0].message.tool_calls.as_ref().unwrap();
assert_eq!(tc.len(), 1);
assert_eq!(tc[0].id, "call_1");
assert_eq!(tc[0].function.name, "search");
}
#[test]
fn chat_completion_response_defaults_missing_fields() {
// content and tool_calls are both missing — should default to None
let json_str = r#"{"choices":[{"message":{}}]}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
assert!(resp.choices[0].message.content.is_none());
assert!(resp.choices[0].message.tool_calls.is_none());
}
// ── LlmResponse ─────────────────────────────────────────────
#[test]
fn llm_response_content_variant() {
let resp = LlmResponse::Content("answer".to_string());
match resp {
LlmResponse::Content(s) => assert_eq!(s, "answer"),
_ => panic!("expected Content variant"),
}
}
#[test]
fn llm_response_tool_calls_variant() {
let resp = LlmResponse::ToolCalls {
calls: vec![LlmToolCall {
id: "1".to_string(),
name: "f".to_string(),
arguments: json!({}),
}],
reasoning: "because".to_string(),
};
match resp {
LlmResponse::ToolCalls { calls, reasoning } => {
assert_eq!(calls.len(), 1);
assert_eq!(reasoning, "because");
}
_ => panic!("expected ToolCalls variant"),
}
}
#[test]
fn llm_response_empty_content() {
let resp = LlmResponse::Content(String::new());
match resp {
LlmResponse::Content(s) => assert!(s.is_empty()),
_ => panic!("expected Content variant"),
}
}
}