From 9eedc1b3d4d39dff77a3aa5d0ae882fd187d5a73 Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Wed, 1 Nov 2023 00:48:10 +0100 Subject: [PATCH] Remove `attrs_droput` data augmentation dependent on PyTorch transformer API Unify missing data handling using Mask Token Replacement (MTR), w/ missing tokens being replaced by mask token --- .../cardinal/multimodal-xformer-finetune.yaml | 1 - .../cardinal/multimodal-xformer-head.yaml | 1 - .../cardinal/multimodal-xformer-scratch.yaml | 1 - .../cardinal/multimodal-xformer.yaml | 4 +- .../experiment/cardinal/xtab-finetune.yaml | 2 +- .../cardiac_multimodal_representation.py | 84 +++++-------------- 6 files changed, 24 insertions(+), 69 deletions(-) diff --git a/didactic/config/experiment/cardinal/multimodal-xformer-finetune.yaml b/didactic/config/experiment/cardinal/multimodal-xformer-finetune.yaml index 5fd51e45..2595ed0a 100644 --- a/didactic/config/experiment/cardinal/multimodal-xformer-finetune.yaml +++ b/didactic/config/experiment/cardinal/multimodal-xformer-finetune.yaml @@ -15,7 +15,6 @@ task: contrastive_loss: _target_: vital.metrics.train.metric.NTXent contrastive_loss_weight: 0 - attrs_dropout: [ 0.1, 0 ] mtr_p: [ 0.3, 0 ] callbacks: diff --git a/didactic/config/experiment/cardinal/multimodal-xformer-head.yaml b/didactic/config/experiment/cardinal/multimodal-xformer-head.yaml index 54f8760b..3532b60b 100644 --- a/didactic/config/experiment/cardinal/multimodal-xformer-head.yaml +++ b/didactic/config/experiment/cardinal/multimodal-xformer-head.yaml @@ -12,7 +12,6 @@ excluded_clinical_attrs: ${oc.dict.keys:task.predict_losses} task: predict_losses: ??? - attrs_dropout: [0.1, 0] callbacks: transformer_encoder_freeze: diff --git a/didactic/config/experiment/cardinal/multimodal-xformer-scratch.yaml b/didactic/config/experiment/cardinal/multimodal-xformer-scratch.yaml index da3b8d45..434dbfee 100644 --- a/didactic/config/experiment/cardinal/multimodal-xformer-scratch.yaml +++ b/didactic/config/experiment/cardinal/multimodal-xformer-scratch.yaml @@ -13,7 +13,6 @@ task: contrastive_loss: _target_: vital.metrics.train.metric.NTXent contrastive_loss_weight: 0 - attrs_dropout: [ 0.1, 0 ] mtr_p: [ 0.3, 0 ] hydra: diff --git a/didactic/config/experiment/cardinal/multimodal-xformer.yaml b/didactic/config/experiment/cardinal/multimodal-xformer.yaml index 84a4c2df..9252fee6 100644 --- a/didactic/config/experiment/cardinal/multimodal-xformer.yaml +++ b/didactic/config/experiment/cardinal/multimodal-xformer.yaml @@ -104,7 +104,6 @@ task: sequential_pooling: False mtr_p: 0 mt_by_attr: False - attrs_dropout: 0 clinical_tokenizer: _target_: rtdl.FeatureTokenizer @@ -123,7 +122,7 @@ callbacks: _target_: pytorch_lightning.callbacks.LearningRateFinder -experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout} +experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr} hydra: job: config: @@ -158,7 +157,6 @@ hydra: - task.constraint.clustering_model - task.mtr_p - task.mt_by_attr - - task.attrs_dropout - task.embed_dim - task/img_tokenizer/model diff --git a/didactic/config/experiment/cardinal/xtab-finetune.yaml b/didactic/config/experiment/cardinal/xtab-finetune.yaml index d9864a1a..89c61018 100644 --- a/didactic/config/experiment/cardinal/xtab-finetune.yaml +++ b/didactic/config/experiment/cardinal/xtab-finetune.yaml @@ -34,7 +34,7 @@ ckpt: ??? # Make it mandatory to provide a checkpoint weights_only: True # Only load the weights and ignore the hyperparameters strict: False # Only load weights where they match the defined network, to only some changes (e.g. heads, etc.) -experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.n_blocks},nhead=${task.model.encoder.attention_n_heads},dropout=${task.model.encoder.attention_dropout},${task.model.encoder.ffn_dropout},${task.model.encoder.residual_dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout} +experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.n_blocks},nhead=${task.model.encoder.attention_n_heads},dropout=${task.model.encoder.attention_dropout},${task.model.encoder.ffn_dropout},${task.model.encoder.residual_dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr} hydra: run: dir: ${oc.env:CARDIAC_MULTIMODAL_REPR_PATH}/xtab-finetune/${experiment_dirname}/targets=${oc.dict.keys:task.predict_losses}/${hydra.job.override_dirname} diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 81fbc81d..5a7d8127 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -7,13 +7,12 @@ import hydra import rtdl import torch -import torch.nn.functional as F from omegaconf import DictConfig from rtdl import FeatureTokenizer from torch import Tensor, nn from torch.nn import Parameter, ParameterDict, init from torchmetrics.functional import accuracy, mean_absolute_error -from vital.data.augmentation.base import random_masking +from vital.data.augmentation.base import mask_tokens, random_masking from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes @@ -45,7 +44,6 @@ def __init__( sequential_pooling: bool = False, mtr_p: float | Tuple[float, float] = 0, mt_by_attr: bool = False, - attrs_dropout: float | Tuple[float, float] = 0, *args, **kwargs, ): @@ -73,9 +71,6 @@ def __init__( If a tuple, specify a masking rate to use during training and inference, respectively. mt_by_attr: Whether to use one MASK token per attribute (`True`), or one universal MASK token for all attributes (`False`). - attrs_dropout: Probability of randomly masking tokens, effectively dropping them, to simulate missing data. - If a float, the value will be used as dropout rate during training (disabled during inference). - If a tuple, specify a dropout rate to use during training and inference, respectively. *args: Positional arguments to pass to the parent's constructor. **kwargs: Keyword arguments to pass to the parent's constructor. """ @@ -88,8 +83,6 @@ def __init__( # If dropout/masking are not single numbers, make sure they are tuples (and not another container type) if not isinstance(mtr_p, (int, float)): mtr_p = tuple(mtr_p) - if not isinstance(attrs_dropout, (int, float)): - attrs_dropout = tuple(attrs_dropout) if contrastive_loss is None and predict_losses is None: raise ValueError( @@ -149,12 +142,7 @@ def __init__( len(CLINICAL_CAT_ATTR_LABELS[cat_attr]) for cat_attr in self.clinical_cat_attrs ] - # Extract train/test dropout/masking probabilities from their configs - if isinstance(self.hparams.attrs_dropout, tuple): - self.train_attrs_dropout, self.test_attrs_dropout = self.hparams.attrs_dropout - else: - self.train_attrs_dropout = self.hparams.attrs_dropout - self.test_attrs_dropout = 0 + # Extract train/test masking probabilities from their configs if isinstance(self.hparams.mtr_p, tuple): self.train_mtr_p, self.test_mtr_p = self.hparams.mtr_p else: @@ -367,9 +355,10 @@ def tokenize( [clinical_attrs[attr].unsqueeze(1) for attr in self.clinical_cat_attrs] ) # (N, S_cat) # Use "sanitized" version of the inputs, where invalid values are replaced by null/default values, for the - # tokenization process. Since the embeddings of the missing tokens will be ignored later on using the attention - # mask anyway, it doesn't matter that the embeddings returned are not "accurate"; it only matters that the - # tokenization doesn't crash or returns NaNs + # tokenization process. This is done to avoid propagating NaNs to available/valid values. + # If the embeddings cannot be ignored later on (e.g. by using an attention mask during inference), they + # should be replaced w/ a more distinct value to indicate that they are missing (e.g. a specific token), + # instead of their current null/default values. # 1) Convert missing numerical attributes (NaNs) to numbers to avoid propagating NaNs # 2) Clip categorical labels to convert indicators of missing data (-1) into valid indices (0) clinical_attrs_tokens = self.clinical_tokenizer( @@ -400,14 +389,14 @@ def tokenize( return tokens, notna_mask @auto_move_data - def encode(self, tokens: Tensor, avail_mask: Tensor, apply_augments: bool = True) -> Tensor: + def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = False) -> Tensor: """Embeds input sequences using the encoder model, optionally selecting/pooling output tokens for the embedding. Args: tokens: (N, S, E), Tokens to feed to the encoder. - avail_mask: (N, S), Mask indicating available (i.e. non-missing) tokens. Missing tokens will not be attended - to by the encoder. - apply_augments: Whether to perform augments on the tokens (e.g. dropout, masking). Normally augments will + avail_mask: (N, S), Boolean mask indicating available (i.e. non-missing) tokens. Missing tokens can thus be + treated distinctly from others (e.g. replaced w/ a specific mask). + disable_augments: Whether to perform augments on the tokens (e.g. masking). Normally augments will be performed differently (if not outright disabled) when not in training, but this parameter allows to disable them even during training. This is useful to compute "uncorrupted" views of the data for contrastive learning. @@ -415,56 +404,27 @@ def encode(self, tokens: Tensor, avail_mask: Tensor, apply_augments: bool = True Returns: (N, E) or (N, S * E), Embeddings of the input sequences. The shape of the embeddings depends on the selection/pooling applied on the output tokens. """ - # Cast attention map to float to be able to perform matmul (and the underlying addmul operations), since Pytorch - # doesn't support addmul for int types (see this issue: https://github.com/pytorch/pytorch/issues/44428) - avail_mask = avail_mask.float() - # Default to attend to all non-missing tokens - attn_mask = torch.ones_like(avail_mask) - - dropout = self.train_attrs_dropout if self.training else self.test_attrs_dropout - if dropout and apply_augments: - # Draw independent Bernoulli samples for each item/attribute pair in the batch, representing whether - # to keep (1) or drop (0) attributes for each item - dropout_dist = torch.full_like(avail_mask, 1 - dropout) - keep_mask = torch.bernoulli(dropout_dist) - - # Repeat the sampling in case all attributes are dropped, missing or masked for an item - while not (keep_mask * avail_mask).any(dim=1).all(dim=0): - keep_mask = torch.bernoulli(dropout_dist) + mask_token = self.mask_token + if isinstance(mask_token, ParameterDict): + mask_token = torch.stack(list(mask_token.values())) - attn_mask *= keep_mask + if mask_token is not None: + # If a mask token is configured, substitute the missing tokens with the mask token to distinguish them from + # the other tokens + tokens = mask_tokens(tokens, mask_token, ~avail_mask) mtr_p = self.train_mtr_p if self.training else self.test_mtr_p - if mtr_p and apply_augments: + if mtr_p and disable_augments: # Mask Token Replacement (MTR) data augmentation - mask_token = self.mask_token - if isinstance(mask_token, ParameterDict): - mask_token = torch.stack(list(mask_token.values())) + # Replace random non-missing tokens with the mask token to perturb the input tokens, _ = random_masking(tokens, mask_token, mtr_p) if self.hparams.latent_token: # Add the latent token to the end of each item in the batch tokens = self.latent_token(tokens) - # Pad attention mask to account for latent token only after dropout, so that latent token is always kept - attn_mask = F.pad(attn_mask, (0, 1), value=1) - - # Build attention mask that avoids attending to missing tokens - attn_mask = torch.stack( - [item_attn_mask[None].T @ item_attn_mask[None] for item_attn_mask in attn_mask] - ) # (N, S, S) - # Cast attention mask back to bool and flip (because Pytorch's MHA expects true/non-zero values to mark where - # NOT to attend) - attn_mask = ~(attn_mask.bool()) - # Repeat the mask to have it be identical for each head of the multi-head attention - # (to respect Pytorch's expected attention mask format) - attn_mask = attn_mask.repeat_interleave(self.nhead, dim=0) # (N * nhead, S, S) - - # Add positional embedding to the tokens + forward pass through the transformer encoder - kwargs = {} - if isinstance(self.encoder, nn.TransformerEncoder): - kwargs["mask"] = attn_mask - out_tokens = self.encoder(self.positional_encoding(tokens), **kwargs) + # Forward pass through the transformer encoder + out_tokens = self.encoder(self.positional_encoding(tokens)) if self.hparams.sequential_pooling: # Perform sequential pooling of the transformers' output tokens @@ -573,7 +533,7 @@ def _contrastive_shared_step( self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor ) -> Dict[str, Tensor]: corrupted_out_features = out_features # Features from a view corrupted by augmentations - anchor_out_features = self.encode(in_tokens, avail_mask, apply_augments=False) + anchor_out_features = self.encode(in_tokens, avail_mask, disable_augments=True) # Compute the contrastive loss/metrics metrics = {