Skip to content

Commit

Permalink
Merge branch 'main' into tch-update
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be authored Aug 19, 2024
2 parents be58b94 + 9707981 commit c003088
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 44 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ installation instructions/support.

Most architectures (including encoders, decoders and encoder-decoders) are
supported. the library aims at keeping compatibility with models exported using
the [optimum](https://github.com/huggingface/optimum) library. A detailed guide
on how to export a Transformer model to ONNX using optimum is available at
the [Optimum](https://github.com/huggingface/optimum) library. A detailed guide
on how to export a Transformer model to ONNX using Optimum is available at
https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model
The resources used to create ONNX models are similar to those based on Pytorch,
replacing the pytorch by the ONNX model. Since ONNX models are less flexible
Expand All @@ -197,7 +197,7 @@ Note that the computational efficiency will drop when the `decoder with past`
file is optional but not provided since the model will not used cached past keys
and values for the attention mechanism, leading to a high number of redundant
computations. The Optimum library offers export options to ensure such a
`decoder with past` model file is created. he base encoder and decoder model
`decoder with past` model file is created. The base encoder and decoder model
architecture are available (and exposed for convenience) in the `encoder` and
`decoder` modules, respectively.

Expand Down
14 changes: 7 additions & 7 deletions examples/async-sentiment.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::{
sync::mpsc,
thread::{self, JoinHandle},
};
use std::sync::mpsc;

use anyhow::Result;
use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel};
use tokio::{sync::oneshot, task};
use tokio::{
sync::oneshot,
task::{self, JoinHandle},
};

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -36,7 +36,7 @@ impl SentimentClassifier {
/// to interact with it
pub fn spawn() -> (JoinHandle<Result<()>>, SentimentClassifier) {
let (sender, receiver) = mpsc::sync_channel(100);
let handle = thread::spawn(move || Self::runner(receiver));
let handle = task::spawn_blocking(move || Self::runner(receiver));
(handle, SentimentClassifier { sender })
}

Expand All @@ -57,7 +57,7 @@ impl SentimentClassifier {
/// Make the runner predict a sample and return the result
pub async fn predict(&self, texts: Vec<String>) -> Result<Vec<Sentiment>> {
let (sender, receiver) = oneshot::channel();
task::block_in_place(|| self.sender.send((texts, sender)))?;
self.sender.send((texts, sender))?;
Ok(receiver.await?)
}
}
18 changes: 9 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,15 @@
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
Expand Down
15 changes: 15 additions & 0 deletions src/models/bert/bert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ impl BertModelResources {
"bert/model",
"https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/model",
"https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model",
Expand Down Expand Up @@ -75,6 +80,11 @@ impl BertConfigResources {
"bert/config",
"https://huggingface.co/bert-base-uncased/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/config",
"https://huggingface.co/bert-large-uncased/resolve/main/config.json",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config",
Expand Down Expand Up @@ -108,6 +118,11 @@ impl BertVocabResources {
"bert/vocab",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/vocab",
"https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab",
Expand Down
14 changes: 7 additions & 7 deletions src/pipelines/masked_language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
//! a masked word can be specified in the `MaskedLanguageConfig` (`mask_token`). and allows
//! multiple masked tokens per input sequence.
//!
//! ```no_run
//!use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
//!use rust_bert::pipelines::common::ModelType;
//!use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
//!use rust_bert::resources::RemoteResource;
//! fn main() -> anyhow::Result<()> {
//! ```no_run
//! use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
//! use rust_bert::resources::RemoteResource;
//!
//! fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelResource;
//! let config = MaskedLanguageConfig::new(
//! let config = MaskedLanguageConfig::new(
//! ModelType::Bert,
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT))),
//! RemoteResource::from_pretrained(BertConfigResources::BERT),
Expand Down
18 changes: 9 additions & 9 deletions src/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
Expand Down
18 changes: 9 additions & 9 deletions src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
Expand Down

0 comments on commit c003088

Please sign in to comment.