Skip to content

Commit

Permalink
Refactor random_masking to extract mask_tokens function for repla…
Browse files Browse the repository at this point in the history
…cement according to pre-defined mask (#178)
  • Loading branch information
nathanpainchaud authored Oct 31, 2023
1 parent e91dcab commit d25fb84
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions vital/data/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch import Tensor


def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Tensor]:
"""Masks random tokens in sequences by replacing them with a predefined `mask_token`.
def mask_tokens(x: Tensor, mask_token: Tensor, mask: Tensor) -> Tensor:
"""Replaces tokens in a batch of sequences with a predefined `mask_token`.
References:
- Adapted from the random masking implementation from the paper that introduced Mask Token Replacement (MTR):
Expand All @@ -14,21 +14,14 @@ def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Ten
Args:
x: (N, S, E) Batch of sequences of tokens.
mask_token: (E) or (S, E) Mask to replace masked tokens with. If a single token of dimension (E), then the mask
will be used to replace any tokens in the sequence. Otherwise, each token in the sequence has to to have its
will be used to replace any tokens in the sequence. Otherwise, each token in the sequence has to have its
own MASK token to be replaced with.
p: Probability to replace a token by the mask token.
mask: (N, S) Boolean mask of tokens in each sequence, with (True) representing tokens to replace.
Returns:
x_masked: (N, S, E) Input tokens, where some tokens have been replaced by the mask token.
mask: (N, S) Mask of tokens that were masked, with (1) representing tokens that were masked.
(N, S, E) Input tokens, where the requested tokens have been replaced by the mask token.
"""
n, s, d = x.shape
mask_dist = torch.full((n, s), p) # Token-wise masking probability

# Repeat the sampling in case all tokens are masked for an item in the batch
mask = torch.bernoulli(mask_dist)
while not mask.any(dim=1).all(dim=0):
mask = torch.bernoulli(mask_dist)

broadcast_mask = mask.unsqueeze(-1).to(device=x.device, dtype=torch.float)
broadcast_mask = broadcast_mask.repeat(1, 1, d)
Expand All @@ -44,4 +37,34 @@ def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Ten
)

x_masked = x * (1 - broadcast_mask) + mask_tokens * broadcast_mask
return x_masked, mask
return x_masked


def random_masking(x: Tensor, mask_token: Tensor, p: float) -> Tuple[Tensor, Tensor]:
"""Masks random tokens in sequences by replacing them with a predefined `mask_token`.
References:
- Adapted from the random masking implementation from the paper that introduced Mask Token Replacement (MTR):
https://github.com/somaonishi/MTR/blob/33b87b37a63d120aff24c041da711fd8b714c00e/model/mask_token.py#L52-L68
Args:
x: (N, S, E) Batch of sequences of tokens.
mask_token: (E) or (S, E) Mask to replace masked tokens with. If a single token of dimension (E), then the mask
will be used to replace any tokens in the sequence. Otherwise, each token in the sequence has to have its
own MASK token to be replaced with.
p: Probability to replace a token by the mask token.
Returns:
x_masked: (N, S, E) Input tokens, where some tokens have been replaced by the mask token.
mask: (N, S) Boolean mask of tokens that were masked, with (True) representing tokens that were masked.
"""
n, s, d = x.shape
mask_dist = torch.full((n, s), p) # Token-wise masking probability

# Repeat the sampling in case all tokens are masked for an item in the batch
mask = torch.bernoulli(mask_dist)
while not mask.any(dim=1).all(dim=0):
mask = torch.bernoulli(mask_dist)
mask = mask.bool() # Cast from 0/1 int to bool

return mask_tokens(x, mask_token, mask), mask

0 comments on commit d25fb84

Please sign in to comment.