From 7c3d1cf3a877f139f3d248df269fecc2fd356ce2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 12 Nov 2023 23:08:14 +0100 Subject: [PATCH] chore: fix precommit --- crates/ggml/src/format/gguf/metadata.rs | 8 ++++++-- crates/ggml/src/format/gguf/mod.rs | 2 +- crates/llm-base/src/loader.rs | 5 ++--- crates/llm-base/src/tokenizer/embedded.rs | 1 + crates/llm-base/src/tokenizer/mod.rs | 24 +++++++++++------------ crates/llm/Cargo.toml | 2 +- crates/llm/examples/embeddings.rs | 2 +- crates/llm/src/lib.rs | 19 +++++++++--------- crates/llm/src/loader.rs | 4 ++-- 9 files changed, 35 insertions(+), 32 deletions(-) diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 70e20347..11da1916 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -79,13 +79,13 @@ impl Metadata { // TODO: consider finding a way to automate getting with traits pub fn get_str(&self, key: &str) -> Result<&str, MetadataError> { let metadata_value = self.get(key)?; - Ok(metadata_value + metadata_value .as_string() .ok_or_else(|| MetadataError::InvalidType { key: key.to_string(), expected_type: MetadataValueType::String, actual_type: metadata_value.value_type(), - })?) + }) } pub fn get_countable(&self, key: &str) -> Result { @@ -460,6 +460,10 @@ impl MetadataArrayValue { _ => None, } } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } // Shared diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index c58e7276..dc2ab9dd 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -253,7 +253,7 @@ impl TensorInfo { util::write_length(writer, ctx.use_64_bit_length, *dimension)?; } - util::write_u32(writer, ggml_type::from(self.element_type) as u32)?; + util::write_u32(writer, ggml_type::from(self.element_type))?; util::write_u64(writer, self.offset)?; Ok(()) diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 52732155..58b4ab4d 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -39,8 +39,7 @@ impl TryFrom for FileType { type Error = llama_ftype; fn try_from(value: llama_ftype) -> Result { - let format = - FileTypeFormat::try_from(((value as u32) % ggml::QNT_VERSION_FACTOR) as llama_ftype)?; + let format = FileTypeFormat::try_from((value % ggml::QNT_VERSION_FACTOR) as llama_ftype)?; Ok(Self { format, @@ -360,7 +359,7 @@ pub trait ModelFactory { /// This method returns a [`Box`], which means that the model will have single ownership. /// If you'd like to share ownership (i.e. to use the model in multiple threads), we /// suggest using [`Arc::from(Box)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E) -/// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading. +/// to convert the [`Box`] into an [`Arc`] after loading. pub fn load( path: &Path, tokenizer_source: TokenizerSource, diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index 02acb4c3..25387d23 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -497,6 +497,7 @@ fn unescape_whitespace(text: &[u8]) -> Vec { let mut buffer: Vec = vec![]; for &b in text { + #[allow(clippy::if_same_then_else)] if b == 0xE2 { // If the current byte is 0xE2, start buffering and check for the sequence. buffer.push(b); diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index afa8c9d6..9852993e 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -118,7 +118,7 @@ impl TokenizerSource { tokenizer_source: HuggingFaceTokenizerErrorSource::Remote( identifier.clone(), ), - error: error.into(), + error, } })?, ) @@ -128,7 +128,7 @@ impl TokenizerSource { tokenizers::Tokenizer::from_file(&path).map_err(|error| { TokenizerLoadError::HuggingFaceTokenizerError { tokenizer_source: HuggingFaceTokenizerErrorSource::File(path.clone()), - error: error.into(), + error, } })?, ) @@ -139,20 +139,18 @@ impl TokenizerSource { Self::Embedded => { if let Ok(hf) = gguf.metadata.get_str("tokenizer.huggingface.json") { Ok(Self::load_huggingface_json(hf)?) - } else { - if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) { - if EMBEDDED_TOKENIZER_ENABLED { - Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) - } else { - Err(TokenizerLoadError::NoSupportedTokenizersFound { - unsupported_tokenizers: vec!["embedded".to_owned()], - }) - } + } else if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) { + if EMBEDDED_TOKENIZER_ENABLED { + Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) } else { Err(TokenizerLoadError::NoSupportedTokenizersFound { - unsupported_tokenizers: vec![], + unsupported_tokenizers: vec!["embedded".to_owned()], }) } + } else { + Err(TokenizerLoadError::NoSupportedTokenizersFound { + unsupported_tokenizers: vec![], + }) } } } @@ -163,7 +161,7 @@ impl TokenizerSource { HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_str(tokenizer_json).map_err( |error| TokenizerLoadError::HuggingFaceTokenizerError { tokenizer_source: HuggingFaceTokenizerErrorSource::String, - error: error.into(), + error, }, )?) .into(), diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 5db0ec0b..159950f8 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -36,7 +36,7 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama", "gptneox"] #, "gpt2", "gptj", "bloom", "mpt", "bert"] +models = ["llama", "gptneox", "gpt2", "gptj", "bloom", "mpt", "bert"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 795d9740..64fc4009 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -104,7 +104,7 @@ fn main() { fn get_embeddings( model: &dyn llm::Model, - inference_parameters: &llm::InferenceParameters, + _inference_parameters: &llm::InferenceParameters, query: &str, ) -> Vec { let mut session = model.start_session(Default::default()); diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index d588d034..39069f06 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -7,6 +7,7 @@ //! - [GPT-NeoX](llm_gptneox) //! - [LLaMA](llm_llama) //! - [MPT](llm_mpt) +//! - [BERT](llm_bert) //! - Falcon (currently disabled due to incompleteness) //! //! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to @@ -19,7 +20,7 @@ //! use llm::Model; //! //! // load a GGML model from disk -//! let llama = llm::load::( +//! let llama = llm::load( //! // path to GGML file //! std::path::Path::new("/path/to/model"), //! // llm::TokenizerSource @@ -35,7 +36,7 @@ //! let mut session = llama.start_session(Default::default()); //! let res = session.infer::( //! // model to use for text generation -//! &llama, +//! llama.as_ref(), //! // randomness provider //! &mut rand::thread_rng(), //! // the prompt to use for text generation, as well as other @@ -94,7 +95,7 @@ pub use loader::{load, load_progress_callback_stdout, LoadError, LoadProgress}; use serde::Serialize; macro_rules! define_models { - ($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal)),*) => { + ($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal),)*) => { /// All available models. pub mod models { $( @@ -173,14 +174,14 @@ macro_rules! define_models { } define_models!( - (bert, "bert", Bert, llm_bert, "Bert"), - (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), - (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), - (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), + // (bert, "bert", Bert, llm_bert, "Bert"), + // (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), + // (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), + // (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), (gptneox, "gptneox", GptNeoX, llm_gptneox, "GPT-NeoX"), (llama, "llama", Llama, llm_llama, "LLaMA"), - (mpt, "mpt", Mpt, llm_mpt, "MPT"), - (falcon, "falcon", Falcon, llm_falcon, "Falcon") + // (mpt, "mpt", Mpt, llm_mpt, "MPT"), + // (falcon, "falcon", Falcon, llm_falcon, "Falcon"), ); /// Used to dispatch some code based on the model architecture. diff --git a/crates/llm/src/loader.rs b/crates/llm/src/loader.rs index bc4c871c..dd75fd4a 100644 --- a/crates/llm/src/loader.rs +++ b/crates/llm/src/loader.rs @@ -23,13 +23,13 @@ pub fn load( params: ModelParameters, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { - Ok(llm_base::loader::load( + llm_base::loader::load( path, tokenizer_source, params, VisitorModelFactory, load_progress_callback, - )?) + ) } struct VisitorModelFactory;