refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s
This commit was merged in pull request #13.
This commit is contained in:
@@ -1,147 +1,17 @@
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::types::*;
|
||||
use crate::error::AgentError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LlmClient {
|
||||
base_url: String,
|
||||
api_key: SecretString,
|
||||
model: String,
|
||||
embed_model: String,
|
||||
http: reqwest::Client,
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) api_key: SecretString,
|
||||
pub(crate) model: String,
|
||||
pub(crate) embed_model: String,
|
||||
pub(crate) http: reqwest::Client,
|
||||
}
|
||||
|
||||
// ── Request types ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallRequest>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ToolDefinitionPayload>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolDefinitionPayload {
|
||||
r#type: String,
|
||||
function: ToolFunctionPayload,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolFunctionPayload {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Response types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<ChatChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatChoice {
|
||||
message: ChatResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<ToolCallResponse>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallResponse {
|
||||
id: String,
|
||||
function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallFunction {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
// ── Public types for tool calling ──────────────────────────────
|
||||
|
||||
/// Definition of a tool that the LLM can invoke
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call request from the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequest {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: ToolCallRequestFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequestFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Response from the LLM — either content or tool calls
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LlmResponse {
|
||||
Content(String),
|
||||
/// Tool calls with optional reasoning text from the LLM
|
||||
ToolCalls { calls: Vec<LlmToolCall>, reasoning: String },
|
||||
}
|
||||
|
||||
// ── Embedding types ────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
// ── Implementation ─────────────────────────────────────────────
|
||||
|
||||
impl LlmClient {
|
||||
pub fn new(
|
||||
base_url: String,
|
||||
@@ -158,18 +28,14 @@ impl LlmClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_model(&self) -> &str {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
fn chat_url(&self) -> String {
|
||||
pub(crate) fn chat_url(&self) -> String {
|
||||
format!(
|
||||
"{}/v1/chat/completions",
|
||||
self.base_url.trim_end_matches('/')
|
||||
)
|
||||
}
|
||||
|
||||
fn auth_header(&self) -> Option<String> {
|
||||
pub(crate) fn auth_header(&self) -> Option<String> {
|
||||
let key = self.api_key.expose_secret();
|
||||
if key.is_empty() {
|
||||
None
|
||||
@@ -241,12 +107,12 @@ impl LlmClient {
|
||||
tools: None,
|
||||
};
|
||||
|
||||
self.send_chat_request(&request_body).await.map(|resp| {
|
||||
match resp {
|
||||
self.send_chat_request(&request_body)
|
||||
.await
|
||||
.map(|resp| match resp {
|
||||
LlmResponse::Content(c) => c,
|
||||
LlmResponse::ToolCalls { .. } => String::new(),
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Chat with tool definitions — returns either content or tool calls.
|
||||
@@ -292,7 +158,7 @@ impl LlmClient {
|
||||
) -> Result<LlmResponse, AgentError> {
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&self.chat_url())
|
||||
.post(self.chat_url())
|
||||
.header("content-type", "application/json")
|
||||
.json(request_body);
|
||||
|
||||
@@ -345,54 +211,7 @@ impl LlmClient {
|
||||
}
|
||||
|
||||
// Otherwise return content
|
||||
let content = choice
|
||||
.message
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default();
|
||||
let content = choice.message.content.clone().unwrap_or_default();
|
||||
Ok(LlmResponse::Content(content))
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
|
||||
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body = EmbeddingRequest {
|
||||
model: self.embed_model.clone(),
|
||||
input: texts,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"Embedding API returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
Ok(data.into_iter().map(|d| d.embedding).collect())
|
||||
}
|
||||
}
|
||||
|
||||
74
compliance-agent/src/llm/embedding.rs
Normal file
74
compliance-agent/src/llm/embedding.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::client::LlmClient;
|
||||
use crate::error::AgentError;
|
||||
|
||||
// ── Embedding types ────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
// ── Embedding implementation ───────────────────────────────────
|
||||
|
||||
impl LlmClient {
|
||||
pub fn embed_model(&self) -> &str {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
|
||||
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body = EmbeddingRequest {
|
||||
model: self.embed_model.clone(),
|
||||
input: texts,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"Embedding API returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
Ok(data.into_iter().map(|d| d.embedding).collect())
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
pub mod client;
|
||||
#[allow(dead_code)]
|
||||
pub mod descriptions;
|
||||
pub mod embedding;
|
||||
#[allow(dead_code)]
|
||||
pub mod fixes;
|
||||
#[allow(dead_code)]
|
||||
pub mod pr_review;
|
||||
pub mod review_prompts;
|
||||
pub mod triage;
|
||||
pub mod types;
|
||||
|
||||
pub use client::LlmClient;
|
||||
pub use types::{
|
||||
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
|
||||
};
|
||||
|
||||
@@ -278,3 +278,220 @@ struct TriageResult {
|
||||
fn default_action() -> String {
|
||||
"confirm".to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::Severity;
|
||||
|
||||
// ── classify_file_path ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn classify_none_path() {
|
||||
assert_eq!(classify_file_path(None), "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_production_path() {
|
||||
assert_eq!(classify_file_path(Some("src/main.rs")), "production");
|
||||
assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_test_paths() {
|
||||
assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test");
|
||||
assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test");
|
||||
assert_eq!(classify_file_path(Some("foo_test.go")), "test");
|
||||
assert_eq!(classify_file_path(Some("bar.test.js")), "test");
|
||||
assert_eq!(classify_file_path(Some("baz.spec.ts")), "test");
|
||||
assert_eq!(
|
||||
classify_file_path(Some("data/fixtures/sample.json")),
|
||||
"test"
|
||||
);
|
||||
assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_example_paths() {
|
||||
assert_eq!(
|
||||
classify_file_path(Some("docs/examples/basic.rs")),
|
||||
"example"
|
||||
);
|
||||
// /example matches because contains("/example")
|
||||
assert_eq!(classify_file_path(Some("src/example/main.py")), "example");
|
||||
assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example");
|
||||
assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_generated_paths() {
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/generated/api.rs")),
|
||||
"generated"
|
||||
);
|
||||
assert_eq!(
|
||||
classify_file_path(Some("proto/gen/service.go")),
|
||||
"generated"
|
||||
);
|
||||
assert_eq!(classify_file_path(Some("api.generated.ts")), "generated");
|
||||
assert_eq!(classify_file_path(Some("service.pb.go")), "generated");
|
||||
assert_eq!(classify_file_path(Some("model_generated.rs")), "generated");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_vendored_paths() {
|
||||
// Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes)
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/vendor/lib/foo.go")),
|
||||
"vendored"
|
||||
);
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/node_modules/pkg/index.js")),
|
||||
"vendored"
|
||||
);
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/third_party/lib.c")),
|
||||
"vendored"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_is_case_insensitive() {
|
||||
assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test");
|
||||
assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored");
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/GENERATED/foo.ts")),
|
||||
"generated"
|
||||
);
|
||||
}
|
||||
|
||||
// ── adjust_confidence ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_production() {
|
||||
assert_eq!(adjust_confidence(8.0, "production"), 8.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_test() {
|
||||
assert_eq!(adjust_confidence(10.0, "test"), 5.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_example() {
|
||||
assert_eq!(adjust_confidence(10.0, "example"), 6.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_generated() {
|
||||
assert_eq!(adjust_confidence(10.0, "generated"), 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_vendored() {
|
||||
assert_eq!(adjust_confidence(10.0, "vendored"), 4.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_unknown_classification() {
|
||||
assert_eq!(adjust_confidence(7.0, "unknown"), 7.0);
|
||||
assert_eq!(adjust_confidence(7.0, "something_else"), 7.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_zero() {
|
||||
assert_eq!(adjust_confidence(0.0, "test"), 0.0);
|
||||
assert_eq!(adjust_confidence(0.0, "production"), 0.0);
|
||||
}
|
||||
|
||||
// ── downgrade_severity ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn downgrade_severity_all_levels() {
|
||||
assert_eq!(downgrade_severity(&Severity::Critical), Severity::High);
|
||||
assert_eq!(downgrade_severity(&Severity::High), Severity::Medium);
|
||||
assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low);
|
||||
assert_eq!(downgrade_severity(&Severity::Low), Severity::Info);
|
||||
assert_eq!(downgrade_severity(&Severity::Info), Severity::Info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn downgrade_severity_info_is_floor() {
|
||||
// Downgrading Info twice should still be Info
|
||||
let s = downgrade_severity(&Severity::Info);
|
||||
assert_eq!(downgrade_severity(&s), Severity::Info);
|
||||
}
|
||||
|
||||
// ── upgrade_severity ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn upgrade_severity_all_levels() {
|
||||
assert_eq!(upgrade_severity(&Severity::Info), Severity::Low);
|
||||
assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium);
|
||||
assert_eq!(upgrade_severity(&Severity::Medium), Severity::High);
|
||||
assert_eq!(upgrade_severity(&Severity::High), Severity::Critical);
|
||||
assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upgrade_severity_critical_is_ceiling() {
|
||||
let s = upgrade_severity(&Severity::Critical);
|
||||
assert_eq!(upgrade_severity(&s), Severity::Critical);
|
||||
}
|
||||
|
||||
// ── upgrade/downgrade roundtrip ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn upgrade_then_downgrade_is_identity_for_middle_values() {
|
||||
for sev in [Severity::Low, Severity::Medium, Severity::High] {
|
||||
assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev);
|
||||
}
|
||||
}
|
||||
|
||||
// ── TriageResult deserialization ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn triage_result_full() {
|
||||
let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#;
|
||||
let r: TriageResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(r.action, "dismiss");
|
||||
assert_eq!(r.confidence, 8.5);
|
||||
assert_eq!(r.rationale, "false positive");
|
||||
assert_eq!(r.remediation.as_deref(), Some("remove code"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn triage_result_defaults() {
|
||||
let json = r#"{}"#;
|
||||
let r: TriageResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(r.action, "confirm");
|
||||
assert_eq!(r.confidence, 0.0);
|
||||
assert_eq!(r.rationale, "");
|
||||
assert!(r.remediation.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn triage_result_partial() {
|
||||
let json = r#"{"action":"downgrade","confidence":6.0}"#;
|
||||
let r: TriageResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(r.action, "downgrade");
|
||||
assert_eq!(r.confidence, 6.0);
|
||||
assert_eq!(r.rationale, "");
|
||||
assert!(r.remediation.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn triage_result_with_markdown_fences() {
|
||||
// Simulate LLM wrapping response in markdown code fences
|
||||
let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```";
|
||||
let cleaned = raw
|
||||
.trim()
|
||||
.trim_start_matches("```json")
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim();
|
||||
let r: TriageResult = serde_json::from_str(cleaned).unwrap();
|
||||
assert_eq!(r.action, "upgrade");
|
||||
assert_eq!(r.confidence, 9.0);
|
||||
}
|
||||
}
|
||||
|
||||
369
compliance-agent/src/llm/types.rs
Normal file
369
compliance-agent/src/llm/types.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Request types ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallRequest>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ChatCompletionRequest {
|
||||
pub(crate) model: String,
|
||||
pub(crate) messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) tools: Option<Vec<ToolDefinitionPayload>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ToolDefinitionPayload {
|
||||
pub(crate) r#type: String,
|
||||
pub(crate) function: ToolFunctionPayload,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ToolFunctionPayload {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Response types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ChatCompletionResponse {
|
||||
pub(crate) choices: Vec<ChatChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ChatChoice {
|
||||
pub(crate) message: ChatResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ChatResponseMessage {
|
||||
#[serde(default)]
|
||||
pub(crate) content: Option<String>,
|
||||
#[serde(default)]
|
||||
pub(crate) tool_calls: Option<Vec<ToolCallResponse>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ToolCallResponse {
|
||||
pub(crate) id: String,
|
||||
pub(crate) function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ToolCallFunction {
|
||||
pub(crate) name: String,
|
||||
pub(crate) arguments: String,
|
||||
}
|
||||
|
||||
// ── Public types for tool calling ──────────────────────────────
|
||||
|
||||
/// Definition of a tool that the LLM can invoke
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call request from the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequest {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: ToolCallRequestFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequestFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Response from the LLM — either content or tool calls
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LlmResponse {
|
||||
Content(String),
|
||||
/// Tool calls with optional reasoning text from the LLM
|
||||
ToolCalls {
|
||||
calls: Vec<LlmToolCall>,
|
||||
reasoning: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// ── ChatMessage ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn chat_message_serializes_minimal() {
|
||||
let msg = ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
let v = serde_json::to_value(&msg).unwrap();
|
||||
assert_eq!(v["role"], "user");
|
||||
assert_eq!(v["content"], "hello");
|
||||
// None fields with skip_serializing_if should be absent
|
||||
assert!(v.get("tool_calls").is_none());
|
||||
assert!(v.get("tool_call_id").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_serializes_with_tool_calls() {
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallRequest {
|
||||
id: "call_1".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"city":"NYC"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
tool_call_id: None,
|
||||
};
|
||||
let v = serde_json::to_value(&msg).unwrap();
|
||||
assert!(v["tool_calls"].is_array());
|
||||
assert_eq!(v["tool_calls"][0]["function"]["name"], "get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_content_null_when_none() {
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
let v = serde_json::to_value(&msg).unwrap();
|
||||
assert!(v["content"].is_null());
|
||||
}
|
||||
|
||||
// ── ToolDefinition ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tool_definition_serializes() {
|
||||
let td = ToolDefinition {
|
||||
name: "search".to_string(),
|
||||
description: "Search the web".to_string(),
|
||||
parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}),
|
||||
};
|
||||
let v = serde_json::to_value(&td).unwrap();
|
||||
assert_eq!(v["name"], "search");
|
||||
assert_eq!(v["parameters"]["type"], "object");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_definition_empty_parameters() {
|
||||
let td = ToolDefinition {
|
||||
name: "noop".to_string(),
|
||||
description: "".to_string(),
|
||||
parameters: json!({}),
|
||||
};
|
||||
let v = serde_json::to_value(&td).unwrap();
|
||||
assert_eq!(v["parameters"], json!({}));
|
||||
}
|
||||
|
||||
// ── LlmToolCall ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn llm_tool_call_roundtrip() {
|
||||
let call = LlmToolCall {
|
||||
id: "tc_42".to_string(),
|
||||
name: "run_scan".to_string(),
|
||||
arguments: json!({"path": "/tmp", "verbose": true}),
|
||||
};
|
||||
let serialized = serde_json::to_string(&call).unwrap();
|
||||
let deserialized: LlmToolCall = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.id, "tc_42");
|
||||
assert_eq!(deserialized.name, "run_scan");
|
||||
assert_eq!(deserialized.arguments["path"], "/tmp");
|
||||
assert_eq!(deserialized.arguments["verbose"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_tool_call_empty_arguments() {
|
||||
let call = LlmToolCall {
|
||||
id: "tc_0".to_string(),
|
||||
name: "noop".to_string(),
|
||||
arguments: json!({}),
|
||||
};
|
||||
let rt: LlmToolCall = serde_json::from_str(&serde_json::to_string(&call).unwrap()).unwrap();
|
||||
assert!(rt.arguments.as_object().unwrap().is_empty());
|
||||
}
|
||||
|
||||
// ── ToolCallRequest / ToolCallRequestFunction ────────────────
|
||||
|
||||
#[test]
|
||||
fn tool_call_request_roundtrip() {
|
||||
let req = ToolCallRequest {
|
||||
id: "call_abc".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: "my_func".to_string(),
|
||||
arguments: r#"{"x":1}"#.to_string(),
|
||||
},
|
||||
};
|
||||
let json_str = serde_json::to_string(&req).unwrap();
|
||||
let back: ToolCallRequest = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(back.id, "call_abc");
|
||||
assert_eq!(back.r#type, "function");
|
||||
assert_eq!(back.function.name, "my_func");
|
||||
assert_eq!(back.function.arguments, r#"{"x":1}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_request_type_field_serializes_as_type() {
|
||||
let req = ToolCallRequest {
|
||||
id: "id".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: "f".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
let v = serde_json::to_value(&req).unwrap();
|
||||
// The field should be "type" in JSON, not "r#type"
|
||||
assert!(v.get("type").is_some());
|
||||
assert!(v.get("r#type").is_none());
|
||||
}
|
||||
|
||||
// ── ChatCompletionRequest ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn chat_completion_request_skips_none_fields() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools: None,
|
||||
};
|
||||
let v = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(v["model"], "gpt-4");
|
||||
assert!(v.get("temperature").is_none());
|
||||
assert!(v.get("max_tokens").is_none());
|
||||
assert!(v.get("tools").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completion_request_includes_set_fields() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![],
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(1024),
|
||||
tools: Some(vec![]),
|
||||
};
|
||||
let v = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(v["temperature"], 0.7);
|
||||
assert_eq!(v["max_tokens"], 1024);
|
||||
assert!(v["tools"].is_array());
|
||||
}
|
||||
|
||||
// ── ChatCompletionResponse deserialization ───────────────────
|
||||
|
||||
#[test]
|
||||
fn chat_completion_response_deserializes_content() {
|
||||
let json_str = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#;
|
||||
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
|
||||
assert_eq!(resp.choices.len(), 1);
|
||||
assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
|
||||
assert!(resp.choices[0].message.tool_calls.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completion_response_deserializes_tool_calls() {
|
||||
let json_str = r#"{
|
||||
"choices": [{
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {"name": "search", "arguments": "{\"q\":\"rust\"}"}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
}"#;
|
||||
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
|
||||
let tc = resp.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tc.len(), 1);
|
||||
assert_eq!(tc[0].id, "call_1");
|
||||
assert_eq!(tc[0].function.name, "search");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completion_response_defaults_missing_fields() {
|
||||
// content and tool_calls are both missing — should default to None
|
||||
let json_str = r#"{"choices":[{"message":{}}]}"#;
|
||||
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
|
||||
assert!(resp.choices[0].message.content.is_none());
|
||||
assert!(resp.choices[0].message.tool_calls.is_none());
|
||||
}
|
||||
|
||||
// ── LlmResponse ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn llm_response_content_variant() {
|
||||
let resp = LlmResponse::Content("answer".to_string());
|
||||
match resp {
|
||||
LlmResponse::Content(s) => assert_eq!(s, "answer"),
|
||||
_ => panic!("expected Content variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_response_tool_calls_variant() {
|
||||
let resp = LlmResponse::ToolCalls {
|
||||
calls: vec![LlmToolCall {
|
||||
id: "1".to_string(),
|
||||
name: "f".to_string(),
|
||||
arguments: json!({}),
|
||||
}],
|
||||
reasoning: "because".to_string(),
|
||||
};
|
||||
match resp {
|
||||
LlmResponse::ToolCalls { calls, reasoning } => {
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(reasoning, "because");
|
||||
}
|
||||
_ => panic!("expected ToolCalls variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_response_empty_content() {
|
||||
let resp = LlmResponse::Content(String::new());
|
||||
match resp {
|
||||
LlmResponse::Content(s) => assert!(s.is_empty()),
|
||||
_ => panic!("expected Content variant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user