Skip to content

Commit

Permalink
fix: 🐛 fix panic when server response error message
Browse files Browse the repository at this point in the history
  • Loading branch information
wjian23 committed Feb 8, 2024
1 parent 6d573fc commit 35bf7aa
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 51 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ description = "An AI Q&A chat util written using Rust"

[workspace.dependencies]
tokio = { version = "1.20.0", features = ["macros", "rt-multi-thread"] }
reqwest = { version = "0.11.18", features = ["json"] }
reqwest = { version = "0.11.18", features = ["json", "stream"] }
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
serde_json_path = "0.6.4"
log = "0.4.20"
uuid = { version = "1.3.3", features = ["v4", "serde", "fast-rng"] }
env_logger = "0.10.0"
reqwest-eventsource = "0.5.0"
eventsource-stream = { version ="0.2.3", features = ["std"] }
once_cell = "1.19.0"
chrono = { version = "0.4.19", features = ["clock", "serde"] }
sqlx = { version = "0.7.3", features = ["sqlite", "runtime-tokio", "tls-native-tls", "uuid", "json", "chrono" ] }
3 changes: 2 additions & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ log = "0.4.20"
home = { version = "0.5.5" }
thiserror = "1.0.49"
reqwest = { version = "0.11.18", features = ["json"] }
reqwest-eventsource.workspace = true
eventsource-stream.workspace = true

futures-util = "0.3.28"
uuid = { version = "1.3.3", features = ["v4", "serde", "fast-rng"] }
Expand All @@ -29,6 +29,7 @@ once_cell.workspace = true
chrono.workspace = true
serde_json_path.workspace = true


[dev-dependencies]
testing_logger = "0.1.1"

Expand Down
55 changes: 54 additions & 1 deletion core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use serde::{Deserialize, Serialize};
use thiserror::Error;

