Skip to content

Commit

Permalink
Expose the t5 config fields + allow t5-large. (huggingface#1987)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 1, 2024
1 parent ea0d8d3 commit be9c200
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
2 changes: 2 additions & 0 deletions candle-examples/examples/t5/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const DTYPE: DType = DType::F32;
enum Which {
T5Base,
T5Small,
T5Large,
T5_3B,
Mt5Base,
Mt5Small,
Expand Down Expand Up @@ -108,6 +109,7 @@ impl T5ModelBuilder {
let (default_model, default_revision) = match args.which {
Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5Large => ("t5-large", "main"),
Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
Expand Down
32 changes: 16 additions & 16 deletions candle-transformers/src/models/t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,26 @@ where

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
d_model: usize,
d_kv: usize,
d_ff: usize,
num_layers: usize,
num_decoder_layers: Option<usize>,
num_heads: usize,
relative_attention_num_buckets: usize,
pub vocab_size: usize,
pub d_model: usize,
pub d_kv: usize,
pub d_ff: usize,
pub num_layers: usize,
pub num_decoder_layers: Option<usize>,
pub num_heads: usize,
pub relative_attention_num_buckets: usize,
#[serde(default = "default_relative_attention_max_distance")]
relative_attention_max_distance: usize,
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
pub relative_attention_max_distance: usize,
pub dropout_rate: f64,
pub layer_norm_epsilon: f64,
pub initializer_factor: f64,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
feed_forward_proj: ActivationWithOptionalGating,
pub feed_forward_proj: ActivationWithOptionalGating,
#[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool,
pub tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
is_decoder: bool,
is_encoder_decoder: bool,
pub is_decoder: bool,
pub is_encoder_decoder: bool,
#[serde(default = "default_use_cache")]
pub use_cache: bool,
pub pad_token_id: usize,
Expand Down

0 comments on commit be9c200

Please sign in to comment.