Skip to content

Commit

Permalink
Update finetuning callback to match new positional encoding/sequentia…
Browse files Browse the repository at this point in the history
…l pooling names
  • Loading branch information
nathanpainchaud committed Oct 24, 2023
1 parent b25147d commit 9d4fc7b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions didactic/callbacks/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def freeze_before_training(self, pl_module: CardiacMultimodalRepresentationTask)
# If encoder's first layer is frozen, then it is also necessary to freeze everything upstream of
# the encoder (e.g. CLS token, tokenizers, positional embedding, etc.) to make sure the encoder's
# inputs remain the same
params_to_freeze.append(pl_module.positional_embedding)
params_to_freeze.append(pl_module.positional_encoding)

# Check if tokenizers are models before marking them to be frozen, so that we'll not try to freeze
# other possible types of tokenizers that are not `nn.Module`s (e.g. functions)
Expand All @@ -84,7 +84,7 @@ def freeze_before_training(self, pl_module: CardiacMultimodalRepresentationTask)
if pl_module.hparams.latent_token:
modules_to_freeze.append(pl_module.latent_token)
if pl_module.hparams.sequential_pooling:
modules_to_freeze.append(pl_module.attention_pool)
modules_to_freeze.append(pl_module.sequential_pooling)

if layers_to_freeze == list(range(num_layers)):
# If all layers of the encoder are frozen
Expand Down

0 comments on commit 9d4fc7b

Please sign in to comment.