Skip to content

Commit

Permalink
refactor: improve retrieve model (#1036)
Browse files Browse the repository at this point in the history
- check the model type while retrieve model
- select chat/reranker model even if it is missed in client models
- find predefined-models for openai-compatible client with startsWith
- remove client::ApiType
  • Loading branch information
sigoden authored Dec 4, 2024
1 parent 7d42fe9 commit 3a33883
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 118 deletions.
4 changes: 2 additions & 2 deletions src/client/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl BedrockClient {
let body = build_chat_completions_body(data, &self.model)?;

let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data, ApiType::ChatCompletions);
self.patch_request_data(&mut request_data);
let RequestData {
url: _,
headers,
Expand Down Expand Up @@ -118,7 +118,7 @@ impl BedrockClient {
});

let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data, ApiType::Embeddings);
self.patch_request_data(&mut request_data);
let RequestData {
url: _,
headers,
Expand Down
38 changes: 7 additions & 31 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use tokio::sync::mpsc::unbounded_channel;
const MODELS_YAML: &str = include_str!("../../models.yaml");

lazy_static::lazy_static! {
pub static ref ALL_MODELS: Vec<BuiltinModels> = serde_yaml::from_str(MODELS_YAML).unwrap();
pub static ref ALL_PREDEFINED_MODELS: Vec<PredefinedModels> = serde_yaml::from_str(MODELS_YAML).unwrap();
static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(?<!\\)/").unwrap();
}

Expand Down Expand Up @@ -144,23 +144,23 @@ pub trait Client: Sync + Send {
&self,
client: &reqwest::Client,
mut request_data: RequestData,
api_type: ApiType,
) -> RequestBuilder {
self.patch_request_data(&mut request_data, api_type);
self.patch_request_data(&mut request_data);
request_data.into_builder(client)
}

fn patch_request_data(&self, request_data: &mut RequestData, api_type: ApiType) {
fn patch_request_data(&self, request_data: &mut RequestData) {
let model_type = self.model().model_type();
let map = std::env::var(get_env_name(&format!(
"patch_{}_{}",
self.model().client_name(),
api_type.name(),
model_type.api_name(),
)))
.ok()
.and_then(|v| serde_json::from_str(&v).ok())
.or_else(|| {
self.patch_config()
.and_then(|v| api_type.extract_patch(v))
.and_then(|v| model_type.extract_patch(v))
.cloned()
});
let map = match map {
Expand Down Expand Up @@ -200,30 +200,6 @@ pub struct RequestPatch {

pub type ApiPatch = IndexMap<String, Value>;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApiType {
ChatCompletions,
Embeddings,
Rerank,
}

impl ApiType {
pub fn name(&self) -> &str {
match self {
ApiType::ChatCompletions => "chat_completions",
ApiType::Embeddings => "embeddings",
ApiType::Rerank => "rerank",
}
}
pub fn extract_patch<'a>(&self, patch: &'a RequestPatch) -> Option<&'a ApiPatch> {
match self {
ApiType::ChatCompletions => patch.chat_completions.as_ref(),
ApiType::Embeddings => patch.embeddings.as_ref(),
ApiType::Rerank => patch.rerank.as_ref(),
}
}
}

pub struct RequestData {
pub url: String,
pub headers: IndexMap<String, String>,
Expand Down Expand Up @@ -383,7 +359,7 @@ pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(St
config["api_base"] = api_base.into();
}
prompts.push(("api_key", "API Key:", false, PromptKind::String));
if !ALL_MODELS.iter().any(|v| v.platform == name) {
if !ALL_PREDEFINED_MODELS.iter().any(|v| v.platform == name) {
prompts.extend([
("models[].name", "Model Name:", true, PromptKind::String),
(
Expand Down
8 changes: 4 additions & 4 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Client for ErnieClient {
) -> Result<ChatCompletionsOutput> {
prepare_access_token(self, client).await?;
let request_data = prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::ChatCompletions);
let builder = self.request_builder(client, request_data);
chat_completions(builder, &self.model).await
}

Expand All @@ -53,7 +53,7 @@ impl Client for ErnieClient {
) -> Result<()> {
prepare_access_token(self, client).await?;
let request_data = prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::ChatCompletions);
let builder = self.request_builder(client, request_data);
chat_completions_streaming(builder, handler, &self.model).await
}

Expand All @@ -64,7 +64,7 @@ impl Client for ErnieClient {
) -> Result<EmbeddingsOutput> {
prepare_access_token(self, client).await?;
let request_data = prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Embeddings);
let builder = self.request_builder(client, request_data);
embeddings(builder, &self.model).await
}

Expand All @@ -75,7 +75,7 @@ impl Client for ErnieClient {
) -> Result<RerankOutput> {
prepare_access_token(self, client).await?;
let request_data = prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Rerank);
let builder = self.request_builder(client, request_data);
rerank(builder, &self.model).await
}
}
Expand Down
45 changes: 27 additions & 18 deletions src/client/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ macro_rules! register_client {
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
if let Some(models) = $crate::client::ALL_MODELS.iter().find(|v| {
if let Some(models) = $crate::client::ALL_PREDEFINED_MODELS.iter().find(|v| {
v.platform == $name ||
($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform))
($name == OpenAICompatibleClient::NAME
&& local_config.name.as_ref().map(|name| name.starts_with(&v.platform)).unwrap_or_default())
}) {
return Model::from_config(client_name, &models.models);
}
Expand Down Expand Up @@ -98,32 +99,40 @@ macro_rules! register_client {
anyhow::bail!("Unknown client '{}'", client)
}

static ALL_CLIENT_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();

pub fn list_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
let models = ALL_CLIENT_MODELS.get_or_init(|| {
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
let names = ALL_CLIENT_NAMES.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
$(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+
ClientConfig::Unknown => vec![],
})
.collect()
});
models.iter().collect()
names.iter().collect()
}

pub fn list_chat_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
list_models(config).into_iter().filter(|v| v.model_type() == "chat").collect()
}
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();

pub fn list_embedding_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
list_models(config).into_iter().filter(|v| v.model_type() == "embedding").collect()
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
let models = ALL_MODELS.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
});
models.iter().collect()
}

pub fn list_reranker_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
list_models(config).into_iter().filter(|v| v.model_type() == "reranker").collect()
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
}
};
}
Expand Down Expand Up @@ -175,7 +184,7 @@ macro_rules! impl_client_trait {
data: $crate::client::ChatCompletionsData,
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::ChatCompletions);
let builder = self.request_builder(client, request_data);
$chat_completions(builder, self.model()).await
}

Expand All @@ -186,7 +195,7 @@ macro_rules! impl_client_trait {
data: $crate::client::ChatCompletionsData,
) -> Result<()> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::ChatCompletions);
let builder = self.request_builder(client, request_data);
$chat_completions_streaming(builder, handler, self.model()).await
}

Expand All @@ -196,7 +205,7 @@ macro_rules! impl_client_trait {
data: &$crate::client::EmbeddingsData,
) -> Result<$crate::client::EmbeddingsOutput> {
let request_data = $prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Embeddings);
let builder = self.request_builder(client, request_data);
$embeddings(builder, self.model()).await
}

Expand All @@ -206,7 +215,7 @@ macro_rules! impl_client_trait {
data: &$crate::client::RerankData,
) -> Result<$crate::client::RerankOutput> {
let request_data = $prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Rerank);
let builder = self.request_builder(client, request_data);
$rerank(builder, self.model()).await
}
}
Expand Down
Loading

0 comments on commit 3a33883

Please sign in to comment.