diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 31468ca..b67285e 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -421,17 +421,17 @@ def tokenize( return tokens, notna_mask @auto_move_data - def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = False) -> Tensor: + def encode(self, tokens: Tensor, avail_mask: Tensor, enable_augments: bool = False) -> Tensor: """Embeds input sequences using the encoder model, optionally selecting/pooling output tokens for the embedding. Args: tokens: (N, S, E), Tokens to feed to the encoder. avail_mask: (N, S), Boolean mask indicating available (i.e. non-missing) tokens. Missing tokens can thus be treated distinctly from others (e.g. replaced w/ a specific mask). - disable_augments: Whether to perform augments on the tokens (e.g. masking). Normally augments will - be performed differently (if not outright disabled) when not in training, but this parameter allows to - disable them even during training. This is useful to compute "uncorrupted" views of the data for - contrastive learning. + enable_augments: Whether to perform augments on the tokens (e.g. masking) to obtain a "corrupted" view for + contrastive learning. Augments are already configured differently for training/testing (to avoid + stochastic test-time predictions), so this parameter is simply useful to easily toggle augments on/off + to obtain contrasting views. Returns: (N, E), Embeddings of the input sequences. """ @@ -445,7 +445,7 @@ def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = Fa tokens = mask_tokens(tokens, mask_token, ~avail_mask) mtr_p = self.train_mtr_p if self.training else self.test_mtr_p - if mtr_p and disable_augments: + if mtr_p and enable_augments: # Mask Token Replacement (MTR) data augmentation # Replace random non-missing tokens with the mask token to perturb the input tokens, _ = random_masking(tokens, mask_token, mtr_p) @@ -603,8 +603,9 @@ def _prediction_shared_step( def _contrastive_shared_step( self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor ) -> Dict[str, Tensor]: - corrupted_out_features = out_features # Features from a view corrupted by augmentations - anchor_out_features = self.encode(in_tokens, avail_mask, disable_augments=True) + # Extract features from the original view + from a view corrupted by augmentations + anchor_out_features = out_features + corrupted_out_features = self.encode(in_tokens, avail_mask, enable_augments=True) # Compute the contrastive loss/metrics metrics = {