Skip to content

Commit

Permalink
extract OpenAI client as a separate module & add max_tokens setting
Browse files Browse the repository at this point in the history
  • Loading branch information
unixzii authored and ktiays committed Mar 17, 2023
1 parent ec5da5d commit d0386f3
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 65 deletions.
3 changes: 2 additions & 1 deletion src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
database::{DatabaseManager, FileDatabaseProvider, InMemDatabaseProvider},
dispatcher::build_dispatcher,
module_mgr::ModuleManager,
modules::{admin::Admin, chat::Chat, prefs::Prefs, stats::Stats},
modules::{admin::Admin, chat::Chat, openai::OpenAI, prefs::Prefs, stats::Stats},
types::HandlerResult,
};

Expand Down Expand Up @@ -55,6 +55,7 @@ pub async fn run(config: SharedConfig) {
debug!("Initializing modules...");
let mut module_mgr = ModuleManager::new();
module_mgr.register_module(crate::modules::config::Config::new(config.clone()));
module_mgr.register_module(OpenAI);
module_mgr.register_module(Prefs::new(db_mgr.clone()));
module_mgr.register_module(Admin::new(db_mgr.clone()));
module_mgr.register_module(Stats::new(db_mgr.clone()));
Expand Down
5 changes: 5 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ pub struct Config {
#[serde(default = "default_conversation_limit", rename = "conversationLimit")]
pub conversation_limit: u64,

/// The maximum number of tokens allowed for the generated answer.
/// JSON key: `maxTokens`
#[serde(default, rename = "maxTokens")]
pub max_tokens: Option<u16>,

/// A boolean value that indicates whether to parse and render the
/// markdown contents. When set to `false`, the raw contents returned
/// from OpenAI will be displayed. This is default to `false`.
Expand Down
11 changes: 4 additions & 7 deletions src/modules/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

mod braille;
mod markdown;
mod openai_client;
mod session;
mod session_mgr;

Expand All @@ -11,7 +10,6 @@ use std::time::Duration;

use anyhow::Error;
use async_openai::types::{ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs, Role};
use async_openai::Client as OpenAIClient;
use futures::StreamExt as FuturesStreamExt;
use teloxide::dispatching::DpHandlerDescription;
use teloxide::dptree::di::DependencySupplier;
Expand All @@ -22,12 +20,12 @@ use crate::{
config::SharedConfig,
dispatcher::noop_handler,
module_mgr::{Command, Module},
modules::openai::{ChatModelResult, OpenAIClient},
modules::{admin::MemberManager, stats::StatsManager},
types::HandlerResult,
utils::StreamExt,
};
use braille::BrailleProgress;
use openai_client::ChatModelResult;
pub(crate) use session::Session;
pub(crate) use session_mgr::SessionManager;

Expand Down Expand Up @@ -328,9 +326,9 @@ async fn stream_model_result(
openai_client: OpenAIClient,
config: &SharedConfig,
) -> Result<ChatModelResult, Error> {
let estimated_prompt_tokens = openai_client::estimate_prompt_tokens(&msgs);
let estimated_prompt_tokens = openai_client.estimate_prompt_tokens(&msgs);

let stream = openai_client::request_chat_model(&openai_client, msgs).await?;
let stream = openai_client.request_chat_model(msgs).await?;
let mut throttled_stream =
stream.throttle_buffer::<Vec<_>>(Duration::from_millis(config.stream_throttle_interval));

Expand Down Expand Up @@ -379,7 +377,7 @@ async fn stream_model_result(
// TODO: OpenAI currently doesn't support to give the token usage
// in stream mode. Therefore we need to estimate it locally.
last_response.token_usage =
openai_client::estimate_tokens(&last_response.content) + estimated_prompt_tokens;
openai_client.estimate_tokens(&last_response.content) + estimated_prompt_tokens;

return Ok(last_response);
}
Expand Down Expand Up @@ -407,7 +405,6 @@ impl Module for Chat {
let config: Arc<SharedConfig> = dep_map.get();

dep_map.insert(SessionManager::new(config.as_ref().clone()));
dep_map.insert(openai_client::new_client(&config.openai_api_key));

Ok(())
}
Expand Down
57 changes: 0 additions & 57 deletions src/modules/chat/openai_client.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
pub(crate) mod admin;
pub(crate) mod chat;
pub(crate) mod config;
pub(crate) mod openai;
pub(crate) mod prefs;
pub(crate) mod stats;
84 changes: 84 additions & 0 deletions src/modules/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::pin::Pin;
use std::sync::Arc;

use anyhow::Error;
use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs};
use async_openai::Client;
use futures::{future, Stream, StreamExt};
use teloxide::dptree::di::{DependencyMap, DependencySupplier};

use crate::{config::SharedConfig, module_mgr::Module};

pub(crate) type ChatModelStream = Pin<Box<dyn Stream<Item = ChatModelResult> + Send>>;

#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub(crate) struct ChatModelResult {
pub content: String,
pub token_usage: u32,
}

#[derive(Clone)]
pub(crate) struct OpenAIClient {
client: Client,
config: SharedConfig,
}

impl OpenAIClient {
pub(crate) async fn request_chat_model(
&self,
msgs: Vec<ChatCompletionRequestMessage>,
) -> Result<ChatModelStream, Error> {
let client = &self.client;
let req = CreateChatCompletionRequestArgs::default()
.model("gpt-3.5-turbo")
.temperature(0.6)
.max_tokens(self.config.max_tokens.unwrap_or(4096))
.messages(msgs)
.build()?;

let stream = client.chat().create_stream(req).await?;
Ok(stream
.scan(ChatModelResult::default(), |acc, cur| {
let content = cur
.as_ref()
.ok()
.and_then(|resp| resp.choices.first())
.and_then(|choice| choice.delta.content.as_ref());
if let Some(content) = content {
acc.content.push_str(content);
}
future::ready(Some(acc.clone()))
})
.boxed())
}

pub(crate) fn estimate_prompt_tokens(&self, msgs: &Vec<ChatCompletionRequestMessage>) -> u32 {
let mut text_len = 0;
for msg in msgs {
text_len += msg.content.len();
}
((text_len as f64) * 1.4) as _
}

pub(crate) fn estimate_tokens(&self, text: &str) -> u32 {
let text_len = text.len();
((text_len as f64) * 1.4) as _
}
}

pub(crate) struct OpenAI;

#[async_trait]
impl Module for OpenAI {
async fn register_dependency(&mut self, dep_map: &mut DependencyMap) -> Result<(), Error> {
let config: Arc<SharedConfig> = dep_map.get();

let openai_client = OpenAIClient {
client: Client::new().with_api_key(&config.openai_api_key),
config: config.as_ref().clone(),
};
dep_map.insert(openai_client);

Ok(())
}
}

0 comments on commit d0386f3

Please sign in to comment.