All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s
218 lines
6.6 KiB
Rust
218 lines
6.6 KiB
Rust
use secrecy::{ExposeSecret, SecretString};
|
|
|
|
use super::types::*;
|
|
use crate::error::AgentError;
|
|
|
|
#[derive(Clone)]
|
|
pub struct LlmClient {
|
|
pub(crate) base_url: String,
|
|
pub(crate) api_key: SecretString,
|
|
pub(crate) model: String,
|
|
pub(crate) embed_model: String,
|
|
pub(crate) http: reqwest::Client,
|
|
}
|
|
|
|
impl LlmClient {
|
|
pub fn new(
|
|
base_url: String,
|
|
api_key: SecretString,
|
|
model: String,
|
|
embed_model: String,
|
|
) -> Self {
|
|
Self {
|
|
base_url,
|
|
api_key,
|
|
model,
|
|
embed_model,
|
|
http: reqwest::Client::new(),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn chat_url(&self) -> String {
|
|
format!(
|
|
"{}/v1/chat/completions",
|
|
self.base_url.trim_end_matches('/')
|
|
)
|
|
}
|
|
|
|
pub(crate) fn auth_header(&self) -> Option<String> {
|
|
let key = self.api_key.expose_secret();
|
|
if key.is_empty() {
|
|
None
|
|
} else {
|
|
Some(format!("Bearer {key}"))
|
|
}
|
|
}
|
|
|
|
/// Simple chat: system + user prompt → text response
|
|
pub async fn chat(
|
|
&self,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
temperature: Option<f64>,
|
|
) -> Result<String, AgentError> {
|
|
let messages = vec![
|
|
ChatMessage {
|
|
role: "system".to_string(),
|
|
content: Some(system_prompt.to_string()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
ChatMessage {
|
|
role: "user".to_string(),
|
|
content: Some(user_prompt.to_string()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
];
|
|
|
|
let request_body = ChatCompletionRequest {
|
|
model: self.model.clone(),
|
|
messages,
|
|
temperature,
|
|
max_tokens: Some(4096),
|
|
tools: None,
|
|
};
|
|
|
|
self.send_chat_request(&request_body).await.map(|resp| {
|
|
match resp {
|
|
LlmResponse::Content(c) => c,
|
|
LlmResponse::ToolCalls { .. } => String::new(), // shouldn't happen without tools
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Chat with a list of (role, content) messages → text response
|
|
#[allow(dead_code)]
|
|
pub async fn chat_with_messages(
|
|
&self,
|
|
messages: Vec<(String, String)>,
|
|
temperature: Option<f64>,
|
|
) -> Result<String, AgentError> {
|
|
let messages = messages
|
|
.into_iter()
|
|
.map(|(role, content)| ChatMessage {
|
|
role,
|
|
content: Some(content),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
})
|
|
.collect();
|
|
|
|
let request_body = ChatCompletionRequest {
|
|
model: self.model.clone(),
|
|
messages,
|
|
temperature,
|
|
max_tokens: Some(4096),
|
|
tools: None,
|
|
};
|
|
|
|
self.send_chat_request(&request_body)
|
|
.await
|
|
.map(|resp| match resp {
|
|
LlmResponse::Content(c) => c,
|
|
LlmResponse::ToolCalls { .. } => String::new(),
|
|
})
|
|
}
|
|
|
|
/// Chat with tool definitions — returns either content or tool calls.
|
|
/// Use this for the AI pentest orchestrator loop.
|
|
pub async fn chat_with_tools(
|
|
&self,
|
|
messages: Vec<ChatMessage>,
|
|
tools: &[ToolDefinition],
|
|
temperature: Option<f64>,
|
|
max_tokens: Option<u32>,
|
|
) -> Result<LlmResponse, AgentError> {
|
|
let tool_payloads: Vec<ToolDefinitionPayload> = tools
|
|
.iter()
|
|
.map(|t| ToolDefinitionPayload {
|
|
r#type: "function".to_string(),
|
|
function: ToolFunctionPayload {
|
|
name: t.name.clone(),
|
|
description: t.description.clone(),
|
|
parameters: t.parameters.clone(),
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
let request_body = ChatCompletionRequest {
|
|
model: self.model.clone(),
|
|
messages,
|
|
temperature,
|
|
max_tokens: Some(max_tokens.unwrap_or(8192)),
|
|
tools: if tool_payloads.is_empty() {
|
|
None
|
|
} else {
|
|
Some(tool_payloads)
|
|
},
|
|
};
|
|
|
|
self.send_chat_request(&request_body).await
|
|
}
|
|
|
|
/// Internal method to send a chat completion request and parse the response
|
|
async fn send_chat_request(
|
|
&self,
|
|
request_body: &ChatCompletionRequest,
|
|
) -> Result<LlmResponse, AgentError> {
|
|
let mut req = self
|
|
.http
|
|
.post(self.chat_url())
|
|
.header("content-type", "application/json")
|
|
.json(request_body);
|
|
|
|
if let Some(auth) = self.auth_header() {
|
|
req = req.header("Authorization", auth);
|
|
}
|
|
|
|
let resp = req
|
|
.send()
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("LiteLLM request failed: {e}")))?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let body = resp.text().await.unwrap_or_default();
|
|
return Err(AgentError::Other(format!(
|
|
"LiteLLM returned {status}: {body}"
|
|
)));
|
|
}
|
|
|
|
let body: ChatCompletionResponse = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
|
|
|
|
let choice = body
|
|
.choices
|
|
.first()
|
|
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))?;
|
|
|
|
// Check for tool calls first
|
|
if let Some(tool_calls) = &choice.message.tool_calls {
|
|
if !tool_calls.is_empty() {
|
|
let calls: Vec<LlmToolCall> = tool_calls
|
|
.iter()
|
|
.map(|tc| {
|
|
let arguments = serde_json::from_str(&tc.function.arguments)
|
|
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
|
|
LlmToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments,
|
|
}
|
|
})
|
|
.collect();
|
|
// Capture any reasoning text the LLM included alongside tool calls
|
|
let reasoning = choice.message.content.clone().unwrap_or_default();
|
|
return Ok(LlmResponse::ToolCalls { calls, reasoning });
|
|
}
|
|
}
|
|
|
|
// Otherwise return content
|
|
let content = choice.message.content.clone().unwrap_or_default();
|
|
Ok(LlmResponse::Content(content))
|
|
}
|
|
}
|