Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
chore: fix precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Nov 12, 2023
1 parent ab956c9 commit 7c3d1cf
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 32 deletions.
8 changes: 6 additions & 2 deletions crates/ggml/src/format/gguf/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ impl Metadata {
// TODO: consider finding a way to automate getting with traits
pub fn get_str(&self, key: &str) -> Result<&str, MetadataError> {
let metadata_value = self.get(key)?;
Ok(metadata_value
metadata_value
.as_string()
.ok_or_else(|| MetadataError::InvalidType {
key: key.to_string(),
expected_type: MetadataValueType::String,
actual_type: metadata_value.value_type(),
})?)
})
}

pub fn get_countable(&self, key: &str) -> Result<usize, MetadataError> {
Expand Down Expand Up @@ -460,6 +460,10 @@ impl MetadataArrayValue {
_ => None,
}
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}
}

// Shared
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/src/format/gguf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ impl TensorInfo {
util::write_length(writer, ctx.use_64_bit_length, *dimension)?;
}

util::write_u32(writer, ggml_type::from(self.element_type) as u32)?;
util::write_u32(writer, ggml_type::from(self.element_type))?;
util::write_u64(writer, self.offset)?;

Ok(())
Expand Down
5 changes: 2 additions & 3 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ impl TryFrom<llama_ftype> for FileType {
type Error = llama_ftype;

fn try_from(value: llama_ftype) -> Result<Self, Self::Error> {
let format =
FileTypeFormat::try_from(((value as u32) % ggml::QNT_VERSION_FACTOR) as llama_ftype)?;
let format = FileTypeFormat::try_from((value % ggml::QNT_VERSION_FACTOR) as llama_ftype)?;

Ok(Self {
format,
Expand Down Expand Up @@ -360,7 +359,7 @@ pub trait ModelFactory {
/// This method returns a [`Box`], which means that the model will have single ownership.
/// If you'd like to share ownership (i.e. to use the model in multiple threads), we
/// suggest using [`Arc::from(Box<T>)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E)
/// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading.
/// to convert the [`Box`] into an [`Arc`] after loading.
pub fn load(
path: &Path,
tokenizer_source: TokenizerSource,
Expand Down
1 change: 1 addition & 0 deletions crates/llm-base/src/tokenizer/embedded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ fn unescape_whitespace(text: &[u8]) -> Vec<u8> {
let mut buffer: Vec<u8> = vec![];

for &b in text {
#[allow(clippy::if_same_then_else)]
if b == 0xE2 {
// If the current byte is 0xE2, start buffering and check for the sequence.
buffer.push(b);
Expand Down
24 changes: 11 additions & 13 deletions crates/llm-base/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl TokenizerSource {
tokenizer_source: HuggingFaceTokenizerErrorSource::Remote(
identifier.clone(),
),
error: error.into(),
error,
}
})?,
)
Expand All @@ -128,7 +128,7 @@ impl TokenizerSource {
tokenizers::Tokenizer::from_file(&path).map_err(|error| {
TokenizerLoadError::HuggingFaceTokenizerError {
tokenizer_source: HuggingFaceTokenizerErrorSource::File(path.clone()),
error: error.into(),
error,
}
})?,
)
Expand All @@ -139,20 +139,18 @@ impl TokenizerSource {
Self::Embedded => {
if let Ok(hf) = gguf.metadata.get_str("tokenizer.huggingface.json") {
Ok(Self::load_huggingface_json(hf)?)
} else {
if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) {
if EMBEDDED_TOKENIZER_ENABLED {
Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into())
} else {
Err(TokenizerLoadError::NoSupportedTokenizersFound {
unsupported_tokenizers: vec!["embedded".to_owned()],
})
}
} else if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) {
if EMBEDDED_TOKENIZER_ENABLED {
Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into())
} else {
Err(TokenizerLoadError::NoSupportedTokenizersFound {
unsupported_tokenizers: vec![],
unsupported_tokenizers: vec!["embedded".to_owned()],
})
}
} else {
Err(TokenizerLoadError::NoSupportedTokenizersFound {
unsupported_tokenizers: vec![],
})
}
}
}
Expand All @@ -163,7 +161,7 @@ impl TokenizerSource {
HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_str(tokenizer_json).map_err(
|error| TokenizerLoadError::HuggingFaceTokenizerError {
tokenizer_source: HuggingFaceTokenizerErrorSource::String,
error: error.into(),
error,
},
)?)
.into(),
Expand Down
2 changes: 1 addition & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ default = ["models", "tokenizers-remote"]

tokenizers-remote = ["llm-base/tokenizers-remote"]

models = ["llama", "gptneox"] #, "gpt2", "gptj", "bloom", "mpt", "bert"]
models = ["llama", "gptneox", "gpt2", "gptj", "bloom", "mpt", "bert"]
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
Expand Down
2 changes: 1 addition & 1 deletion crates/llm/examples/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ fn main() {

fn get_embeddings(
model: &dyn llm::Model,
inference_parameters: &llm::InferenceParameters,
_inference_parameters: &llm::InferenceParameters,
query: &str,
) -> Vec<f32> {
let mut session = model.start_session(Default::default());
Expand Down
19 changes: 10 additions & 9 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! - [GPT-NeoX](llm_gptneox)
//! - [LLaMA](llm_llama)
//! - [MPT](llm_mpt)
//! - [BERT](llm_bert)
//! - Falcon (currently disabled due to incompleteness)
//!
//! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to
Expand All @@ -19,7 +20,7 @@
//! use llm::Model;
//!
//! // load a GGML model from disk
//! let llama = llm::load::<llm::models::Llama>(
//! let llama = llm::load(
//! // path to GGML file
//! std::path::Path::new("/path/to/model"),
//! // llm::TokenizerSource
Expand All @@ -35,7 +36,7 @@
//! let mut session = llama.start_session(Default::default());
//! let res = session.infer::<std::convert::Infallible>(
//! // model to use for text generation
//! &llama,
//! llama.as_ref(),
//! // randomness provider
//! &mut rand::thread_rng(),
//! // the prompt to use for text generation, as well as other
Expand Down Expand Up @@ -94,7 +95,7 @@ pub use loader::{load, load_progress_callback_stdout, LoadError, LoadProgress};
use serde::Serialize;

macro_rules! define_models {
($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal)),*) => {
($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal),)*) => {
/// All available models.
pub mod models {
$(
Expand Down Expand Up @@ -173,14 +174,14 @@ macro_rules! define_models {
}

define_models!(
(bert, "bert", Bert, llm_bert, "Bert"),
(bloom, "bloom", Bloom, llm_bloom, "BLOOM"),
(gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"),
(gptj, "gptj", GptJ, llm_gptj, "GPT-J"),
// (bert, "bert", Bert, llm_bert, "Bert"),
// (bloom, "bloom", Bloom, llm_bloom, "BLOOM"),
// (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"),
// (gptj, "gptj", GptJ, llm_gptj, "GPT-J"),
(gptneox, "gptneox", GptNeoX, llm_gptneox, "GPT-NeoX"),
(llama, "llama", Llama, llm_llama, "LLaMA"),
(mpt, "mpt", Mpt, llm_mpt, "MPT"),
(falcon, "falcon", Falcon, llm_falcon, "Falcon")
// (mpt, "mpt", Mpt, llm_mpt, "MPT"),
// (falcon, "falcon", Falcon, llm_falcon, "Falcon"),
);

/// Used to dispatch some code based on the model architecture.
Expand Down
4 changes: 2 additions & 2 deletions crates/llm/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ pub fn load(
params: ModelParameters,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<Box<dyn Model>, LoadError> {
Ok(llm_base::loader::load(
llm_base::loader::load(
path,
tokenizer_source,
params,
VisitorModelFactory,
load_progress_callback,
)?)
)
}

struct VisitorModelFactory;
Expand Down

0 comments on commit 7c3d1cf

Please sign in to comment.