From 931b717eec1804117909be86796f5f44b1d92f69 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 4 Dec 2024 20:54:38 +0800 Subject: [PATCH] update --- src/client/common.rs | 5 +++-- src/client/model.rs | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index 00307eaa..e0ec8619 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -150,16 +150,17 @@ pub trait Client: Sync + Send { } fn patch_request_data(&self, request_data: &mut RequestData) { + let model_type = self.model().model_type(); let map = std::env::var(get_env_name(&format!( "patch_{}_{}", self.model().client_name(), - self.model().model_type().api_name(), + model_type.api_name(), ))) .ok() .and_then(|v| serde_json::from_str(&v).ok()) .or_else(|| { self.patch_config() - .and_then(|v| self.model().model_type().extract_patch(v)) + .and_then(|v| model_type.extract_patch(v)) .cloned() }); let map = match map { diff --git a/src/client/model.rs b/src/client/model.rs index c3fee566..e0f9d291 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -104,10 +104,12 @@ impl Model { } pub fn model_type(&self) -> ModelType { - match self.data.model_type.as_str() { - "embed" | "embedding" => ModelType::Embedding, - "rerank" | "reranker" => ModelType::Reranker, - _ => ModelType::Chat, + if self.data.model_type.starts_with("embed") { + return ModelType::Embedding; + } else if self.data.model_type.starts_with("rerank") { + return ModelType::Reranker; + } else { + ModelType::Chat } }