#[derive(Error, Debug)]
Expand All @@ -9,7 +10,7 @@ pub enum PolestarError {
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("eventsource error: {0}")]
EventSource(#[from] reqwest_eventsource::Error),
EventSource(#[from] eventsource_stream::EventStreamError<reqwest::Error>),
#[error("database not found")]
DatabaseNotFound,
#[error("database error: {0}")]
Expand All @@ -18,6 +19,58 @@ pub enum PolestarError {
UTF8(#[from] std::string::FromUtf8Error),
#[error("Token not found")]
TokenNotFound,
#[error("{}: {}.", .0.message, "Please try again later or contact us at Discord")]
PolestarServerError(PolestarServerError),
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PolestarServerError {
pub kind: PolestarServerErrType,
pub message: String,
}

#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, PartialOrd)]
pub enum PolestarServerErrType {
/// [no retry] unauthorized access to the token
UnAuthed,
/// [no retry] token expired need to refresh
Expires,
/// [no retry] quota been exceeded
OverQuota,
/// [no retry] request parameter invalid
InvalidContent,
/// [no retry] server internal error
ServerError,
/// [retry] response service trigger network error
NetWork,
/// [no retry] not found the resource
NotFound,
/// [no retry] undefined server error can't retry
Unknown,
/// [retry] local request timeout can retry
TimedOut,
/// [no retry] internal error
InternalError,
/// [no retry] attachment not found
AttachmentNotFound,
}

impl ToString for PolestarServerErrType {
fn to_string(&self) -> String {
match self {
PolestarServerErrType::UnAuthed => "AppError UnAuthed".to_string(),
PolestarServerErrType::Expires => "AppError Expires".to_string(),
PolestarServerErrType::OverQuota => "AppError OverQuota".to_string(),
PolestarServerErrType::InvalidContent => "AppError InvalidContent".to_string(),
PolestarServerErrType::ServerError => "AppError ServerError".to_string(),
PolestarServerErrType::NetWork => "AppError NetWork".to_string(),
PolestarServerErrType::NotFound => "AppError NotFound".to_string(),
PolestarServerErrType::Unknown => "AppError Unknown".to_string(),
PolestarServerErrType::TimedOut => "AppError TimedOut".to_string(),
PolestarServerErrType::InternalError => "AppError InternalError".to_string(),
PolestarServerErrType::AttachmentNotFound => "AppError AttachmentNotFound".to_string(),
}
}
}

pub type PolestarResult<T> = Result<T, PolestarError>;
38 changes: 13 additions & 25 deletions core/src/service/open_ai.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use eventsource_stream::Event;
use futures_util::{Stream, StreamExt};
use reqwest_eventsource::Event;
use serde::{Deserialize, Serialize};

use crate::{error::PolestarError, model::MsgRole};
Expand Down Expand Up @@ -55,7 +55,7 @@ impl From<MsgRole> for Role {
}

pub async fn deal_open_ai_stream(
stream: &mut (impl Stream<Item = Result<Event, reqwest_eventsource::Error>> + Unpin),
stream: &mut (impl Stream<Item = Result<Event, PolestarError>> + Unpin),
mut delta_op: impl FnMut(String),
) -> Result<String, PolestarError> {
let mut answer = String::default();
Expand Down Expand Up @@ -91,7 +91,7 @@ pub fn mock_stream_string(_content: &str, mut delta_op: impl FnMut(String)) {
}

async fn stream_event_source_handler(
stream: &mut (impl Stream<Item = Result<Event, reqwest_eventsource::Error>> + Unpin),
stream: &mut (impl Stream<Item = Result<Event, PolestarError>> + Unpin),
) -> Result<Option<String>, PolestarError> {
let terminated = "[DONE]";
let chunk_size = 256;
Expand All @@ -101,30 +101,18 @@ async fn stream_event_source_handler(

let mut delta = String::default();
for item in items {
match item {
Ok(event) => {
if let Event::Message(event) = event {
if event.data == terminated {
break;
}
let obj =
serde_json::from_str::<CreateChatCompletionStreamResponse>(&event.data).unwrap();
let choices = obj.choices;
assert!(choices.len() == 1);
let data = item?.data;
if data == terminated {
break;
}
let obj = serde_json::from_str::<CreateChatCompletionStreamResponse>(&data).unwrap();
let choices = obj.choices;
assert!(choices.len() == 1);

if let Some(content) = &choices[0].delta.content {
delta.push_str(content);
}
}
}
Err(reqwest_eventsource::Error::StreamEnded) => match delta.is_empty() {
true => return Ok(None),
false => return Ok(Some(delta)),
},
Err(e) => {
return Err(PolestarError::EventSource(e));
}
if let Some(content) = &choices[0].delta.content {
delta.push_str(content);
}
}

Ok(Some(delta))
}
35 changes: 21 additions & 14 deletions core/src/service/req.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use futures_util::StreamExt;
use eventsource_stream::{Event, Eventsource};
use futures_util::{Stream, TryStreamExt};
use log::warn;
use regex::Regex;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT},
Method, RequestBuilder,
Method, RequestBuilder, StatusCode,
};
use reqwest_eventsource::EventSource;
use serde::{Deserialize, Serialize};
use serde_json::{json, value::Value as JsonValue};
use serde_json_path::JsonPath;

use crate::{
error::{PolestarError, PolestarResult},
error::{PolestarError, PolestarResult, PolestarServerError},
model::{
AppInfo, Bot, BotId, Channel, FeedbackMessageListForServer, FeedbackTimestamp, GlbVar, Quota,
ServerProvider, UserFeedbackMessageForServer, GLOBAL_VARS,
Expand Down Expand Up @@ -44,14 +44,20 @@ async fn req_stream(
method: Method,
headers: HeaderMap,
body: Option<String>,
) -> Result<EventSource, PolestarError> {
) -> Result<impl Stream<Item = Result<Event, PolestarError>>, PolestarError> {
let req_builder = req_builder(&url, method, headers, body);
let mut stream = EventSource::new(req_builder).unwrap();
let stream_resp = stream.next().await;
if let Some(Err(err)) = stream_resp {
return Err(PolestarError::EventSource(err));
let resp = req_builder.send().await?;
let content = resp.headers().get("content-type");
let stream_content = "text/event-stream";
let content_type = content.and_then(|t| t.to_str().ok());
if resp.status() == StatusCode::OK && content_type == Some(stream_content) {
let eventsource = resp.bytes_stream().eventsource();
Ok(eventsource.map_err(|e| e.into()))
} else {
Err(PolestarError::PolestarServerError(
resp.json::<PolestarServerError>().await?,
))
}
Ok(stream)
}

pub fn create_text_request(info: &AppInfo, bot_id: BotId) -> TextStreamReq {
Expand Down Expand Up @@ -205,7 +211,10 @@ pub struct TextStreamReq {
}

impl TextStreamReq {
pub async fn request(self, body: String) -> Result<EventSource, PolestarError> {
pub async fn request(
self,
body: String,
) -> Result<impl Stream<Item = Result<Event, PolestarError>>, PolestarError> {
req_stream(
self.url.clone(),
Method::POST,
Expand Down Expand Up @@ -278,9 +287,7 @@ fn to_value_str(val: &JsonValue) -> String {

#[derive(Debug, Serialize, Deserialize)]
pub struct UserQuota {
pub user_id: u64,
pub limits: f32,
pub used: f32,
// pub user_id: u64,
pub statistics: serde_json::Value,
}

Expand Down
3 changes: 2 additions & 1 deletion gui/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ serde.workspace = true
serde_json.workspace = true
uuid.workspace = true
reqwest.workspace = true
reqwest-eventsource.workspace = true

eventsource-stream.workspace = true
once_cell.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }

Expand Down
7 changes: 1 addition & 6 deletions gui/src/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@ pub async fn query_open_ai(

println!("request content: {}", content);

let mut stream = req
.request(content)
.to_ribir_future()
.await
.unwrap()
.to_ribir_stream();
let mut stream = req.request(content).to_ribir_future().await?;

deal_open_ai_stream(&mut stream, delta_op).await
}
Expand Down
8 changes: 7 additions & 1 deletion gui/src/widgets/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ pub fn send_msg(
})
.unwrap_or(content);

let _ = query_open_ai(chat.map_reader(|chat| chat.info()), bot_id, text, |delta| {
let res = query_open_ai(chat.map_reader(|chat| chat.info()), bot_id, text, |delta| {
update_msg(MsgAction::Receiving(MsgBody::Text(Some(delta))));
})
.await;

if let Err(e) = res {
update_msg(MsgAction::Receiving(MsgBody::Text(Some(format!(
"Error: {}",
e
)))));
}
update_msg(MsgAction::Fulfilled);
});
}
Expand Down

0 comments on commit 35bf7aa

Please sign in to comment.