Skip to content

Commit

Permalink
Refactor encoder model config inside its own config folder
Browse files Browse the repository at this point in the history
This has the benefit of making the model config more intuitive and flexible
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 3545183 commit 47f2902
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 54 deletions.
5 changes: 3 additions & 2 deletions didactic/config/experiment/cardinal/multimodal-xformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

defaults:
- /task/img_tokenizer/model: linear-embedding
- override /task/model: cardinal-ft-transformer
- /task/model/encoder: ???
- override /task/model: null # Set this to null because we specify multiple submodels instead of a singleton model
- override /task/optim: adamw
- override /data: cardinal

Expand Down Expand Up @@ -118,7 +119,7 @@ callbacks:
_target_: pytorch_lightning.callbacks.LearningRateFinder


experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr}
experiment_dirname: encoder=${hydra:runtime.choices.task/model/encoder}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr}
hydra:
job:
config:
Expand Down
4 changes: 2 additions & 2 deletions didactic/config/experiment/cardinal/xtab-finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

defaults:
- cardinal/multimodal-xformer
- override /task/model: xtab-ft-transformer
- override /task/model/encoder: xtab-ft-transformer

trainer:
max_steps: 2000
Expand Down Expand Up @@ -34,7 +34,7 @@ ckpt: ??? # Make it mandatory to provide a checkpoint
weights_only: True # Only load the weights and ignore the hyperparameters
strict: False # Only load weights where they match the defined network, to only some changes (e.g. heads, etc.)

experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.n_blocks},nhead=${task.model.encoder.attention_n_heads},dropout=${task.model.encoder.attention_dropout},${task.model.encoder.ffn_dropout},${task.model.encoder.residual_dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr}
experiment_dirname: encoder=${hydra:runtime.choices.task/model/encoder}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.n_blocks},nhead=${task.model.encoder.attention_n_heads},dropout=${task.model.encoder.attention_dropout},${task.model.encoder.ffn_dropout},${task.model.encoder.residual_dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr}
hydra:
run:
dir: ${oc.env:CARDIAC_MULTIMODAL_REPR_PATH}/xtab-finetune/${experiment_dirname}/targets=${oc.dict.keys:task.predict_losses}/${hydra.job.override_dirname}
Expand Down
10 changes: 0 additions & 10 deletions didactic/config/task/model/cardinal-ft-transformer.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- torch-transformer-encoder

num_layers: 6
encoder_layer:
d_model: ${task.embed_dim}
nhead: 2
dim_feedforward: ${op.mul:1.5,${task.model.encoder.encoder_layer.d_model},int}
dropout: 0.1
16 changes: 16 additions & 0 deletions didactic/config/task/model/encoder/torch-transformer-encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: torch.nn.TransformerEncoder
num_layers: 1

norm:
_target_: torch.nn.LayerNorm
normalized_shape: ${task.model.encoder.encoder_layer.d_model}

encoder_layer:
_target_: torch.nn.TransformerEncoderLayer
d_model: ???
nhead: 1
dim_feedforward: 2048
dropout: 0.1
activation: relu
batch_first: True
norm_first: True
22 changes: 22 additions & 0 deletions didactic/config/task/model/encoder/xtab-ft-transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_target_: autogluon.multimodal.models.ft_transformer.FT_Transformer

d_token: 192
n_blocks: 3
attention_n_heads: 8
attention_dropout: 0.2
attention_initialization: kaiming
attention_normalization: layer_norm
ffn_d_hidden: ${task.model.encoder.d_token}
ffn_dropout: 0.1
ffn_activation: reglu
ffn_normalization: layer_norm
residual_dropout: 0.1
prenormalization: True
first_prenormalization: False
last_layer_query_idx: null
n_tokens: null # Only used when compressing the input sequence (`kv_compression_ratio is not None`)
kv_compression_ratio: null # Only used when compressing the input sequence (`kv_compression_ratio is not None`)
kv_compression_sharing: null # Only used when compressing the input sequence (`kv_compression_ratio is not None`)
head_activation: False # Only used when using a projection head (`projection=True`)
head_normalization: null # Only used when using a projection head (`projection=True`)
d_out: null # Only used when using a projection head (`projection=True`)
17 changes: 0 additions & 17 deletions didactic/config/task/model/torch-transformer-encoder.yaml

This file was deleted.

23 changes: 0 additions & 23 deletions didactic/config/task/model/xtab-ft-transformer.yaml

This file was deleted.

0 comments on commit 47f2902

Please sign in to comment.