From 35bf7aafd83bb72e69ca7b9884ee357c0323057c Mon Sep 17 00:00:00 2001 From: wjian23 Date: Thu, 8 Feb 2024 16:37:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=90=9B=20fix=20panic=20when=20serv?= =?UTF-8?q?er=20response=20error=20message?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 4 +-- core/Cargo.toml | 3 +- core/src/error.rs | 55 ++++++++++++++++++++++++++++++++++++- core/src/service/open_ai.rs | 38 +++++++++---------------- core/src/service/req.rs | 35 +++++++++++++---------- gui/Cargo.toml | 3 +- gui/src/req.rs | 7 +---- gui/src/widgets/helper.rs | 8 +++++- 8 files changed, 102 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6e150dd..aacb7f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,14 @@ description = "An AI Q&A chat util written using Rust" [workspace.dependencies] tokio = { version = "1.20.0", features = ["macros", "rt-multi-thread"] } -reqwest = { version = "0.11.18", features = ["json"] } +reqwest = { version = "0.11.18", features = ["json", "stream"] } serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.107" serde_json_path = "0.6.4" log = "0.4.20" uuid = { version = "1.3.3", features = ["v4", "serde", "fast-rng"] } env_logger = "0.10.0" -reqwest-eventsource = "0.5.0" +eventsource-stream = { version ="0.2.3", features = ["std"] } once_cell = "1.19.0" chrono = { version = "0.4.19", features = ["clock", "serde"] } sqlx = { version = "0.7.3", features = ["sqlite", "runtime-tokio", "tls-native-tls", "uuid", "json", "chrono" ] } diff --git a/core/Cargo.toml b/core/Cargo.toml index 4158ceb..e10b784 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,7 +14,7 @@ log = "0.4.20" home = { version = "0.5.5" } thiserror = "1.0.49" reqwest = { version = "0.11.18", features = ["json"] } -reqwest-eventsource.workspace = true +eventsource-stream.workspace = true futures-util = "0.3.28" uuid = { version = "1.3.3", features = ["v4", "serde", "fast-rng"] } @@ -29,6 +29,7 @@ once_cell.workspace = true chrono.workspace = true serde_json_path.workspace = true + [dev-dependencies] testing_logger = "0.1.1" diff --git a/core/src/error.rs b/core/src/error.rs index 76e8a7c..d9abc15 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -1,3 +1,4 @@ +use serde::{Deserialize, Serialize}; use thiserror::Error; #[derive(Error, Debug)] @@ -9,7 +10,7 @@ pub enum PolestarError { #[error("reqwest error: {0}")] Reqwest(#[from] reqwest::Error), #[error("eventsource error: {0}")] - EventSource(#[from] reqwest_eventsource::Error), + EventSource(#[from] eventsource_stream::EventStreamError), #[error("database not found")] DatabaseNotFound, #[error("database error: {0}")] @@ -18,6 +19,58 @@ pub enum PolestarError { UTF8(#[from] std::string::FromUtf8Error), #[error("Token not found")] TokenNotFound, + #[error("{}: {}.", .0.message, "Please try again later or contact us at Discord")] + PolestarServerError(PolestarServerError), +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct PolestarServerError { + pub kind: PolestarServerErrType, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, PartialOrd)] +pub enum PolestarServerErrType { + /// [no retry] unauthorized access to the token + UnAuthed, + /// [no retry] token expired need to refresh + Expires, + /// [no retry] quota been exceeded + OverQuota, + /// [no retry] request parameter invalid + InvalidContent, + /// [no retry] server internal error + ServerError, + /// [retry] response service trigger network error + NetWork, + /// [no retry] not found the resource + NotFound, + /// [no retry] undefined server error can't retry + Unknown, + /// [retry] local request timeout can retry + TimedOut, + /// [no retry] internal error + InternalError, + /// [no retry] attachment not found + AttachmentNotFound, +} + +impl ToString for PolestarServerErrType { + fn to_string(&self) -> String { + match self { + PolestarServerErrType::UnAuthed => "AppError UnAuthed".to_string(), + PolestarServerErrType::Expires => "AppError Expires".to_string(), + PolestarServerErrType::OverQuota => "AppError OverQuota".to_string(), + PolestarServerErrType::InvalidContent => "AppError InvalidContent".to_string(), + PolestarServerErrType::ServerError => "AppError ServerError".to_string(), + PolestarServerErrType::NetWork => "AppError NetWork".to_string(), + PolestarServerErrType::NotFound => "AppError NotFound".to_string(), + PolestarServerErrType::Unknown => "AppError Unknown".to_string(), + PolestarServerErrType::TimedOut => "AppError TimedOut".to_string(), + PolestarServerErrType::InternalError => "AppError InternalError".to_string(), + PolestarServerErrType::AttachmentNotFound => "AppError AttachmentNotFound".to_string(), + } + } } pub type PolestarResult = Result; diff --git a/core/src/service/open_ai.rs b/core/src/service/open_ai.rs index 96c3506..44b5659 100644 --- a/core/src/service/open_ai.rs +++ b/core/src/service/open_ai.rs @@ -1,5 +1,5 @@ +use eventsource_stream::Event; use futures_util::{Stream, StreamExt}; -use reqwest_eventsource::Event; use serde::{Deserialize, Serialize}; use crate::{error::PolestarError, model::MsgRole}; @@ -55,7 +55,7 @@ impl From for Role { } pub async fn deal_open_ai_stream( - stream: &mut (impl Stream> + Unpin), + stream: &mut (impl Stream> + Unpin), mut delta_op: impl FnMut(String), ) -> Result { let mut answer = String::default(); @@ -91,7 +91,7 @@ pub fn mock_stream_string(_content: &str, mut delta_op: impl FnMut(String)) { } async fn stream_event_source_handler( - stream: &mut (impl Stream> + Unpin), + stream: &mut (impl Stream> + Unpin), ) -> Result, PolestarError> { let terminated = "[DONE]"; let chunk_size = 256; @@ -101,30 +101,18 @@ async fn stream_event_source_handler( let mut delta = String::default(); for item in items { - match item { - Ok(event) => { - if let Event::Message(event) = event { - if event.data == terminated { - break; - } - let obj = - serde_json::from_str::(&event.data).unwrap(); - let choices = obj.choices; - assert!(choices.len() == 1); + let data = item?.data; + if data == terminated { + break; + } + let obj = serde_json::from_str::(&data).unwrap(); + let choices = obj.choices; + assert!(choices.len() == 1); - if let Some(content) = &choices[0].delta.content { - delta.push_str(content); - } - } - } - Err(reqwest_eventsource::Error::StreamEnded) => match delta.is_empty() { - true => return Ok(None), - false => return Ok(Some(delta)), - }, - Err(e) => { - return Err(PolestarError::EventSource(e)); - } + if let Some(content) = &choices[0].delta.content { + delta.push_str(content); } } + Ok(Some(delta)) } diff --git a/core/src/service/req.rs b/core/src/service/req.rs index 711d2a4..2c0c481 100644 --- a/core/src/service/req.rs +++ b/core/src/service/req.rs @@ -1,17 +1,17 @@ -use futures_util::StreamExt; +use eventsource_stream::{Event, Eventsource}; +use futures_util::{Stream, TryStreamExt}; use log::warn; use regex::Regex; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT}, - Method, RequestBuilder, + Method, RequestBuilder, StatusCode, }; -use reqwest_eventsource::EventSource; use serde::{Deserialize, Serialize}; use serde_json::{json, value::Value as JsonValue}; use serde_json_path::JsonPath; use crate::{ - error::{PolestarError, PolestarResult}, + error::{PolestarError, PolestarResult, PolestarServerError}, model::{ AppInfo, Bot, BotId, Channel, FeedbackMessageListForServer, FeedbackTimestamp, GlbVar, Quota, ServerProvider, UserFeedbackMessageForServer, GLOBAL_VARS, @@ -44,14 +44,20 @@ async fn req_stream( method: Method, headers: HeaderMap, body: Option, -) -> Result { +) -> Result>, PolestarError> { let req_builder = req_builder(&url, method, headers, body); - let mut stream = EventSource::new(req_builder).unwrap(); - let stream_resp = stream.next().await; - if let Some(Err(err)) = stream_resp { - return Err(PolestarError::EventSource(err)); + let resp = req_builder.send().await?; + let content = resp.headers().get("content-type"); + let stream_content = "text/event-stream"; + let content_type = content.and_then(|t| t.to_str().ok()); + if resp.status() == StatusCode::OK && content_type == Some(stream_content) { + let eventsource = resp.bytes_stream().eventsource(); + Ok(eventsource.map_err(|e| e.into())) + } else { + Err(PolestarError::PolestarServerError( + resp.json::().await?, + )) } - Ok(stream) } pub fn create_text_request(info: &AppInfo, bot_id: BotId) -> TextStreamReq { @@ -205,7 +211,10 @@ pub struct TextStreamReq { } impl TextStreamReq { - pub async fn request(self, body: String) -> Result { + pub async fn request( + self, + body: String, + ) -> Result>, PolestarError> { req_stream( self.url.clone(), Method::POST, @@ -278,9 +287,7 @@ fn to_value_str(val: &JsonValue) -> String { #[derive(Debug, Serialize, Deserialize)] pub struct UserQuota { - pub user_id: u64, - pub limits: f32, - pub used: f32, + // pub user_id: u64, pub statistics: serde_json::Value, } diff --git a/gui/Cargo.toml b/gui/Cargo.toml index 268f26f..dc85afd 100644 --- a/gui/Cargo.toml +++ b/gui/Cargo.toml @@ -15,7 +15,8 @@ serde.workspace = true serde_json.workspace = true uuid.workspace = true reqwest.workspace = true -reqwest-eventsource.workspace = true + +eventsource-stream.workspace = true once_cell.workspace = true tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/gui/src/req.rs b/gui/src/req.rs index f5034b1..d1b5586 100644 --- a/gui/src/req.rs +++ b/gui/src/req.rs @@ -19,12 +19,7 @@ pub async fn query_open_ai( println!("request content: {}", content); - let mut stream = req - .request(content) - .to_ribir_future() - .await - .unwrap() - .to_ribir_stream(); + let mut stream = req.request(content).to_ribir_future().await?; deal_open_ai_stream(&mut stream, delta_op).await } diff --git a/gui/src/widgets/helper.rs b/gui/src/widgets/helper.rs index f39efef..737b0d8 100644 --- a/gui/src/widgets/helper.rs +++ b/gui/src/widgets/helper.rs @@ -51,11 +51,17 @@ pub fn send_msg( }) .unwrap_or(content); - let _ = query_open_ai(chat.map_reader(|chat| chat.info()), bot_id, text, |delta| { + let res = query_open_ai(chat.map_reader(|chat| chat.info()), bot_id, text, |delta| { update_msg(MsgAction::Receiving(MsgBody::Text(Some(delta)))); }) .await; + if let Err(e) = res { + update_msg(MsgAction::Receiving(MsgBody::Text(Some(format!( + "Error: {}", + e + ))))); + } update_msg(MsgAction::Fulfilled); }); }