From e9379a822afe4df5feceaa8de600cf70d9635392 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 25 Nov 2023 08:32:07 +0000 Subject: [PATCH 1/7] updated tch version --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6a6d7195..162cf097 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,7 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "8.1.1" -tch = "0.13.0" +tch = "0.14.0" serde_json = "1" serde = { version = "1", features = ["derive"] } ordered-float = "3" @@ -97,7 +97,7 @@ anyhow = "1" csv = "1" criterion = "0.4" tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] } -torch-sys = "0.13.0" +torch-sys = "0.14.0" tempfile = "3" itertools = "0.10" tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } From bb4482fb091a09b2c5680485506b190b02a37507 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 25 Nov 2023 09:28:01 +0000 Subject: [PATCH 2/7] Addition of casting operation for cpu compat --- examples/natural_language_inference_deberta.rs | 2 +- src/common/resources/mod.rs | 15 +++++++++------ src/models/bart/bart_model.rs | 7 ++++++- src/models/gpt2/gpt2_model.rs | 7 ++++++- src/models/gpt_j/gpt_j_model.rs | 7 ++++++- src/models/gpt_neo/gpt_neo_model.rs | 7 ++++++- src/models/longt5/longt5_model.rs | 7 ++++++- src/models/m2m_100/m2m_100_model.rs | 7 ++++++- src/models/marian/marian_model.rs | 7 ++++++- src/models/mbart/mbart_model.rs | 9 +++++++-- src/models/openai_gpt/openai_gpt_model.rs | 7 ++++++- src/models/pegasus/pegasus_model.rs | 7 ++++++- src/models/prophetnet/prophetnet_model.rs | 7 ++++++- src/models/reformer/reformer_model.rs | 7 ++++++- src/models/t5/t5_model.rs | 7 ++++++- src/models/xlnet/xlnet_model.rs | 7 ++++++- src/pipelines/common.rs | 9 +++++++++ src/pipelines/conversation.rs | 4 ++++ src/pipelines/generation_utils.rs | 5 ++++- src/pipelines/masked_language.rs | 8 ++++++-- src/pipelines/pos_tagging.rs | 1 + src/pipelines/question_answering.rs | 8 +++++++- src/pipelines/sentence_embeddings/builder.rs | 11 ++++++++++- src/pipelines/sentence_embeddings/config.rs | 16 ++++++++++------ src/pipelines/sentence_embeddings/pipeline.rs | 8 +++++++- src/pipelines/sequence_classification.rs | 6 +++++- src/pipelines/summarization.rs | 6 +++++- src/pipelines/text_generation.rs | 6 +++++- src/pipelines/token_classification.rs | 6 +++++- .../translation/translation_pipeline.rs | 6 +++++- src/pipelines/zero_shot_classification.rs | 9 ++++++++- tests/albert.rs | 2 +- tests/bart.rs | 3 ++- tests/onnx.rs | 6 +++--- 34 files changed, 192 insertions(+), 45 deletions(-) diff --git a/examples/natural_language_inference_deberta.rs b/examples/natural_language_inference_deberta.rs index debe53b8..4ab227a0 100644 --- a/examples/natural_language_inference_deberta.rs +++ b/examples/natural_language_inference_deberta.rs @@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> { )?; let config = DebertaConfig::from_file(config_path); let model = DebertaForSequenceClassification::new(vs.root(), &config)?; - load_weights(&model_resource, &mut vs)?; + load_weights(&model_resource, &mut vs, None, device)?; // Define input let input = [("I love you.", "I like you.")]; diff --git a/src/common/resources/mod.rs b/src/common/resources/mod.rs index d4edf617..4b26b496 100644 --- a/src/common/resources/mod.rs +++ b/src/common/resources/mod.rs @@ -30,6 +30,7 @@ use std::ops::DerefMut; use std::path::PathBuf; use std::sync::RwLockWriteGuard; use tch::nn::VarStore; +use tch::{Device, Kind}; pub enum Resource<'a> { PathBuf(PathBuf), @@ -84,17 +85,19 @@ impl ResourceProvider for Box { pub fn load_weights( rp: &(impl ResourceProvider + ?Sized), vs: &mut VarStore, + kind: Option, + device: Device, ) -> Result<(), RustBertError> { match rp.get_resource()? { - Resource::Buffer(mut data) => { - vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?; - Ok(()) - } - Resource::PathBuf(path) => Ok(vs.load(path)?), - } + Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())), + Resource::PathBuf(path) => vs.load(path), + }?; + cast_var_store(vs, kind, device); + Ok(()) } #[cfg(feature = "remote")] mod remote; +use crate::pipelines::common::cast_var_store; #[cfg(feature = "remote")] pub use remote::RemoteResource; diff --git a/src/models/bart/bart_model.rs b/src/models/bart/bart_model.rs index 62246fac..cdc23d36 100644 --- a/src/models/bart/bart_model.rs +++ b/src/models/bart/bart_model.rs @@ -1004,7 +1004,12 @@ impl BartGenerator { let mut var_store = nn::VarStore::new(device); let config = BartConfig::from_file(config_path); let model = BartForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/models/gpt2/gpt2_model.rs b/src/models/gpt2/gpt2_model.rs index 3fb5b7aa..0aac55d8 100644 --- a/src/models/gpt2/gpt2_model.rs +++ b/src/models/gpt2/gpt2_model.rs @@ -652,7 +652,12 @@ impl GPT2Generator { let config = Gpt2Config::from_file(config_path); let model = GPT2LMHeadModel::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/models/gpt_j/gpt_j_model.rs b/src/models/gpt_j/gpt_j_model.rs index 8468d577..907f8053 100644 --- a/src/models/gpt_j/gpt_j_model.rs +++ b/src/models/gpt_j/gpt_j_model.rs @@ -625,7 +625,12 @@ impl GptJGenerator { if config.preload_on_cpu && device != Device::Cpu { var_store.set_device(Device::Cpu); } - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; if device != Device::Cpu { var_store.set_device(device); } diff --git a/src/models/gpt_neo/gpt_neo_model.rs b/src/models/gpt_neo/gpt_neo_model.rs index ed7773c7..3faeb7af 100644 --- a/src/models/gpt_neo/gpt_neo_model.rs +++ b/src/models/gpt_neo/gpt_neo_model.rs @@ -672,7 +672,12 @@ impl GptNeoGenerator { let mut var_store = nn::VarStore::new(device); let config = GptNeoConfig::from_file(config_path); let model = GptNeoForCausalLM::new(var_store.root(), &config)?; - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/models/longt5/longt5_model.rs b/src/models/longt5/longt5_model.rs index 3efc47a6..69dff618 100644 --- a/src/models/longt5/longt5_model.rs +++ b/src/models/longt5/longt5_model.rs @@ -595,7 +595,12 @@ impl LongT5Generator { let config = LongT5Config::from_file(config_path); let model = LongT5ForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = config.bos_token_id; let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/models/m2m_100/m2m_100_model.rs b/src/models/m2m_100/m2m_100_model.rs index 9b5d65c0..c86ce604 100644 --- a/src/models/m2m_100/m2m_100_model.rs +++ b/src/models/m2m_100/m2m_100_model.rs @@ -544,7 +544,12 @@ impl M2M100Generator { let config = M2M100Config::from_file(config_path); let model = M2M100ForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/models/marian/marian_model.rs b/src/models/marian/marian_model.rs index 44843157..5368037c 100644 --- a/src/models/marian/marian_model.rs +++ b/src/models/marian/marian_model.rs @@ -761,7 +761,12 @@ impl MarianGenerator { let config = BartConfig::from_file(config_path); let model = MarianForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/models/mbart/mbart_model.rs b/src/models/mbart/mbart_model.rs index 80d52c9a..0de1d0b5 100644 --- a/src/models/mbart/mbart_model.rs +++ b/src/models/mbart/mbart_model.rs @@ -650,7 +650,7 @@ impl MBartForSequenceClassification { /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = MBartConfig::from_file(config_path); - /// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();; + /// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap(); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -800,7 +800,12 @@ impl MBartGenerator { let config = MBartConfig::from_file(config_path); let model = MBartForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/models/openai_gpt/openai_gpt_model.rs b/src/models/openai_gpt/openai_gpt_model.rs index 55a3e5b9..7780f34d 100644 --- a/src/models/openai_gpt/openai_gpt_model.rs +++ b/src/models/openai_gpt/openai_gpt_model.rs @@ -498,7 +498,12 @@ impl OpenAIGenerator { let mut var_store = nn::VarStore::new(device); let config = Gpt2Config::from_file(config_path); let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/models/pegasus/pegasus_model.rs b/src/models/pegasus/pegasus_model.rs index 557d647a..4fd9b6ba 100644 --- a/src/models/pegasus/pegasus_model.rs +++ b/src/models/pegasus/pegasus_model.rs @@ -505,7 +505,12 @@ impl PegasusConditionalGenerator { let mut var_store = nn::VarStore::new(device); let config = PegasusConfig::from_file(config_path); let model = PegasusForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = config diff --git a/src/models/prophetnet/prophetnet_model.rs b/src/models/prophetnet/prophetnet_model.rs index c8c746c3..e27aaf27 100644 --- a/src/models/prophetnet/prophetnet_model.rs +++ b/src/models/prophetnet/prophetnet_model.rs @@ -919,7 +919,12 @@ impl ProphetNetConditionalGenerator { let mut var_store = nn::VarStore::new(device); let config = ProphetNetConfig::from_file(config_path); let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?; - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id); let eos_token_ids = Some(vec![config.eos_token_id]); diff --git a/src/models/reformer/reformer_model.rs b/src/models/reformer/reformer_model.rs index bd1711db..4d4f5292 100644 --- a/src/models/reformer/reformer_model.rs +++ b/src/models/reformer/reformer_model.rs @@ -1056,7 +1056,12 @@ impl ReformerGenerator { let mut var_store = nn::VarStore::new(device); let config = ReformerConfig::from_file(config_path); let model = ReformerModelWithLMHead::new(var_store.root(), &config)?; - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/models/t5/t5_model.rs b/src/models/t5/t5_model.rs index 53c60d9f..c715815e 100644 --- a/src/models/t5/t5_model.rs +++ b/src/models/t5/t5_model.rs @@ -763,7 +763,12 @@ impl T5Generator { let config = T5Config::from_file(config_path); let model = T5ForConditionalGeneration::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id.unwrap_or(-1)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/models/xlnet/xlnet_model.rs b/src/models/xlnet/xlnet_model.rs index d45afb2d..7246da66 100644 --- a/src/models/xlnet/xlnet_model.rs +++ b/src/models/xlnet/xlnet_model.rs @@ -1560,7 +1560,12 @@ impl XLNetGenerator { let config = XLNetConfig::from_file(config_path); let model = XLNetLMHeadModel::new(var_store.root(), &config); - crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; + crate::resources::load_weights( + &generate_config.model_resource, + &mut var_store, + generate_config.kind, + device, + )?; let bos_token_id = Some(config.bos_token_id); let eos_token_ids = Some(vec![config.eos_token_id]); diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index 5356671c..8575f5be 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -60,6 +60,7 @@ use std::convert::TryFrom; use std::fmt::Debug; use std::path::{Path, PathBuf}; +use tch::nn::VarStore; use tch::{Device, Kind, Tensor}; #[cfg(feature = "onnx")] @@ -2348,3 +2349,11 @@ impl TokenizerOption { } } } + +pub fn cast_var_store(varstore: &mut VarStore, kind: Option, device: Device) { + match (kind, device) { + (Some(kind), _) => varstore.set_kind(kind), + (None, Device::Cpu) => varstore.set_kind(Kind::Float), + (None, _) => {} + } +} diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 02e864ab..7e90f91a 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -115,6 +115,8 @@ pub struct ConversationConfig { pub diversity_penalty: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } #[cfg(feature = "remote")] @@ -150,6 +152,7 @@ impl Default for ConversationConfig { num_beam_groups: None, diversity_penalty: None, device: Device::cuda_if_available(), + kind: None, } } } @@ -177,6 +180,7 @@ impl From for GenerateConfig { num_beam_groups: config.num_beam_groups, diversity_penalty: config.diversity_penalty, device: config.device, + kind: config.kind, } } } diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index a6cc3893..1105bab1 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -67,7 +67,7 @@ //! ``` use tch::kind::Kind::Int64; -use tch::{no_grad, Device, Tensor}; +use tch::{no_grad, Device, Kind, Tensor}; use crate::bart::LayerState as BartLayerState; use crate::common::resources::ResourceProvider; @@ -136,6 +136,8 @@ pub struct GenerateConfig { pub diversity_penalty: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } #[cfg(feature = "remote")] @@ -166,6 +168,7 @@ impl Default for GenerateConfig { num_beam_groups: None, diversity_penalty: None, device: Device::cuda_if_available(), + kind: None, } } } diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs index 92b10f1c..2a3c4319 100644 --- a/src/pipelines/masked_language.rs +++ b/src/pipelines/masked_language.rs @@ -52,7 +52,7 @@ use crate::deberta::DebertaForMaskedLM; use crate::deberta_v2::DebertaV2ForMaskedLM; use crate::fnet::FNetForMaskedLM; use crate::pipelines::common::{ - get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, + cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, }; use crate::resources::ResourceProvider; use crate::roberta::RobertaForMaskedLM; @@ -67,7 +67,7 @@ use crate::{ resources::RemoteResource, }; use tch::nn::VarStore; -use tch::{no_grad, Device, Tensor}; +use tch::{no_grad, Device, Kind, Tensor}; #[derive(Debug, Clone)] /// Output container for masked language model pipeline. @@ -103,6 +103,8 @@ pub struct MaskedLanguageConfig { pub mask_token: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl MaskedLanguageConfig { @@ -143,6 +145,7 @@ impl MaskedLanguageConfig { add_prefix_space: add_prefix_space.into(), mask_token: mask_token.into(), device: Device::cuda_if_available(), + kind: None, } } } @@ -285,6 +288,7 @@ impl MaskedLanguageOption { ))), }?; var_store.load(weights_path)?; + cast_var_store(&mut var_store, config.kind, device); Ok(model) } diff --git a/src/pipelines/pos_tagging.rs b/src/pipelines/pos_tagging.rs index 1e7e907c..d8d97fb2 100644 --- a/src/pipelines/pos_tagging.rs +++ b/src/pipelines/pos_tagging.rs @@ -138,6 +138,7 @@ impl Default for POSConfig { strip_accents: Some(true), add_prefix_space: None, device: Device::cuda_if_available(), + kind: None, label_aggregation_function: LabelAggregationOption::First, batch_size: 64, }, diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index 67651524..f8aedaba 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -52,7 +52,7 @@ use crate::fnet::FNetForQuestionAnswering; use crate::longformer::LongformerForQuestionAnswering; use crate::mobilebert::MobileBertForQuestionAnswering; use crate::pipelines::common::{ - get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, + cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, }; use crate::reformer::ReformerForQuestionAnswering; use crate::resources::ResourceProvider; @@ -158,6 +158,8 @@ pub struct QuestionAnsweringConfig { pub max_query_length: usize, /// Maximum length for the answer pub max_answer_length: usize, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl QuestionAnsweringConfig { @@ -199,6 +201,7 @@ impl QuestionAnsweringConfig { doc_stride: 128, max_query_length: 64, max_answer_length: 15, + kind: None, } } @@ -248,6 +251,7 @@ impl QuestionAnsweringConfig { doc_stride: doc_stride.into().unwrap_or(128), max_query_length: max_query_length.into().unwrap_or(64), max_answer_length: max_answer_length.into().unwrap_or(15), + kind: None, } } } @@ -267,6 +271,7 @@ impl Default for QuestionAnsweringConfig { )), merges_resource: None, device: Device::cuda_if_available(), + kind: None, model_type: ModelType::DistilBert, lower_case: false, add_prefix_space: None, @@ -474,6 +479,7 @@ impl QuestionAnsweringOption { ))), }?; var_store.load(weights_path)?; + cast_var_store(&mut var_store, config.kind, device); Ok(model) } diff --git a/src/pipelines/sentence_embeddings/builder.rs b/src/pipelines/sentence_embeddings/builder.rs index 74bb1ecd..bab028d8 100644 --- a/src/pipelines/sentence_embeddings/builder.rs +++ b/src/pipelines/sentence_embeddings/builder.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use serde::Deserialize; -use tch::Device; +use tch::{Device, Kind}; use crate::pipelines::common::ModelType; use crate::pipelines::sentence_embeddings::{ @@ -21,6 +21,7 @@ use crate::{ /// (configuration and weights). pub struct SentenceEmbeddingsBuilder { device: Device, + kind: Option, inner: T, } @@ -29,6 +30,11 @@ impl SentenceEmbeddingsBuilder { self.device = device; self } + + pub fn with_kind(mut self, kind: Kind) -> Self { + self.kind = Some(kind); + self + } } pub struct Local { @@ -46,6 +52,7 @@ impl SentenceEmbeddingsBuilder { pub fn local>(model_dir: P) -> Self { Self { device: Device::cuda_if_available(), + kind: None, inner: Local { model_dir: model_dir.into(), }, @@ -106,6 +113,7 @@ impl SentenceEmbeddingsBuilder { tokenizer_vocab_resource: tokenizer_vocab.into(), tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()), device: self.device, + kind: self.kind, }; SentenceEmbeddingsModel::new(config) @@ -122,6 +130,7 @@ impl SentenceEmbeddingsBuilder { pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self { Self { device: Device::cuda_if_available(), + kind: None, inner: Remote { config: SentenceEmbeddingsConfig::from(model_type), }, diff --git a/src/pipelines/sentence_embeddings/config.rs b/src/pipelines/sentence_embeddings/config.rs index 8901c3cd..f6f6cb2b 100644 --- a/src/pipelines/sentence_embeddings/config.rs +++ b/src/pipelines/sentence_embeddings/config.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use tch::Device; +use tch::{Device, Kind}; use crate::pipelines::common::ModelType; use crate::resources::ResourceProvider; @@ -55,6 +55,8 @@ pub struct SentenceEmbeddingsConfig { pub tokenizer_merges_resource: Option>, /// Device to place the transformer model on pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } #[cfg(feature = "remote")] @@ -92,6 +94,7 @@ impl From for SentenceEmbeddingsConfig { )), tokenizer_merges_resource: None, device: Device::cuda_if_available(), + kind: None, }, SentenceEmbeddingsModelType::BertBaseNliMeanTokens => SentenceEmbeddingsConfig { @@ -121,6 +124,7 @@ impl From for SentenceEmbeddingsConfig { )), tokenizer_merges_resource: None, device: Device::cuda_if_available(), + kind: None, }, SentenceEmbeddingsModelType::AllMiniLmL12V2 => SentenceEmbeddingsConfig { @@ -149,7 +153,7 @@ impl From for SentenceEmbeddingsConfig { BertVocabResources::ALL_MINI_LM_L12_V2, )), tokenizer_merges_resource: None, - device: Device::cuda_if_available(), + device: Device::cuda_if_available(), kind: None, }, SentenceEmbeddingsModelType::AllMiniLmL6V2 => SentenceEmbeddingsConfig { @@ -178,7 +182,7 @@ impl From for SentenceEmbeddingsConfig { BertVocabResources::ALL_MINI_LM_L6_V2, )), tokenizer_merges_resource: None, - device: Device::cuda_if_available(), + device: Device::cuda_if_available(), kind: None, }, SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig { @@ -209,7 +213,7 @@ impl From for SentenceEmbeddingsConfig { tokenizer_merges_resource: Some(Box::new(RemoteResource::from_pretrained( RobertaMergesResources::ALL_DISTILROBERTA_V1, ))), - device: Device::cuda_if_available(), + device: Device::cuda_if_available(), kind: None, }, SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2 => SentenceEmbeddingsConfig { @@ -238,7 +242,7 @@ impl From for SentenceEmbeddingsConfig { AlbertVocabResources::PARAPHRASE_ALBERT_SMALL_V2, )), tokenizer_merges_resource: None, - device: Device::cuda_if_available(), + device: Device::cuda_if_available(), kind: None, }, SentenceEmbeddingsModelType::SentenceT5Base => SentenceEmbeddingsConfig { @@ -271,7 +275,7 @@ impl From for SentenceEmbeddingsConfig { T5VocabResources::SENTENCE_T5_BASE, )), tokenizer_merges_resource: None, - device: Device::cuda_if_available(), + device: Device::cuda_if_available(), kind: None, }, } } diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs index f9d8f9a0..a634ce3f 100644 --- a/src/pipelines/sentence_embeddings/pipeline.rs +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -236,6 +236,7 @@ impl SentenceEmbeddingsModel { dense_config_resource, dense_weights_resource, device, + kind, } = config; let modules = @@ -254,7 +255,12 @@ impl SentenceEmbeddingsModel { ); let transformer = SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?; - crate::resources::load_weights(&transformer_weights_resource, &mut var_store)?; + crate::resources::load_weights( + &transformer_weights_resource, + &mut var_store, + kind, + device, + )?; // Setup pooling layer let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?); diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 3faf049b..fbc05a2e 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -68,7 +68,7 @@ use crate::fnet::FNetForSequenceClassification; use crate::longformer::LongformerForSequenceClassification; use crate::mobilebert::MobileBertForSequenceClassification; use crate::pipelines::common::{ - get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, + cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, }; use crate::reformer::ReformerForSequenceClassification; use crate::resources::ResourceProvider; @@ -123,6 +123,8 @@ pub struct SequenceClassificationConfig { pub add_prefix_space: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl SequenceClassificationConfig { @@ -160,6 +162,7 @@ impl SequenceClassificationConfig { strip_accents: strip_accents.into(), add_prefix_space: add_prefix_space.into(), device: Device::cuda_if_available(), + kind: None, } } } @@ -392,6 +395,7 @@ impl SequenceClassificationOption { ))), }?; var_store.load(weights_path)?; + cast_var_store(&mut var_store, config.kind, device); Ok(model) } diff --git a/src/pipelines/summarization.rs b/src/pipelines/summarization.rs index ed5a26fc..299049f9 100644 --- a/src/pipelines/summarization.rs +++ b/src/pipelines/summarization.rs @@ -62,7 +62,7 @@ //! # ; //! ``` -use tch::Device; +use tch::{Device, Kind}; use crate::bart::BartGenerator; use crate::common::error::RustBertError; @@ -126,6 +126,8 @@ pub struct SummarizationConfig { pub diversity_penalty: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl SummarizationConfig { @@ -170,6 +172,7 @@ impl SummarizationConfig { num_beam_groups: None, diversity_penalty: None, device: Device::cuda_if_available(), + kind: None, } } } @@ -214,6 +217,7 @@ impl From for GenerateConfig { num_beam_groups: config.num_beam_groups, diversity_penalty: config.diversity_penalty, device: config.device, + kind: config.kind, } } } diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index 97c23070..c084aaa5 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -31,7 +31,7 @@ //! //! Customized text generation models models can be loaded by overwriting the resources in the configuration. //! The dependencies will be downloaded to the user's home directory, e.g. under ~/.cache/.rustbert/gpt2 -use tch::Device; +use tch::{Device, Kind}; use crate::common::error::RustBertError; use crate::gpt2::GPT2Generator; @@ -97,6 +97,8 @@ pub struct TextGenerationConfig { pub diversity_penalty: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl TextGenerationConfig { @@ -141,6 +143,7 @@ impl TextGenerationConfig { num_beam_groups: None, diversity_penalty: None, device: Device::cuda_if_available(), + kind: None, } } } @@ -185,6 +188,7 @@ impl From for GenerateConfig { num_beam_groups: config.num_beam_groups, diversity_penalty: config.diversity_penalty, device: config.device, + kind: config.kind, } } } diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 37957278..c29dde5d 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -122,7 +122,7 @@ use crate::fnet::FNetForTokenClassification; use crate::longformer::LongformerForTokenClassification; use crate::mobilebert::MobileBertForTokenClassification; use crate::pipelines::common::{ - get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, + cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption, }; use crate::resources::ResourceProvider; use crate::roberta::RobertaForTokenClassification; @@ -242,6 +242,8 @@ pub struct TokenClassificationConfig { pub add_prefix_space: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, /// Sub-tokens aggregation method (default: `LabelAggregationOption::First`) pub label_aggregation_function: LabelAggregationOption, /// Batch size for predictions @@ -284,6 +286,7 @@ impl TokenClassificationConfig { strip_accents: strip_accents.into(), add_prefix_space: add_prefix_space.into(), device: Device::cuda_if_available(), + kind: None, label_aggregation_function, batch_size: 64, } @@ -506,6 +509,7 @@ impl TokenClassificationOption { ))), }?; var_store.load(weights_path)?; + cast_var_store(&mut var_store, config.kind, device); Ok(model) } diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs index f4ff01ee..5a4c0a40 100644 --- a/src/pipelines/translation/translation_pipeline.rs +++ b/src/pipelines/translation/translation_pipeline.rs @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tch::Device; +use tch::{Device, Kind}; use crate::common::error::RustBertError; use crate::m2m_100::M2M100Generator; @@ -978,6 +978,8 @@ pub struct TranslationConfig { pub num_beam_groups: Option, /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5) pub diversity_penalty: Option, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl TranslationConfig { @@ -1065,6 +1067,7 @@ impl TranslationConfig { num_return_sequences: 1, num_beam_groups: None, diversity_penalty: None, + kind: None, } } } @@ -1092,6 +1095,7 @@ impl From for GenerateConfig { num_beam_groups: config.num_beam_groups, diversity_penalty: config.diversity_penalty, device: config.device, + kind: config.kind, } } } diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 5625c2a3..642bc35f 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -106,7 +106,9 @@ use crate::deberta_v2::DebertaV2ForSequenceClassification; use crate::distilbert::DistilBertModelClassifier; use crate::longformer::LongformerForSequenceClassification; use crate::mobilebert::MobileBertForSequenceClassification; -use crate::pipelines::common::{ConfigOption, ModelResource, ModelType, TokenizerOption}; +use crate::pipelines::common::{ + cast_var_store, ConfigOption, ModelResource, ModelType, TokenizerOption, +}; use crate::pipelines::sequence_classification::Label; use crate::resources::ResourceProvider; use crate::roberta::RobertaForSequenceClassification; @@ -147,6 +149,8 @@ pub struct ZeroShotClassificationConfig { pub add_prefix_space: Option, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise + pub kind: Option, } impl ZeroShotClassificationConfig { @@ -184,6 +188,7 @@ impl ZeroShotClassificationConfig { strip_accents: strip_accents.into(), add_prefix_space: add_prefix_space.into(), device: Device::cuda_if_available(), + kind: None, } } } @@ -210,6 +215,7 @@ impl Default for ZeroShotClassificationConfig { strip_accents: None, add_prefix_space: None, device: Device::cuda_if_available(), + kind: None, } } } @@ -400,6 +406,7 @@ impl ZeroShotClassificationOption { ))), }?; var_store.load(weights_path)?; + cast_var_store(&mut var_store, config.kind, device); Ok(model) } diff --git a/tests/albert.rs b/tests/albert.rs index 8990f4d7..d60abde0 100644 --- a/tests/albert.rs +++ b/tests/albert.rs @@ -35,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> { AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?; let config = AlbertConfig::from_file(config_path); let albert_model = AlbertForMaskedLM::new(vs.root(), &config); - load_weights(&weights_resource, &mut vs)?; + load_weights(&weights_resource, &mut vs, None, device)?; // Define input let input = [ diff --git a/tests/bart.rs b/tests/bart.rs index e307fe08..3b56ae45 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -2,7 +2,7 @@ use rust_bert::bart::{ BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources, BartVocabResources, }; -use rust_bert::pipelines::common::ModelResource; +use rust_bert::pipelines::common::{cast_var_store, ModelResource}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::zero_shot_classification::{ ZeroShotClassificationConfig, ZeroShotClassificationModel, @@ -44,6 +44,7 @@ fn bart_lm_model() -> anyhow::Result<()> { let config = BartConfig::from_file(config_path); let bart_model = BartModel::new(&vs.root() / "model", &config); vs.load(weights_path)?; + cast_var_store(&mut vs, None, device); // Define input let input = ["One two three four"]; diff --git a/tests/onnx.rs b/tests/onnx.rs index 7caa3689..8a0b216e 100644 --- a/tests/onnx.rs +++ b/tests/onnx.rs @@ -234,15 +234,15 @@ mod tests { ModelType::M2M100, ModelResource::ONNX(ONNXModelResources { encoder_resource: Some(Box::new(RemoteResource::new( - "https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx", + "https://huggingface.co/optimum/m2m100_418M/blob/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/encoder_model.onnx", "onnx-m2m100_418M", ))), decoder_resource: Some(Box::new(RemoteResource::new( - "https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx", + "https://huggingface.co/optimum/m2m100_418M/blob/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_model.onnx", "onnx-m2m100_418M", ))), decoder_with_past_resource: Some(Box::new(RemoteResource::new( - "https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx", + "https://huggingface.co/optimum/m2m100_418M/blob/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_with_past_model.onnx", "onnx-m2m100_418M", ))), }), From a6b85d20f1156f74f19d57615c69edcf2bc11de2 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 25 Nov 2023 09:45:03 +0000 Subject: [PATCH 3/7] Fix ONNX resource path --- tests/onnx.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/onnx.rs b/tests/onnx.rs index 8a0b216e..ce1a9d68 100644 --- a/tests/onnx.rs +++ b/tests/onnx.rs @@ -234,15 +234,15 @@ mod tests { ModelType::M2M100, ModelResource::ONNX(ONNXModelResources { encoder_resource: Some(Box::new(RemoteResource::new( - "https://huggingface.co/optimum/m2m100_418M/blob/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/encoder_model.onnx", + "https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/encoder_model.onnx", "onnx-m2m100_418M", ))), decoder_resource: Some(Box::new(RemoteResource::new( - "https://huggingface.co/optimum/m2m100_418M/blob/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_model.onnx", + "https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_model.onnx", "onnx-m2m100_418M", ))), decoder_with_past_resource: Some(Box::new(RemoteResource::new( - "https://huggingface.co/optimum/m2m100_418M/blob/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_with_past_model.onnx", + "https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_with_past_model.onnx", "onnx-m2m100_418M", ))), }), From dd1cb2f4115b10669ded7f429a81066819c86184 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 25 Nov 2023 10:26:36 +0000 Subject: [PATCH 4/7] Fix GPT-J bias bool loading --- src/models/gpt_j/attention.rs | 21 +++++++-------------- src/models/gpt_j/gpt_j_model.rs | 16 ---------------- src/models/gpt_j/transformer.rs | 9 --------- tests/gpt_j.rs | 15 ++++++++------- 4 files changed, 15 insertions(+), 46 deletions(-) diff --git a/src/models/gpt_j/attention.rs b/src/models/gpt_j/attention.rs index 4e7bcd91..2e46a3e7 100644 --- a/src/models/gpt_j/attention.rs +++ b/src/models/gpt_j/attention.rs @@ -68,11 +68,16 @@ impl GptJAttention { let p = p.borrow(); let max_positions = config.n_positions; - let bias = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device())) + let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device())) .tril(0) .view([1, 1, max_positions, max_positions]) .requires_grad_(false); - let bias = p.var_copy("bias", &bias); + let mut bias = p + .f_ones_no_train("bias", &[1, 1, max_positions, max_positions]) + .unwrap() + .to_kind(Kind::Uint8) + .to_device(p.device()); + bias.copy_(&bias_value); let attn_pdrop = config.attn_pdrop.unwrap_or(0.1); let resid_pdrop = config.resid_pdrop.unwrap_or(0.1); @@ -95,21 +100,9 @@ impl GptJAttention { ..Default::default() }; let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config); - if config.use_float16 { - (p / "k_proj").half(); - } let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config); - if config.use_float16 { - (p / "v_proj").half(); - } let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config); - if config.use_float16 { - (p / "q_proj").half(); - } let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config); - if config.use_float16 { - (p / "out_proj").half(); - } GptJAttention { bias, diff --git a/src/models/gpt_j/gpt_j_model.rs b/src/models/gpt_j/gpt_j_model.rs index 907f8053..de548426 100644 --- a/src/models/gpt_j/gpt_j_model.rs +++ b/src/models/gpt_j/gpt_j_model.rs @@ -131,8 +131,6 @@ pub struct GptJConfig { pub rotary_dim: Option, pub vocab_size: i64, pub scale_attn_weights: Option, - #[serde(default = "default_use_float16")] - pub use_float16: bool, #[serde(default = "default_preload_on_cpu")] pub preload_on_cpu: bool, pub decoder_start_token_id: Option, @@ -164,7 +162,6 @@ impl Default for GptJConfig { rotary_dim: Some(64), vocab_size: 50400, scale_attn_weights: Some(true), - use_float16: default_use_float16(), preload_on_cpu: default_preload_on_cpu(), decoder_start_token_id: None, forced_bos_token_id: None, @@ -173,10 +170,6 @@ impl Default for GptJConfig { } } -fn default_use_float16() -> bool { - true -} - fn default_preload_on_cpu() -> bool { true } @@ -233,9 +226,6 @@ impl GptJModel { config.n_embd, Default::default(), ); - if config.use_float16 { - (&(&p / "wte") / "weight").half() - }; let embd_pdrop = config.embd_pdrop.unwrap_or(0.1); let drop = Dropout::new(embd_pdrop); @@ -245,9 +235,6 @@ impl GptJModel { ..Default::default() }; let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config); - if config.use_float16 { - (&p / "ln_f").half() - }; let mut h: Vec = vec![]; let h_path = &p / "h"; @@ -475,9 +462,6 @@ impl GptJLMHeadModel { config.vocab_size, Default::default(), ); - if config.use_float16 { - (p / "lm_head").half(); - } GptJLMHeadModel { transformer, diff --git a/src/models/gpt_j/transformer.rs b/src/models/gpt_j/transformer.rs index a00878c9..798ee750 100644 --- a/src/models/gpt_j/transformer.rs +++ b/src/models/gpt_j/transformer.rs @@ -43,18 +43,12 @@ impl GptJMLP { intermediate_size, Default::default(), ); - if config.use_float16 { - (p / "fc_in").half() - }; let fc_out = nn::linear( p / "fc_out", intermediate_size, config.n_embd, Default::default(), ); - if config.use_float16 { - (p / "fc_out").half() - }; let activation = match &config.afn { Some(activation_enum) => match activation_enum { @@ -100,9 +94,6 @@ impl GptJBlock { ..Default::default() }; let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config); - if config.use_float16 { - (p / "ln_1").half() - }; let attn = GptJAttention::new(p / "attn", config); let mlp = GptJMLP::new(p / "mlp", config); diff --git a/tests/gpt_j.rs b/tests/gpt_j.rs index d09e9a26..f2efc1d4 100644 --- a/tests/gpt_j.rs +++ b/tests/gpt_j.rs @@ -3,12 +3,12 @@ use rust_bert::gpt_j::{ GptJVocabResources, }; use rust_bert::pipelines::generation_utils::Cache; -use rust_bert::resources::{RemoteResource, ResourceProvider}; +use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider}; use rust_bert::Config; use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer}; use rust_tokenizers::vocab::Vocab; use std::convert::TryFrom; -use tch::{nn, Device, Tensor}; +use tch::{nn, Device, Kind, Tensor}; /// Equivalent Python code: /// @@ -67,14 +67,15 @@ fn gpt_j_correctness() -> anyhow::Result<()> { let mut vs = nn::VarStore::new(device); let config_path = config_resource.get_local_path()?; - let weights_path = model_resource.get_local_path()?; - let mut config = GptJConfig::from_file(config_path); - config.use_float16 = matches!(device, Device::Cuda(_)); + let config = GptJConfig::from_file(config_path); let model = GptJLMHeadModel::new(vs.root(), &config); - vs.load(weights_path)?; + let kind = match device { + Device::Cpu => None, + _ => Some(Kind::Half), + }; + load_weights(&model_resource, &mut vs, kind, device)?; // Tokenize prompts - let prompts = [ "It was a very nice and sunny", "It was a gloom winter night, and", From e3a10049334b1c5b2c178888f02b71cf19918e22 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 25 Nov 2023 10:30:38 +0000 Subject: [PATCH 5/7] Updated changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index da57eda6..e2ca290c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,12 +5,17 @@ All notable changes to this project will be documented in this file. The format ## Added - Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines. - Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files. +- Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise. ## Fixed - (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences). - Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations - Updated ONNX runtime backend version to 1.15.x - Issue with incorrect results for QA models with a tokenizer not using segment ids +- Issue with GPT-J that was incorrectly tracking the gradients for the attention bias + +## Changed +- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0). ## [0.21.0] - 2023-06-03 ## Added From 69b1991825a13a897ac0a75ff638fd088f8773a3 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 25 Nov 2023 10:41:12 +0000 Subject: [PATCH 6/7] Fix Clippy warnings --- CHANGELOG.md | 2 +- benches/generation_benchmark.rs | 1 + src/models/longt5/encoder.rs | 4 ++-- src/models/reformer/attention.rs | 4 ++-- src/models/t5/attention.rs | 10 +++++----- src/models/t5/encoder.rs | 6 +++--- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2ca290c..49cc059b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format ## Added - Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines. - Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files. -- Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise. +- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise. ## Fixed - (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences). diff --git a/benches/generation_benchmark.rs b/benches/generation_benchmark.rs index c281250b..e54a9718 100644 --- a/benches/generation_benchmark.rs +++ b/benches/generation_benchmark.rs @@ -37,6 +37,7 @@ fn create_text_generation_model() -> TextGenerationModel { diversity_penalty: None, num_return_sequences: 5, device: Device::cuda_if_available(), + kind: None, }; TextGenerationModel::new(config).unwrap() } diff --git a/src/models/longt5/encoder.rs b/src/models/longt5/encoder.rs index 637e7e0a..0d4bfb5d 100644 --- a/src/models/longt5/encoder.rs +++ b/src/models/longt5/encoder.rs @@ -288,8 +288,8 @@ impl LongT5Stack { let (batch_size, sequence_length) = (input_shape[0], input_shape[1]); - let mask_seq_length = if old_layer_states.is_some() { - if old_layer_states.as_ref().unwrap()[0].0.is_some() { + let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states { + if old_layer_states_value[0].0.is_some() { old_layer_states.as_ref().unwrap()[0] .0 .as_ref() diff --git a/src/models/reformer/attention.rs b/src/models/reformer/attention.rs index 2784967a..598f7db7 100644 --- a/src/models/reformer/attention.rs +++ b/src/models/reformer/attention.rs @@ -1368,8 +1368,8 @@ impl ReformerAttention { let new_layer_state = if self.use_past { let prev_buckets = if let Some(buckets_value) = &buckets { if layer_state.is_none() | { - if layer_state.is_some() { - layer_state.as_ref().unwrap().prev_buckets.is_none() + if let Some(layer_state_value) = &layer_state { + layer_state_value.prev_buckets.is_none() } else { false } diff --git a/src/models/t5/attention.rs b/src/models/t5/attention.rs index 2abae806..777375c0 100644 --- a/src/models/t5/attention.rs +++ b/src/models/t5/attention.rs @@ -191,15 +191,15 @@ impl T5Attention { let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.query), bs); - let (mut k, mut v) = if key_value_states.is_none() { + let (mut k, mut v) = if let Some(key_value_states_value) = key_value_states { ( - self.shape(hidden_states.apply(&self.key), bs), - self.shape(hidden_states.apply(&self.value), bs), + self.shape(key_value_states_value.apply(&self.key), bs), + self.shape(key_value_states_value.apply(&self.value), bs), ) } else { ( - self.shape(key_value_states.as_ref().unwrap().apply(&self.key), bs), - self.shape(key_value_states.as_ref().unwrap().apply(&self.value), bs), + self.shape(hidden_states.apply(&self.key), bs), + self.shape(hidden_states.apply(&self.value), bs), ) }; diff --git a/src/models/t5/encoder.rs b/src/models/t5/encoder.rs index b2e00ce1..74a2c4dc 100644 --- a/src/models/t5/encoder.rs +++ b/src/models/t5/encoder.rs @@ -383,9 +383,9 @@ impl T5Stack { let (batch_size, sequence_length) = (input_shape[0], input_shape[1]); - let mask_seq_length = if old_layer_states.is_some() { - if old_layer_states.as_ref().unwrap()[0].0.is_some() { - old_layer_states.as_ref().unwrap()[0] + let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states { + if old_layer_states_value[0].0.is_some() { + old_layer_states_value[0] .0 .as_ref() .unwrap() From c2ba3d747431169697c06c0027eb3c0969ec3e2d Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 26 Nov 2023 08:52:41 +0000 Subject: [PATCH 7/7] Updated readme --- README.md | 4 ++-- src/lib.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d4fd0c45..2fb596c7 100644 --- a/README.md +++ b/README.md @@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett ### Manual installation (recommended) -1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.0.0`: if this version is no longer available on the "get started" page, -the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version). +1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page, +the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version). 2. Extract the library to a location of your choice 3. Set the following environment variables ##### Linux: diff --git a/src/lib.rs b/src/lib.rs index e84f3c09..e73b16ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,8 +90,8 @@ //! //! ### Manual installation (recommended) //! -//! 1. Download `libtorch` from . This package requires `v2.0`: if this version is no longer available on the "get started" page, -//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11. +//! 1. Download `libtorch` from . This package requires `v2.1`: if this version is no longer available on the "get started" page, +//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. //! 2. Extract the library to a location of your choice //! 3. Set the following environment variables //! ##### Linux: