use secrecy::{ExposeSecret, SecretString}; use super::types::*; use crate::error::AgentError; #[derive(Clone)] pub struct LlmClient { pub(crate) base_url: String, pub(crate) api_key: SecretString, pub(crate) model: String, pub(crate) embed_model: String, pub(crate) http: reqwest::Client, } impl LlmClient { pub fn new( base_url: String, api_key: SecretString, model: String, embed_model: String, ) -> Self { Self { base_url, api_key, model, embed_model, http: reqwest::Client::new(), } } pub(crate) fn chat_url(&self) -> String { format!( "{}/v1/chat/completions", self.base_url.trim_end_matches('/') ) } pub(crate) fn auth_header(&self) -> Option { let key = self.api_key.expose_secret(); if key.is_empty() { None } else { Some(format!("Bearer {key}")) } } /// Simple chat: system + user prompt → text response pub async fn chat( &self, system_prompt: &str, user_prompt: &str, temperature: Option, ) -> Result { let messages = vec![ ChatMessage { role: "system".to_string(), content: Some(system_prompt.to_string()), tool_calls: None, tool_call_id: None, }, ChatMessage { role: "user".to_string(), content: Some(user_prompt.to_string()), tool_calls: None, tool_call_id: None, }, ]; let request_body = ChatCompletionRequest { model: self.model.clone(), messages, temperature, max_tokens: Some(4096), tools: None, }; self.send_chat_request(&request_body).await.map(|resp| { match resp { LlmResponse::Content(c) => c, LlmResponse::ToolCalls { .. } => String::new(), // shouldn't happen without tools } }) } /// Chat with a list of (role, content) messages → text response #[allow(dead_code)] pub async fn chat_with_messages( &self, messages: Vec<(String, String)>, temperature: Option, ) -> Result { let messages = messages .into_iter() .map(|(role, content)| ChatMessage { role, content: Some(content), tool_calls: None, tool_call_id: None, }) .collect(); let request_body = ChatCompletionRequest { model: self.model.clone(), messages, temperature, max_tokens: Some(4096), tools: None, }; self.send_chat_request(&request_body) .await .map(|resp| match resp { LlmResponse::Content(c) => c, LlmResponse::ToolCalls { .. } => String::new(), }) } /// Chat with tool definitions — returns either content or tool calls. /// Use this for the AI pentest orchestrator loop. pub async fn chat_with_tools( &self, messages: Vec, tools: &[ToolDefinition], temperature: Option, max_tokens: Option, ) -> Result { let tool_payloads: Vec = tools .iter() .map(|t| ToolDefinitionPayload { r#type: "function".to_string(), function: ToolFunctionPayload { name: t.name.clone(), description: t.description.clone(), parameters: t.parameters.clone(), }, }) .collect(); let request_body = ChatCompletionRequest { model: self.model.clone(), messages, temperature, max_tokens: Some(max_tokens.unwrap_or(8192)), tools: if tool_payloads.is_empty() { None } else { Some(tool_payloads) }, }; self.send_chat_request(&request_body).await } /// Internal method to send a chat completion request and parse the response async fn send_chat_request( &self, request_body: &ChatCompletionRequest, ) -> Result { let mut req = self .http .post(self.chat_url()) .header("content-type", "application/json") .json(request_body); if let Some(auth) = self.auth_header() { req = req.header("Authorization", auth); } let resp = req .send() .await .map_err(|e| AgentError::Other(format!("LiteLLM request failed: {e}")))?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); return Err(AgentError::Other(format!( "LiteLLM returned {status}: {body}" ))); } let body: ChatCompletionResponse = resp .json() .await .map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?; let choice = body .choices .first() .ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))?; // Check for tool calls first if let Some(tool_calls) = &choice.message.tool_calls { if !tool_calls.is_empty() { let calls: Vec = tool_calls .iter() .map(|tc| { let arguments = serde_json::from_str(&tc.function.arguments) .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); LlmToolCall { id: tc.id.clone(), name: tc.function.name.clone(), arguments, } }) .collect(); // Capture any reasoning text the LLM included alongside tool calls let reasoning = choice.message.content.clone().unwrap_or_default(); return Ok(LlmResponse::ToolCalls { calls, reasoning }); } } // Otherwise return content let content = choice.message.content.clone().unwrap_or_default(); Ok(LlmResponse::Content(content)) } }