From f834cc9badded34c167064d0013013eb3bc6783c Mon Sep 17 00:00:00 2001 From: Dotan Nahum Date: Mon, 11 Sep 2023 18:33:12 +0300 Subject: [PATCH 1/2] Including DeBertaV2 for zero-shot --- src/pipelines/zero_shot_classification.rs | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index ea0cc327..10ff9f25 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -102,6 +102,7 @@ use crate::albert::AlbertForSequenceClassification; use crate::bart::BartForSequenceClassification; use crate::bert::BertForSequenceClassification; use crate::deberta::DebertaForSequenceClassification; +use crate::deberta_v2::DebertaV2ForSequenceClassification; use crate::distilbert::DistilBertModelClassifier; use crate::longformer::LongformerForSequenceClassification; use crate::mobilebert::MobileBertForSequenceClassification; @@ -222,6 +223,8 @@ pub enum ZeroShotClassificationOption { Bart(BartForSequenceClassification), /// DeBERTa for Sequence Classification Deberta(DebertaForSequenceClassification), + /// DeBERTaV2 for Sequence Classification + DebertaV2(DebertaV2ForSequenceClassification), /// Bert for Sequence Classification Bert(BertForSequenceClassification), /// DistilBert for Sequence Classification @@ -288,6 +291,17 @@ impl ZeroShotClassificationOption { )) } } + ModelType::DebertaV2 => { + if let ConfigOption::DebertaV2(config) = model_config { + Ok(Self::DebertaV2( + DebertaV2ForSequenceClassification::new(var_store.root(), config)?, + )) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a DebertaConfig for DeBERTaV2!".to_string(), + )) + } + } ModelType::Bert => { if let ConfigOption::Bert(config) = model_config { Ok(Self::Bert( @@ -413,6 +427,7 @@ impl ZeroShotClassificationOption { match *self { Self::Bart(_) => ModelType::Bart, Self::Deberta(_) => ModelType::Deberta, + Self::DebertaV2(_) => ModelType::DebertaV2, Self::Bert(_) => ModelType::Bert, Self::Roberta(_) => ModelType::Roberta, Self::XLMRoberta(_) => ModelType::Roberta, @@ -474,6 +489,19 @@ impl ZeroShotClassificationOption { .expect("Error in DeBERTa forward_t") .logits } + Self::DebertaV2(ref model) => { + model + .forward_t( + input_ids, + mask, + token_type_ids, + position_ids, + input_embeds, + train, + ) + .expect("Error in DeBERTaV2 forward_t") + .logits + } Self::DistilBert(ref model) => { model .forward_t(input_ids, mask, input_embeds, train) From 00b547c6a643f115db3fbfb4b67674088304c6af Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 21 Oct 2023 08:52:38 +0100 Subject: [PATCH 2/2] Ignore Clippy warning --- src/pipelines/zero_shot_classification.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 26c01959..aeddb74a 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -218,6 +218,7 @@ impl Default for ZeroShotClassificationConfig { /// The models are using a classification architecture that should be trained on Natural Language Inference. /// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding /// to contradiction and the last logit corresponding to entailment. +#[allow(clippy::large_enum_variant)] pub enum ZeroShotClassificationOption { /// Bart for Sequence Classification Bart(BartForSequenceClassification),