From d25fb84889129c8e461cbd5bd331cc2d7a854726 Mon Sep 17 00:00:00 2001 From: Nathan Painchaud <23144457+nathanpainchaud@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:18:48 +0100 Subject: [PATCH] Refactor `random_masking` to extract `mask_tokens` function for replacement according to pre-defined mask (#178) --- vital/data/augmentation/base.py | 49 ++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/vital/data/augmentation/base.py b/vital/data/augmentation/base.py index a1d38b60..c902df38 100644 --- a/vital/data/augmentation/base.py +++ b/vital/data/augmentation/base.py @@ -4,8 +4,8 @@ from torch import Tensor -def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Tensor]: - """Masks random tokens in sequences by replacing them with a predefined `mask_token`. +def mask_tokens(x: Tensor, mask_token: Tensor, mask: Tensor) -> Tensor: + """Replaces tokens in a batch of sequences with a predefined `mask_token`. References: - Adapted from the random masking implementation from the paper that introduced Mask Token Replacement (MTR): @@ -14,21 +14,14 @@ def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Ten Args: x: (N, S, E) Batch of sequences of tokens. mask_token: (E) or (S, E) Mask to replace masked tokens with. If a single token of dimension (E), then the mask - will be used to replace any tokens in the sequence. Otherwise, each token in the sequence has to to have its + will be used to replace any tokens in the sequence. Otherwise, each token in the sequence has to have its own MASK token to be replaced with. - p: Probability to replace a token by the mask token. + mask: (N, S) Boolean mask of tokens in each sequence, with (True) representing tokens to replace. Returns: - x_masked: (N, S, E) Input tokens, where some tokens have been replaced by the mask token. - mask: (N, S) Mask of tokens that were masked, with (1) representing tokens that were masked. + (N, S, E) Input tokens, where the requested tokens have been replaced by the mask token. """ n, s, d = x.shape - mask_dist = torch.full((n, s), p) # Token-wise masking probability - - # Repeat the sampling in case all tokens are masked for an item in the batch - mask = torch.bernoulli(mask_dist) - while not mask.any(dim=1).all(dim=0): - mask = torch.bernoulli(mask_dist) broadcast_mask = mask.unsqueeze(-1).to(device=x.device, dtype=torch.float) broadcast_mask = broadcast_mask.repeat(1, 1, d) @@ -44,4 +37,34 @@ def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Ten ) x_masked = x * (1 - broadcast_mask) + mask_tokens * broadcast_mask - return x_masked, mask + return x_masked + + +def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Tensor]: + """Masks random tokens in sequences by replacing them with a predefined `mask_token`. + + References: + - Adapted from the random masking implementation from the paper that introduced Mask Token Replacement (MTR): + https://github.com/somaonishi/MTR/blob/33b87b37a63d120aff24c041da711fd8b714c00e/model/mask_token.py#L52-L68 + + Args: + x: (N, S, E) Batch of sequences of tokens. + mask_token: (E) or (S, E) Mask to replace masked tokens with. If a single token of dimension (E), then the mask + will be used to replace any tokens in the sequence. Otherwise, each token in the sequence has to have its + own MASK token to be replaced with. + p: Probability to replace a token by the mask token. + + Returns: + x_masked: (N, S, E) Input tokens, where some tokens have been replaced by the mask token. + mask: (N, S) Boolean mask of tokens that were masked, with (True) representing tokens that were masked. + """ + n, s, d = x.shape + mask_dist = torch.full((n, s), p) # Token-wise masking probability + + # Repeat the sampling in case all tokens are masked for an item in the batch + mask = torch.bernoulli(mask_dist) + while not mask.any(dim=1).all(dim=0): + mask = torch.bernoulli(mask_dist) + mask = mask.bool() # Cast from 0/1 int to bool + + return mask_tokens(x, mask_token, mask), mask