feat: use librechat instead of own chat (#14)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #14
This commit was merged in pull request #14.
This commit is contained in:
@@ -1,266 +0,0 @@
|
||||
//! SSE streaming endpoint for chat completions.
|
||||
//!
|
||||
//! Exposes `GET /api/chat/stream?session_id=<id>` which:
|
||||
//! 1. Authenticates the user via tower-sessions
|
||||
//! 2. Loads the session and its messages from MongoDB
|
||||
//! 3. Streams LLM tokens as SSE events to the frontend
|
||||
//! 4. Persists the complete assistant message on finish
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse, Response,
|
||||
},
|
||||
Extension,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use tower_sessions::Session;
|
||||
|
||||
use super::{
|
||||
auth::LOGGED_IN_USER_SESS_KEY,
|
||||
chat::{doc_to_chat_message, doc_to_chat_session},
|
||||
provider_client::{send_chat_request, ProviderMessage},
|
||||
server_state::ServerState,
|
||||
state::UserStateInner,
|
||||
};
|
||||
use crate::models::{ChatMessage, ChatRole};
|
||||
|
||||
/// Query parameters for the SSE stream endpoint.
|
||||
#[derive(Deserialize)]
|
||||
pub struct StreamQuery {
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
/// SSE streaming handler for chat completions.
|
||||
///
|
||||
/// Reads the session's provider/model config, loads conversation history,
|
||||
/// sends to the LLM with `stream: true`, and forwards tokens as SSE events.
|
||||
///
|
||||
/// # SSE Event Format
|
||||
///
|
||||
/// - `data: {"token": "..."}` -- partial token
|
||||
/// - `data: {"done": true, "message_id": "..."}` -- stream complete
|
||||
/// - `data: {"error": "..."}` -- on failure
|
||||
pub async fn chat_stream_handler(
|
||||
session: Session,
|
||||
Extension(state): Extension<ServerState>,
|
||||
Query(params): Query<StreamQuery>,
|
||||
) -> Response {
|
||||
// Authenticate
|
||||
let user_state: Option<UserStateInner> = match session.get(LOGGED_IN_USER_SESS_KEY).await {
|
||||
Ok(u) => u,
|
||||
Err(_) => return (StatusCode::UNAUTHORIZED, "session error").into_response(),
|
||||
};
|
||||
let user = match user_state {
|
||||
Some(u) => u,
|
||||
None => return (StatusCode::UNAUTHORIZED, "not authenticated").into_response(),
|
||||
};
|
||||
|
||||
// Load session from MongoDB (raw document to handle ObjectId -> String)
|
||||
let chat_session = {
|
||||
use mongodb::bson::{doc, oid::ObjectId};
|
||||
let oid = match ObjectId::parse_str(¶ms.session_id) {
|
||||
Ok(o) => o,
|
||||
Err(_) => return (StatusCode::BAD_REQUEST, "invalid session_id").into_response(),
|
||||
};
|
||||
match state
|
||||
.db
|
||||
.raw_collection("chat_sessions")
|
||||
.find_one(doc! { "_id": oid, "user_sub": &user.sub })
|
||||
.await
|
||||
{
|
||||
Ok(Some(doc)) => doc_to_chat_session(&doc),
|
||||
Ok(None) => return (StatusCode::NOT_FOUND, "session not found").into_response(),
|
||||
Err(e) => {
|
||||
tracing::error!("db error loading session: {e}");
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, "db error").into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Load messages (raw documents to handle ObjectId -> String)
|
||||
let messages = {
|
||||
use mongodb::bson::doc;
|
||||
use mongodb::options::FindOptions;
|
||||
|
||||
let opts = FindOptions::builder().sort(doc! { "timestamp": 1 }).build();
|
||||
|
||||
match state
|
||||
.db
|
||||
.raw_collection("chat_messages")
|
||||
.find(doc! { "session_id": ¶ms.session_id })
|
||||
.with_options(opts)
|
||||
.await
|
||||
{
|
||||
Ok(mut cursor) => {
|
||||
use futures::TryStreamExt;
|
||||
let mut msgs = Vec::new();
|
||||
while let Some(doc) = TryStreamExt::try_next(&mut cursor).await.unwrap_or(None) {
|
||||
msgs.push(doc_to_chat_message(&doc));
|
||||
}
|
||||
msgs
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("db error loading messages: {e}");
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, "db error").into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Convert to provider format
|
||||
let provider_msgs: Vec<ProviderMessage> = messages
|
||||
.iter()
|
||||
.map(|m| ProviderMessage {
|
||||
role: match m.role {
|
||||
ChatRole::User => "user".to_string(),
|
||||
ChatRole::Assistant => "assistant".to_string(),
|
||||
ChatRole::System => "system".to_string(),
|
||||
},
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let provider = chat_session.provider.clone();
|
||||
let model = chat_session.model.clone();
|
||||
let session_id = params.session_id.clone();
|
||||
|
||||
// TODO: Load user's API key from preferences for non-Ollama providers.
|
||||
// For now, Ollama (no key needed) is the default path.
|
||||
let api_key: Option<String> = None;
|
||||
|
||||
// Send streaming request to LLM
|
||||
let llm_resp = match send_chat_request(
|
||||
&state,
|
||||
&provider,
|
||||
&model,
|
||||
&provider_msgs,
|
||||
api_key.as_deref(),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!("LLM request failed: {e}");
|
||||
return (StatusCode::BAD_GATEWAY, "LLM request failed").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if !llm_resp.status().is_success() {
|
||||
let status = llm_resp.status();
|
||||
let body = llm_resp.text().await.unwrap_or_default();
|
||||
tracing::error!("LLM returned {status}: {body}");
|
||||
return (StatusCode::BAD_GATEWAY, format!("LLM error: {status}")).into_response();
|
||||
}
|
||||
|
||||
// Stream the response bytes as SSE events
|
||||
let byte_stream = llm_resp.bytes_stream();
|
||||
let state_clone = state.clone();
|
||||
|
||||
let sse_stream = build_sse_stream(byte_stream, state_clone, session_id, provider.clone());
|
||||
|
||||
Sse::new(sse_stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Build an SSE stream that parses OpenAI-compatible streaming chunks
|
||||
/// and emits token events. On completion, persists the full message.
|
||||
fn build_sse_stream(
|
||||
byte_stream: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
|
||||
state: ServerState,
|
||||
session_id: String,
|
||||
_provider: String,
|
||||
) -> impl Stream<Item = Result<Event, std::convert::Infallible>> + Send + 'static {
|
||||
// Use an async stream to process chunks
|
||||
async_stream::stream! {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut full_content = String::new();
|
||||
let mut buffer = String::new();
|
||||
|
||||
// Pin the byte stream for iteration
|
||||
let mut stream = std::pin::pin!(byte_stream);
|
||||
|
||||
while let Some(chunk_result) = StreamExt::next(&mut stream).await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
let err_json = serde_json::json!({ "error": e.to_string() });
|
||||
yield Ok(Event::default().data(err_json.to_string()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
buffer.push_str(&text);
|
||||
|
||||
// Process complete SSE lines from the buffer.
|
||||
// OpenAI streaming format: `data: {...}\n\n`
|
||||
while let Some(line_end) = buffer.find('\n') {
|
||||
let line = buffer[..line_end].trim().to_string();
|
||||
buffer = buffer[line_end + 1..].to_string();
|
||||
|
||||
if line.is_empty() || line == "data: [DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||
// Extract token from OpenAI delta format
|
||||
if let Some(token) = parsed["choices"][0]["delta"]["content"].as_str() {
|
||||
full_content.push_str(token);
|
||||
let event_data = serde_json::json!({ "token": token });
|
||||
yield Ok(Event::default().data(event_data.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Persist the complete assistant message
|
||||
if !full_content.is_empty() {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let message = ChatMessage {
|
||||
id: String::new(),
|
||||
session_id: session_id.clone(),
|
||||
role: ChatRole::Assistant,
|
||||
content: full_content,
|
||||
attachments: Vec::new(),
|
||||
timestamp: now.clone(),
|
||||
};
|
||||
|
||||
let msg_id = match state.db.chat_messages().insert_one(&message).await {
|
||||
Ok(result) => result
|
||||
.inserted_id
|
||||
.as_object_id()
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default(),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to persist assistant message: {e}");
|
||||
String::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Update session timestamp
|
||||
if let Ok(session_oid) =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&session_id)
|
||||
{
|
||||
let _ = state
|
||||
.db
|
||||
.chat_sessions()
|
||||
.update_one(
|
||||
mongodb::bson::doc! { "_id": session_oid },
|
||||
mongodb::bson::doc! { "$set": { "updated_at": &now } },
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let done_data = serde_json::json!({ "done": true, "message_id": msg_id });
|
||||
yield Ok(Event::default().data(done_data.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,8 +12,6 @@ mod auth;
|
||||
#[cfg(feature = "server")]
|
||||
mod auth_middleware;
|
||||
#[cfg(feature = "server")]
|
||||
mod chat_stream;
|
||||
#[cfg(feature = "server")]
|
||||
pub mod config;
|
||||
#[cfg(feature = "server")]
|
||||
pub mod database;
|
||||
@@ -33,8 +31,6 @@ pub use auth::*;
|
||||
#[cfg(feature = "server")]
|
||||
pub use auth_middleware::*;
|
||||
#[cfg(feature = "server")]
|
||||
pub use chat_stream::*;
|
||||
#[cfg(feature = "server")]
|
||||
pub use error::*;
|
||||
#[cfg(feature = "server")]
|
||||
pub use server::*;
|
||||
|
||||
@@ -6,7 +6,7 @@ use time::Duration;
|
||||
use tower_sessions::{cookie::Key, MemoryStore, SessionManagerLayer};
|
||||
|
||||
use crate::infrastructure::{
|
||||
auth_callback, auth_login, chat_stream_handler,
|
||||
auth_callback, auth_login,
|
||||
config::{KeycloakConfig, LlmProvidersConfig, ServiceUrls, SmtpConfig, StripeConfig},
|
||||
database::Database,
|
||||
logout, require_auth,
|
||||
@@ -82,7 +82,6 @@ pub fn server_start(app: fn() -> Element) -> Result<(), super::Error> {
|
||||
.route("/auth", get(auth_login))
|
||||
.route("/auth/callback", get(auth_callback))
|
||||
.route("/logout", get(logout))
|
||||
.route("/api/chat/stream", get(chat_stream_handler))
|
||||
.serve_dioxus_application(ServeConfig::new(), app)
|
||||
.layer(Extension(PendingOAuthStore::default()))
|
||||
.layer(Extension(server_state))
|
||||
|
||||
Reference in New Issue
Block a user