Skip to content

Commit

Permalink
Simplify config switch between CLS token and sequence pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Jun 25, 2024
1 parent a592b03 commit 7beffaa
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 22 deletions.
2 changes: 1 addition & 1 deletion didactic/callbacks/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def freeze_before_training(self, pl_module: CardiacMultimodalRepresentationTask)
# Add optional models/parameters if they're used in the model
if pl_module.hparams.cls_token:
modules_to_freeze.append(pl_module.cls_token)
if pl_module.hparams.sequence_pooling:
else:
modules_to_freeze.append(pl_module.sequence_pooling)

if layers_to_freeze == list(range(num_layers)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ task:
embed_dim: 8
views: ${data.patients_kwargs.views}
cls_token: True
sequence_pooling: False
mtr_p: 0
mt_by_attr: False

Expand Down
26 changes: 6 additions & 20 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
tabular_tokenizer: Optional[TabularEmbedding | DictConfig] = None,
time_series_tokenizer: Optional[TimeSeriesEmbedding | DictConfig] = None,
cls_token: bool = True,
sequence_pooling: bool = False,
mtr_p: float | Tuple[float, float] = 0,
mt_by_attr: bool = False,
*args,
Expand All @@ -68,10 +67,8 @@ def __init__(
tabular_tokenizer: Tokenizer that can process tabular, i.e. patient records, data.
time_series_tokenizer: Tokenizer that can process time-series data.
cross_attention_module: Module to use for cross-attention between the tabular and time-series tokens.
cls_token: Whether to add a CLS token to use as the encoder's output token. Mutually exclusive parameter
with `sequence_pooling`.
sequence_pooling: Whether to perform sequence pooling on the encoder's output tokens. Mutually exclusive
parameter with `cls_token`.
cls_token: If `True`, adds a CLS token to use as the encoder's output token.
If `False`, the output token is obtained by sequence pooling over the encoder's output tokens.
mtr_p: Probability to replace tokens by the learned MASK token, following the Mask Token Replacement (MTR)
data augmentation method.
If a float, the value will be used as masking rate during training (disabled during inference).
Expand Down Expand Up @@ -99,12 +96,6 @@ def __init__(
"model in fully-supervised mode, with the self-supervised loss as an auxiliary term."
)

if cls_token == sequence_pooling:
raise ValueError(
"You should specify either `cls_token` or `sequence_pooling` as the method to reduce the "
"dimensionality of the encoder's output from a sequence of tokens to only one token."
)

if not tabular_tokenizer and tabular_attrs:
raise ValueError(
f"You have requested the following tabular attributes: "
Expand Down Expand Up @@ -273,7 +264,7 @@ def __init__(
# Initialize parameters of method for reducing the dimensionality of the encoder's output to only one token
if self.hparams.cls_token:
self.cls_token = CLSToken(self.hparams.embed_dim)
if self.hparams.sequence_pooling:
else:
self.sequence_pooling = SequencePooling(self.hparams.embed_dim)

if self.hparams.mtr_p:
Expand Down Expand Up @@ -493,17 +484,12 @@ def encode(self, tokens: Tensor, avail_mask: Tensor, enable_augments: bool = Fal
# Forward pass through the transformer encoder
out_tokens = self.encoder(tokens)

if self.hparams.sequence_pooling:
# Perform sequence pooling of the transformers' output tokens
out_features = self.sequence_pooling(out_tokens) # (N, S, E) -> (N, E)
elif self.hparams.cls_token:
if self.hparams.cls_token:
# Only keep the CLS token (i.e. the last token) from the tokens outputted by the encoder
out_features = out_tokens[:, -1, :] # (N, S, E) -> (N, E)
else:
raise AssertionError(
"Either `cls_token` or `sequence_pooling` should have been enabled as the method to reduce the "
"dimensionality of the encoder's output from a sequence of tokens to only one token."
)
# Perform sequence pooling of the transformers' output tokens
out_features = self.sequence_pooling(out_tokens) # (N, S, E) -> (N, E)

return out_features

Expand Down

0 comments on commit 7beffaa

Please sign in to comment.