diff --git a/src/app.rs b/src/app.rs index 4d3f658..3eda5bc 100644 --- a/src/app.rs +++ b/src/app.rs @@ -15,7 +15,7 @@ use crate::{ database::{DatabaseManager, FileDatabaseProvider, InMemDatabaseProvider}, dispatcher::build_dispatcher, module_mgr::ModuleManager, - modules::{admin::Admin, chat::Chat, prefs::Prefs, stats::Stats}, + modules::{admin::Admin, chat::Chat, openai::OpenAI, prefs::Prefs, stats::Stats}, types::HandlerResult, }; @@ -55,6 +55,7 @@ pub async fn run(config: SharedConfig) { debug!("Initializing modules..."); let mut module_mgr = ModuleManager::new(); module_mgr.register_module(crate::modules::config::Config::new(config.clone())); + module_mgr.register_module(OpenAI); module_mgr.register_module(Prefs::new(db_mgr.clone())); module_mgr.register_module(Admin::new(db_mgr.clone())); module_mgr.register_module(Stats::new(db_mgr.clone())); diff --git a/src/config.rs b/src/config.rs index b552420..4d663cf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -86,6 +86,11 @@ pub struct Config { #[serde(default = "default_conversation_limit", rename = "conversationLimit")] pub conversation_limit: u64, + /// The maximum number of tokens allowed for the generated answer. + /// JSON key: `maxTokens` + #[serde(default, rename = "maxTokens")] + pub max_tokens: Option, + /// A boolean value that indicates whether to parse and render the /// markdown contents. When set to `false`, the raw contents returned /// from OpenAI will be displayed. This is default to `false`. diff --git a/src/modules/chat/mod.rs b/src/modules/chat/mod.rs index 874463b..b299f00 100644 --- a/src/modules/chat/mod.rs +++ b/src/modules/chat/mod.rs @@ -2,7 +2,6 @@ mod braille; mod markdown; -mod openai_client; mod session; mod session_mgr; @@ -11,7 +10,6 @@ use std::time::Duration; use anyhow::Error; use async_openai::types::{ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs, Role}; -use async_openai::Client as OpenAIClient; use futures::StreamExt as FuturesStreamExt; use teloxide::dispatching::DpHandlerDescription; use teloxide::dptree::di::DependencySupplier; @@ -22,12 +20,12 @@ use crate::{ config::SharedConfig, dispatcher::noop_handler, module_mgr::{Command, Module}, + modules::openai::{ChatModelResult, OpenAIClient}, modules::{admin::MemberManager, stats::StatsManager}, types::HandlerResult, utils::StreamExt, }; use braille::BrailleProgress; -use openai_client::ChatModelResult; pub(crate) use session::Session; pub(crate) use session_mgr::SessionManager; @@ -328,9 +326,9 @@ async fn stream_model_result( openai_client: OpenAIClient, config: &SharedConfig, ) -> Result { - let estimated_prompt_tokens = openai_client::estimate_prompt_tokens(&msgs); + let estimated_prompt_tokens = openai_client.estimate_prompt_tokens(&msgs); - let stream = openai_client::request_chat_model(&openai_client, msgs).await?; + let stream = openai_client.request_chat_model(msgs).await?; let mut throttled_stream = stream.throttle_buffer::>(Duration::from_millis(config.stream_throttle_interval)); @@ -379,7 +377,7 @@ async fn stream_model_result( // TODO: OpenAI currently doesn't support to give the token usage // in stream mode. Therefore we need to estimate it locally. last_response.token_usage = - openai_client::estimate_tokens(&last_response.content) + estimated_prompt_tokens; + openai_client.estimate_tokens(&last_response.content) + estimated_prompt_tokens; return Ok(last_response); } @@ -407,7 +405,6 @@ impl Module for Chat { let config: Arc = dep_map.get(); dep_map.insert(SessionManager::new(config.as_ref().clone())); - dep_map.insert(openai_client::new_client(&config.openai_api_key)); Ok(()) } diff --git a/src/modules/chat/openai_client.rs b/src/modules/chat/openai_client.rs deleted file mode 100644 index 45b4026..0000000 --- a/src/modules/chat/openai_client.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::pin::Pin; - -use anyhow::Error; -use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs}; -use async_openai::Client as OpenAIClient; -use futures::{future, Stream, StreamExt}; - -pub(crate) type ChatModelStream = Pin + Send>>; - -#[derive(Clone, Debug, Default, Eq, PartialEq)] -pub(crate) struct ChatModelResult { - pub content: String, - pub token_usage: u32, -} - -pub(crate) fn new_client(api_key: &str) -> OpenAIClient { - OpenAIClient::new().with_api_key(api_key) -} - -pub(crate) async fn request_chat_model( - client: &OpenAIClient, - msgs: Vec, -) -> Result { - let req = CreateChatCompletionRequestArgs::default() - .model("gpt-3.5-turbo") - .temperature(0.6) - .messages(msgs) - .build()?; - - let stream = client.chat().create_stream(req).await?; - Ok(stream - .scan(ChatModelResult::default(), |acc, cur| { - let content = cur - .as_ref() - .ok() - .and_then(|resp| resp.choices.first()) - .and_then(|choice| choice.delta.content.as_ref()); - if let Some(content) = content { - acc.content.push_str(content); - } - future::ready(Some(acc.clone())) - }) - .boxed()) -} - -pub(crate) fn estimate_prompt_tokens(msgs: &Vec) -> u32 { - let mut text_len = 0; - for msg in msgs { - text_len += msg.content.len(); - } - ((text_len as f64) * 1.4) as _ -} - -pub(crate) fn estimate_tokens(text: &str) -> u32 { - let text_len = text.len(); - ((text_len as f64) * 1.4) as _ -} diff --git a/src/modules/mod.rs b/src/modules/mod.rs index c25dc65..cf0d37b 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -3,5 +3,6 @@ pub(crate) mod admin; pub(crate) mod chat; pub(crate) mod config; +pub(crate) mod openai; pub(crate) mod prefs; pub(crate) mod stats; diff --git a/src/modules/openai.rs b/src/modules/openai.rs new file mode 100644 index 0000000..43c1b0c --- /dev/null +++ b/src/modules/openai.rs @@ -0,0 +1,84 @@ +use std::pin::Pin; +use std::sync::Arc; + +use anyhow::Error; +use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs}; +use async_openai::Client; +use futures::{future, Stream, StreamExt}; +use teloxide::dptree::di::{DependencyMap, DependencySupplier}; + +use crate::{config::SharedConfig, module_mgr::Module}; + +pub(crate) type ChatModelStream = Pin + Send>>; + +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub(crate) struct ChatModelResult { + pub content: String, + pub token_usage: u32, +} + +#[derive(Clone)] +pub(crate) struct OpenAIClient { + client: Client, + config: SharedConfig, +} + +impl OpenAIClient { + pub(crate) async fn request_chat_model( + &self, + msgs: Vec, + ) -> Result { + let client = &self.client; + let req = CreateChatCompletionRequestArgs::default() + .model("gpt-3.5-turbo") + .temperature(0.6) + .max_tokens(self.config.max_tokens.unwrap_or(4096)) + .messages(msgs) + .build()?; + + let stream = client.chat().create_stream(req).await?; + Ok(stream + .scan(ChatModelResult::default(), |acc, cur| { + let content = cur + .as_ref() + .ok() + .and_then(|resp| resp.choices.first()) + .and_then(|choice| choice.delta.content.as_ref()); + if let Some(content) = content { + acc.content.push_str(content); + } + future::ready(Some(acc.clone())) + }) + .boxed()) + } + + pub(crate) fn estimate_prompt_tokens(&self, msgs: &Vec) -> u32 { + let mut text_len = 0; + for msg in msgs { + text_len += msg.content.len(); + } + ((text_len as f64) * 1.4) as _ + } + + pub(crate) fn estimate_tokens(&self, text: &str) -> u32 { + let text_len = text.len(); + ((text_len as f64) * 1.4) as _ + } +} + +pub(crate) struct OpenAI; + +#[async_trait] +impl Module for OpenAI { + async fn register_dependency(&mut self, dep_map: &mut DependencyMap) -> Result<(), Error> { + let config: Arc = dep_map.get(); + + let openai_client = OpenAIClient { + client: Client::new().with_api_key(&config.openai_api_key), + config: config.as_ref().clone(), + }; + dep_map.insert(openai_client); + + Ok(()) + } +}