From 62b40d04b5f2bbb294eda8d2b2e49cb3d1f87cc1 Mon Sep 17 00:00:00 2001 From: Kavan <20254930+kavan-mevada@users.noreply.github.com> Date: Sun, 18 Aug 2024 05:24:35 -0400 Subject: [PATCH 1/4] added bert-large-uncased model (#439) Co-authored-by: guillaume-be --- src/models/bert/bert_model.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/models/bert/bert_model.rs b/src/models/bert/bert_model.rs index 09b2e58f..a45f08e2 100644 --- a/src/models/bert/bert_model.rs +++ b/src/models/bert/bert_model.rs @@ -42,6 +42,11 @@ impl BertModelResources { "bert/model", "https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. + pub const BERT_LARGE: (&'static str, &'static str) = ( + "bert-large/model", + "https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot", + ); /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/model", @@ -75,6 +80,11 @@ impl BertConfigResources { "bert/config", "https://huggingface.co/bert-base-uncased/resolve/main/config.json", ); + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. + pub const BERT_LARGE: (&'static str, &'static str) = ( + "bert-large/config", + "https://huggingface.co/bert-large-uncased/resolve/main/config.json", + ); /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/config", @@ -108,6 +118,11 @@ impl BertVocabResources { "bert/vocab", "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", ); + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. + pub const BERT_LARGE: (&'static str, &'static str) = ( + "bert-large/vocab", + "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", + ); /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/vocab", From db1ed8ccbce3b48d3d1d0c149d98f9b4a80317ba Mon Sep 17 00:00:00 2001 From: Yousef Abu Shanab <93343012+youzarsiph@users.noreply.github.com> Date: Sun, 18 Aug 2024 13:57:06 +0300 Subject: [PATCH 2/4] Fix formatting error in `src/pipelines/masked_language.rs` (#456) * Fix formatting error in src/piplines/masked_language.rs * Update Zero-shot classification docs to fix formatting errors --------- Co-authored-by: guillaume-be --- src/lib.rs | 18 +++++++++--------- src/pipelines/masked_language.rs | 14 +++++++------- src/pipelines/mod.rs | 18 +++++++++--------- src/pipelines/zero_shot_classification.rs | 18 +++++++++--------- 4 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3363374e..20164892 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -360,15 +360,15 @@ //! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; //! # fn main() -> anyhow::Result<()> { //! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -//! let input_sentence = "Who are you voting for in 2020?"; -//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -//! let candidate_labels = &["politics", "public health", "economics", "sports"]; -//! let output = sequence_classification_model.predict_multilabel( -//! &[input_sentence, input_sequence_2], -//! candidate_labels, -//! None, -//! 128, -//! ); +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); //! # Ok(()) //! # } //! ``` diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs index 2a3c4319..f83699c7 100644 --- a/src/pipelines/masked_language.rs +++ b/src/pipelines/masked_language.rs @@ -15,15 +15,15 @@ //! a masked word can be specified in the `MaskedLanguageConfig` (`mask_token`). and allows //! multiple masked tokens per input sequence. //! -//! ```no_run -//!use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; -//!use rust_bert::pipelines::common::ModelType; -//!use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; -//!use rust_bert::resources::RemoteResource; -//! fn main() -> anyhow::Result<()> { +//! ```no_run +//! use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; +//! use rust_bert::pipelines::common::ModelType; +//! use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; +//! use rust_bert::resources::RemoteResource; //! +//! fn main() -> anyhow::Result<()> { //! use rust_bert::pipelines::common::ModelResource; -//! let config = MaskedLanguageConfig::new( +//! let config = MaskedLanguageConfig::new( //! ModelType::Bert, //! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT))), //! RemoteResource::from_pretrained(BertConfigResources::BERT), diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index e886d69d..8bda080d 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -205,15 +205,15 @@ //! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; //! # fn main() -> anyhow::Result<()> { //! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -//! let input_sentence = "Who are you voting for in 2020?"; -//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -//! let candidate_labels = &["politics", "public health", "economics", "sports"]; -//! let output = sequence_classification_model.predict_multilabel( -//! &[input_sentence, input_sequence_2], -//! candidate_labels, -//! None, -//! 128, -//! ); +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); //! # Ok(()) //! # } //! ``` diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 642bc35f..8cdb16bc 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -25,15 +25,15 @@ //! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; //! # fn main() -> anyhow::Result<()> { //! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -//! let input_sentence = "Who are you voting for in 2020?"; -//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -//! let candidate_labels = &["politics", "public health", "economics", "sports"]; -//! let output = sequence_classification_model.predict_multilabel( -//! &[input_sentence, input_sequence_2], -//! candidate_labels, -//! None, -//! 128, -//! ); +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); //! # Ok(()) //! # } //! ``` From 9af98f8d60aa6c3ce97dbd55dec1800b3a2d0db0 Mon Sep 17 00:00:00 2001 From: Hiroki <39700763+hkfi@users.noreply.github.com> Date: Sun, 18 Aug 2024 14:07:19 +0300 Subject: [PATCH 3/4] Update README.md (#460) Fixing typos Co-authored-by: guillaume-be --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2fd91270..f91cf7ff 100644 --- a/README.md +++ b/README.md @@ -177,8 +177,8 @@ installation instructions/support. Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using -the [optimum](https://github.com/huggingface/optimum) library. A detailed guide -on how to export a Transformer model to ONNX using optimum is available at +the [Optimum](https://github.com/huggingface/optimum) library. A detailed guide +on how to export a Transformer model to ONNX using Optimum is available at https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible @@ -197,7 +197,7 @@ Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided since the model will not used cached past keys and values for the attention mechanism, leading to a high number of redundant computations. The Optimum library offers export options to ensure such a -`decoder with past` model file is created. he base encoder and decoder model +`decoder with past` model file is created. The base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively. From 9707981ec363e11f69ee45f883021558c4867836 Mon Sep 17 00:00:00 2001 From: Hsiang-Cheng Yang Date: Tue, 20 Aug 2024 04:16:58 +0800 Subject: [PATCH 4/4] Update async-sentiment.rs (#337) * Update async-sentiment.rs use `tokio::task::spawn_blocking` instead of `std::thread::spawn` * fmt --------- Co-authored-by: guillaume-be --- examples/async-sentiment.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/async-sentiment.rs b/examples/async-sentiment.rs index 9232773a..f3342d1a 100644 --- a/examples/async-sentiment.rs +++ b/examples/async-sentiment.rs @@ -1,11 +1,11 @@ -use std::{ - sync::mpsc, - thread::{self, JoinHandle}, -}; +use std::sync::mpsc; use anyhow::Result; use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel}; -use tokio::{sync::oneshot, task}; +use tokio::{ + sync::oneshot, + task::{self, JoinHandle}, +}; #[tokio::main] async fn main() -> Result<()> { @@ -36,7 +36,7 @@ impl SentimentClassifier { /// to interact with it pub fn spawn() -> (JoinHandle>, SentimentClassifier) { let (sender, receiver) = mpsc::sync_channel(100); - let handle = thread::spawn(move || Self::runner(receiver)); + let handle = task::spawn_blocking(move || Self::runner(receiver)); (handle, SentimentClassifier { sender }) } @@ -57,7 +57,7 @@ impl SentimentClassifier { /// Make the runner predict a sample and return the result pub async fn predict(&self, texts: Vec) -> Result> { let (sender, receiver) = oneshot::channel(); - task::block_in_place(|| self.sender.send((texts, sender)))?; + self.sender.send((texts, sender))?; Ok(receiver.await?) } }