use serde::{Deserialize, Serialize}; use super::client::LlmClient; use crate::error::AgentError; // ── Embedding types ──────────────────────────────────────────── #[derive(Serialize)] struct EmbeddingRequest { model: String, input: Vec, } #[derive(Deserialize)] struct EmbeddingResponse { data: Vec, } #[derive(Deserialize)] struct EmbeddingData { embedding: Vec, index: usize, } // ── Embedding implementation ─────────────────────────────────── impl LlmClient { pub fn embed_model(&self) -> &str { &self.embed_model } /// Generate embeddings for a batch of texts pub async fn embed(&self, texts: Vec) -> Result>, AgentError> { let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/')); let request_body = EmbeddingRequest { model: self.embed_model.clone(), input: texts, }; let mut req = self .http .post(&url) .header("content-type", "application/json") .json(&request_body); if let Some(auth) = self.auth_header() { req = req.header("Authorization", auth); } let resp = req .send() .await .map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); return Err(AgentError::Other(format!( "Embedding API returned {status}: {body}" ))); } let body: EmbeddingResponse = resp .json() .await .map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?; let mut data = body.data; data.sort_by_key(|d| d.index); Ok(data.into_iter().map(|d| d.embedding).collect()) } }