Skip to content

Commit

Permalink
Remove attrs_droput data augmentation dependent on PyTorch transfor…
Browse files Browse the repository at this point in the history
…mer API

Unify missing data handling using Mask Token Replacement (MTR), w/ missing tokens being replaced by mask token
  • Loading branch information
nathanpainchaud committed Oct 31, 2023
1 parent 5660c9f commit 8f7f6f7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 76 deletions.
4 changes: 1 addition & 3 deletions didactic/config/experiment/cardinal/multimodal-xformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ task:
sequential_pooling: False
mtr_p: 0
mt_by_attr: False
attrs_dropout: 0

clinical_tokenizer:
_target_: rtdl.FeatureTokenizer
Expand Down Expand Up @@ -138,7 +137,7 @@ callbacks:
# attention_rollout_kwargs:
# includes_cls_token: ${task.latent_token}

experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout}
experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr}
hydra:
job:
config:
Expand Down Expand Up @@ -173,7 +172,6 @@ hydra:
- task.constraint.clustering_model
- task.mtr_p
- task.mt_by_attr
- task.attrs_dropout

- task.embed_dim
- task/img_tokenizer/model
Expand Down
2 changes: 1 addition & 1 deletion didactic/config/experiment/cardinal/xtab-finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ckpt: ??? # Make it mandatory to provide a checkpoint
weights_only: True # Only load the weights and ignore the hyperparameters
strict: False # Only load weights where they match the defined network, to only some changes (e.g. heads, etc.)

experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.n_blocks},nhead=${task.model.encoder.attention_n_heads},dropout=${task.model.encoder.attention_dropout},${task.model.encoder.ffn_dropout},${task.model.encoder.residual_dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout}
experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.n_blocks},nhead=${task.model.encoder.attention_n_heads},dropout=${task.model.encoder.attention_dropout},${task.model.encoder.ffn_dropout},${task.model.encoder.residual_dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr}
hydra:
run:
dir: ${oc.env:CARDIAC_MULTIMODAL_REPR_PATH}/xtab-finetune/${experiment_dirname}/targets=${oc.dict.keys:task.predict_losses}/${hydra.job.override_dirname}
Expand Down
94 changes: 22 additions & 72 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import hydra
import rtdl
import torch
import torch.nn.functional as F
from omegaconf import DictConfig
from rtdl import FeatureTokenizer
from torch import Tensor, nn
from torch.nn import Parameter, ParameterDict, init
from torchmetrics.functional import accuracy, mean_absolute_error
from vital.data.augmentation.base import random_masking
from vital.data.augmentation.base import mask_tokens, random_masking
from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute
from vital.data.cardinal.config import View as ViewEnum
from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes
Expand Down Expand Up @@ -47,7 +46,6 @@ def __init__(
sequential_pooling: bool = False,
mtr_p: float | Tuple[float, float] = 0,
mt_by_attr: bool = False,
attrs_dropout: float | Tuple[float, float] = 0,
*args,
**kwargs,
):
Expand Down Expand Up @@ -78,9 +76,6 @@ def __init__(
If a tuple, specify a masking rate to use during training and inference, respectively.
mt_by_attr: Whether to use one MASK token per attribute (`True`), or one universal MASK token for all
attributes (`False`).
attrs_dropout: Probability of randomly masking tokens, effectively dropping them, to simulate missing data.
If a float, the value will be used as dropout rate during training (disabled during inference).
If a tuple, specify a dropout rate to use during training and inference, respectively.
*args: Positional arguments to pass to the parent's constructor.
**kwargs: Keyword arguments to pass to the parent's constructor.
"""
Expand All @@ -93,8 +88,6 @@ def __init__(
# If dropout/masking are not single numbers, make sure they are tuples (and not another container type)
if not isinstance(mtr_p, (int, float)):
mtr_p = tuple(mtr_p)
if not isinstance(attrs_dropout, (int, float)):
attrs_dropout = tuple(attrs_dropout)

if contrastive_loss is None and predict_losses is None:
raise ValueError(
Expand Down Expand Up @@ -154,12 +147,7 @@ def __init__(
len(CLINICAL_CAT_ATTR_LABELS[cat_attr]) for cat_attr in self.clinical_cat_attrs
]

# Extract train/test dropout/masking probabilities from their configs
if isinstance(self.hparams.attrs_dropout, tuple):
self.train_attrs_dropout, self.test_attrs_dropout = self.hparams.attrs_dropout
else:
self.train_attrs_dropout = self.hparams.attrs_dropout
self.test_attrs_dropout = 0
# Extract train/test masking probabilities from their configs
if isinstance(self.hparams.mtr_p, tuple):
self.train_mtr_p, self.test_mtr_p = self.hparams.mtr_p
else:
Expand Down Expand Up @@ -217,16 +205,6 @@ def __init__(
self.unimodal_parametrization_heads,
) = self.configure_model()

# Check compatibility between config options and the encoder's architecture
if (self.train_attrs_dropout or self.test_attrs_dropout) and not isinstance(
self.encoder, nn.TransformerEncoder
):
raise ValueError(
"You have requested to apply dropout on the encoder's input tokens (`attrs_dropout` flag), but the "
"encoder is not a native PyTorch `TransformerEncoder`. `attrs_dropout` is only supported for native "
"PyTorch `TransformerEncoder`, since they can be provided with attention masks."
)

# Configure tokenizers and extract relevant info about the models' architectures
if isinstance(self.encoder, nn.TransformerEncoder): # Native PyTorch `TransformerEncoder`
self.nhead = self.encoder.layers[0].self_attn.num_heads
Expand Down Expand Up @@ -395,9 +373,10 @@ def tokenize(
[clinical_attrs[attr].unsqueeze(1) for attr in self.clinical_cat_attrs]
) # (N, S_cat)
# Use "sanitized" version of the inputs, where invalid values are replaced by null/default values, for the
# tokenization process. Since the embeddings of the missing tokens will be ignored later on using the attention
# mask anyway, it doesn't matter that the embeddings returned are not "accurate"; it only matters that the
# tokenization doesn't crash or returns NaNs
# tokenization process. This is done to avoid propagating NaNs to available/valid values.
# If the embeddings cannot be ignored later on (e.g. by using an attention mask during inference), they
# should be replaced w/ a more distinct value to indicate that they are missing (e.g. a specific token),
# instead of their current null/default values.
# 1) Convert missing numerical attributes (NaNs) to numbers to avoid propagating NaNs
# 2) Clip categorical labels to convert indicators of missing data (-1) into valid indices (0)
clinical_attrs_tokens = self.clinical_tokenizer(
Expand Down Expand Up @@ -428,71 +407,42 @@ def tokenize(
return tokens, notna_mask

@auto_move_data
def encode(self, tokens: Tensor, avail_mask: Tensor, apply_augments: bool = True) -> Tensor:
def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = False) -> Tensor:
"""Embeds input sequences using the encoder model, optionally selecting/pooling output tokens for the embedding.
Args:
tokens: (N, S, E), Tokens to feed to the encoder.
avail_mask: (N, S), Mask indicating available (i.e. non-missing) tokens. Missing tokens will not be attended
to by the encoder.
apply_augments: Whether to perform augments on the tokens (e.g. dropout, masking). Normally augments will
avail_mask: (N, S), Boolean mask indicating available (i.e. non-missing) tokens. Missing tokens can thus be
treated distinctly from others (e.g. replaced w/ a specific mask).
disable_augments: Whether to perform augments on the tokens (e.g. masking). Normally augments will
be performed differently (if not outright disabled) when not in training, but this parameter allows to
disable them even during training. This is useful to compute "uncorrupted" views of the data for
contrastive learning.
Returns: (N, E) or (N, S * E), Embeddings of the input sequences. The shape of the embeddings depends on the
selection/pooling applied on the output tokens.
"""
# Cast attention map to float to be able to perform matmul (and the underlying addmul operations), since Pytorch
# doesn't support addmul for int types (see this issue: https://github.com/pytorch/pytorch/issues/44428)
avail_mask = avail_mask.float()
# Default to attend to all non-missing tokens
attn_mask = torch.ones_like(avail_mask)

dropout = self.train_attrs_dropout if self.training else self.test_attrs_dropout
if dropout and apply_augments:
# Draw independent Bernoulli samples for each item/attribute pair in the batch, representing whether
# to keep (1) or drop (0) attributes for each item
dropout_dist = torch.full_like(avail_mask, 1 - dropout)
keep_mask = torch.bernoulli(dropout_dist)

# Repeat the sampling in case all attributes are dropped, missing or masked for an item
while not (keep_mask * avail_mask).any(dim=1).all(dim=0):
keep_mask = torch.bernoulli(dropout_dist)
mask_token = self.mask_token
if isinstance(mask_token, ParameterDict):
mask_token = torch.stack(list(mask_token.values()))

attn_mask *= keep_mask
if mask_token is not None:
# If a mask token is configured, substitute the missing tokens with the mask token to distinguish them from
# the other tokens
tokens = mask_tokens(tokens, mask_token, ~avail_mask)

mtr_p = self.train_mtr_p if self.training else self.test_mtr_p
if mtr_p and apply_augments:
if mtr_p and disable_augments:
# Mask Token Replacement (MTR) data augmentation
mask_token = self.mask_token
if isinstance(mask_token, ParameterDict):
mask_token = torch.stack(list(mask_token.values()))
# Replace random non-missing tokens with the mask token to perturb the input
tokens, _ = random_masking(tokens, mask_token, mtr_p)

if self.hparams.latent_token:
# Add the latent token to the end of each item in the batch
tokens = self.latent_token(tokens)

# Pad attention mask to account for latent token only after dropout, so that latent token is always kept
attn_mask = F.pad(attn_mask, (0, 1), value=1)

# Build attention mask that avoids attending to missing tokens
attn_mask = torch.stack(
[item_attn_mask[None].T @ item_attn_mask[None] for item_attn_mask in attn_mask]
) # (N, S, S)
# Cast attention mask back to bool and flip (because Pytorch's MHA expects true/non-zero values to mark where
# NOT to attend)
attn_mask = ~(attn_mask.bool())
# Repeat the mask to have it be identical for each head of the multi-head attention
# (to respect Pytorch's expected attention mask format)
attn_mask = attn_mask.repeat_interleave(self.nhead, dim=0) # (N * nhead, S, S)

# Add positional embedding to the tokens + forward pass through the transformer encoder
kwargs = {}
if isinstance(self.encoder, nn.TransformerEncoder):
kwargs["mask"] = attn_mask
out_tokens = self.encoder(self.positional_encoding(tokens), **kwargs)
# Forward pass through the transformer encoder
out_tokens = self.encoder(self.positional_encoding(tokens))

if self.hparams.sequential_pooling:
# Perform sequential pooling of the transformers' output tokens
Expand Down Expand Up @@ -640,7 +590,7 @@ def _contrastive_shared_step(
self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor
) -> Dict[str, Tensor]:
corrupted_out_features = out_features # Features from a view corrupted by augmentations
anchor_out_features = self.encode(in_tokens, avail_mask, apply_augments=False)
anchor_out_features = self.encode(in_tokens, avail_mask, disable_augments=True)

# Compute the contrastive loss/metrics
metrics = {
Expand Down

0 comments on commit 8f7f6f7

Please sign in to comment.