diff --git a/README.md b/README.md index dcebf15c..e3f380f7 100644 --- a/README.md +++ b/README.md @@ -177,8 +177,8 @@ installation instructions/support. Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using -the [optimum](https://github.com/huggingface/optimum) library. A detailed guide -on how to export a Transformer model to ONNX using optimum is available at +the [Optimum](https://github.com/huggingface/optimum) library. A detailed guide +on how to export a Transformer model to ONNX using Optimum is available at https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible @@ -197,7 +197,7 @@ Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided since the model will not used cached past keys and values for the attention mechanism, leading to a high number of redundant computations. The Optimum library offers export options to ensure such a -`decoder with past` model file is created. he base encoder and decoder model +`decoder with past` model file is created. The base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively. diff --git a/examples/async-sentiment.rs b/examples/async-sentiment.rs index 9232773a..f3342d1a 100644 --- a/examples/async-sentiment.rs +++ b/examples/async-sentiment.rs @@ -1,11 +1,11 @@ -use std::{ - sync::mpsc, - thread::{self, JoinHandle}, -}; +use std::sync::mpsc; use anyhow::Result; use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel}; -use tokio::{sync::oneshot, task}; +use tokio::{ + sync::oneshot, + task::{self, JoinHandle}, +}; #[tokio::main] async fn main() -> Result<()> { @@ -36,7 +36,7 @@ impl SentimentClassifier { /// to interact with it pub fn spawn() -> (JoinHandle>, SentimentClassifier) { let (sender, receiver) = mpsc::sync_channel(100); - let handle = thread::spawn(move || Self::runner(receiver)); + let handle = task::spawn_blocking(move || Self::runner(receiver)); (handle, SentimentClassifier { sender }) } @@ -57,7 +57,7 @@ impl SentimentClassifier { /// Make the runner predict a sample and return the result pub async fn predict(&self, texts: Vec) -> Result> { let (sender, receiver) = oneshot::channel(); - task::block_in_place(|| self.sender.send((texts, sender)))?; + self.sender.send((texts, sender))?; Ok(receiver.await?) } } diff --git a/src/lib.rs b/src/lib.rs index aebc99ac..511348d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -360,15 +360,15 @@ //! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; //! # fn main() -> anyhow::Result<()> { //! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -//! let input_sentence = "Who are you voting for in 2020?"; -//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -//! let candidate_labels = &["politics", "public health", "economics", "sports"]; -//! let output = sequence_classification_model.predict_multilabel( -//! &[input_sentence, input_sequence_2], -//! candidate_labels, -//! None, -//! 128, -//! ); +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); //! # Ok(()) //! # } //! ``` diff --git a/src/models/bert/bert_model.rs b/src/models/bert/bert_model.rs index 09b2e58f..a45f08e2 100644 --- a/src/models/bert/bert_model.rs +++ b/src/models/bert/bert_model.rs @@ -42,6 +42,11 @@ impl BertModelResources { "bert/model", "https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. + pub const BERT_LARGE: (&'static str, &'static str) = ( + "bert-large/model", + "https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot", + ); /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/model", @@ -75,6 +80,11 @@ impl BertConfigResources { "bert/config", "https://huggingface.co/bert-base-uncased/resolve/main/config.json", ); + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. + pub const BERT_LARGE: (&'static str, &'static str) = ( + "bert-large/config", + "https://huggingface.co/bert-large-uncased/resolve/main/config.json", + ); /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/config", @@ -108,6 +118,11 @@ impl BertVocabResources { "bert/vocab", "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", ); + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. + pub const BERT_LARGE: (&'static str, &'static str) = ( + "bert-large/vocab", + "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", + ); /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/vocab", diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs index 2a3c4319..f83699c7 100644 --- a/src/pipelines/masked_language.rs +++ b/src/pipelines/masked_language.rs @@ -15,15 +15,15 @@ //! a masked word can be specified in the `MaskedLanguageConfig` (`mask_token`). and allows //! multiple masked tokens per input sequence. //! -//! ```no_run -//!use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; -//!use rust_bert::pipelines::common::ModelType; -//!use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; -//!use rust_bert::resources::RemoteResource; -//! fn main() -> anyhow::Result<()> { +//! ```no_run +//! use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; +//! use rust_bert::pipelines::common::ModelType; +//! use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; +//! use rust_bert::resources::RemoteResource; //! +//! fn main() -> anyhow::Result<()> { //! use rust_bert::pipelines::common::ModelResource; -//! let config = MaskedLanguageConfig::new( +//! let config = MaskedLanguageConfig::new( //! ModelType::Bert, //! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT))), //! RemoteResource::from_pretrained(BertConfigResources::BERT), diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index e886d69d..8bda080d 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -205,15 +205,15 @@ //! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; //! # fn main() -> anyhow::Result<()> { //! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -//! let input_sentence = "Who are you voting for in 2020?"; -//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -//! let candidate_labels = &["politics", "public health", "economics", "sports"]; -//! let output = sequence_classification_model.predict_multilabel( -//! &[input_sentence, input_sequence_2], -//! candidate_labels, -//! None, -//! 128, -//! ); +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); //! # Ok(()) //! # } //! ``` diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 642bc35f..8cdb16bc 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -25,15 +25,15 @@ //! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; //! # fn main() -> anyhow::Result<()> { //! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -//! let input_sentence = "Who are you voting for in 2020?"; -//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -//! let candidate_labels = &["politics", "public health", "economics", "sports"]; -//! let output = sequence_classification_model.predict_multilabel( -//! &[input_sentence, input_sequence_2], -//! candidate_labels, -//! None, -//! 128, -//! ); +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); //! # Ok(()) //! # } //! ```