Skip to content

Commit

Permalink
make ZeroShotClassification support the roberta-large-mnli
Browse files Browse the repository at this point in the history
I don't know if it's correct, so a review would be appreciated
  • Loading branch information
Charles Samuels committed Sep 21, 2023
1 parent 0a1c3df commit 49b7a5d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = model_config {
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::Roberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
Expand Down Expand Up @@ -491,7 +491,7 @@ impl ZeroShotClassificationOption {
.forward_t(
input_ids,
mask,
token_type_ids,
None,
position_ids,
input_embeds,
train,
Expand Down

0 comments on commit 49b7a5d

Please sign in to comment.