Skip to content

Commit

Permalink
Fix name of token augmentations flag to properly explain its behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Nov 20, 2023
1 parent 4c802ba commit 3d23c8d
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,17 +421,17 @@ def tokenize(
return tokens, notna_mask

@auto_move_data
def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = False) -> Tensor:
def encode(self, tokens: Tensor, avail_mask: Tensor, enable_augments: bool = False) -> Tensor:
"""Embeds input sequences using the encoder model, optionally selecting/pooling output tokens for the embedding.
Args:
tokens: (N, S, E), Tokens to feed to the encoder.
avail_mask: (N, S), Boolean mask indicating available (i.e. non-missing) tokens. Missing tokens can thus be
treated distinctly from others (e.g. replaced w/ a specific mask).
disable_augments: Whether to perform augments on the tokens (e.g. masking). Normally augments will
be performed differently (if not outright disabled) when not in training, but this parameter allows to
disable them even during training. This is useful to compute "uncorrupted" views of the data for
contrastive learning.
enable_augments: Whether to perform augments on the tokens (e.g. masking) to obtain a "corrupted" view for
contrastive learning. Augments are already configured differently for training/testing (to avoid
stochastic test-time predictions), so this parameter is simply useful to easily toggle augments on/off
to obtain contrasting views.
Returns: (N, E), Embeddings of the input sequences.
"""
Expand All @@ -445,7 +445,7 @@ def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = Fa
tokens = mask_tokens(tokens, mask_token, ~avail_mask)

mtr_p = self.train_mtr_p if self.training else self.test_mtr_p
if mtr_p and disable_augments:
if mtr_p and enable_augments:
# Mask Token Replacement (MTR) data augmentation
# Replace random non-missing tokens with the mask token to perturb the input
tokens, _ = random_masking(tokens, mask_token, mtr_p)
Expand Down Expand Up @@ -603,8 +603,9 @@ def _prediction_shared_step(
def _contrastive_shared_step(
self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor
) -> Dict[str, Tensor]:
corrupted_out_features = out_features # Features from a view corrupted by augmentations
anchor_out_features = self.encode(in_tokens, avail_mask, disable_augments=True)
# Extract features from the original view + from a view corrupted by augmentations
anchor_out_features = out_features
corrupted_out_features = self.encode(in_tokens, avail_mask, enable_augments=True)

# Compute the contrastive loss/metrics
metrics = {
Expand Down

0 comments on commit 3d23c8d

Please sign in to comment.