From 9b12220bb2820c8750cea264d68e833926438f5d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 6 Aug 2023 09:59:26 -0700 Subject: [PATCH] figured out part of the whole masking thing, still need to take care of p_drop, which assume is just like MLM per token prob of masking --- setup.py | 2 +- voicebox_pytorch/voicebox_pytorch.py | 51 ++++++++++++++++++---------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/setup.py b/setup.py index a387389..16fa66a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.0.9', + version = '0.0.10', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index de683fb..6789e8f 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -8,6 +8,7 @@ from torchdiffeq import odeint_adjoint as odeint from beartype import beartype +from beartype.typing import Tuple from einops import rearrange, repeat, reduce, pack, unpack @@ -263,13 +264,16 @@ def __init__( ff_mult = 4, conv_pos_embed_kernel_size = 31, conv_pos_embed_groups = None, - attn_flash = False + attn_flash = False, + frac_lengths_mask: Tuple[float, float] = (0.1, 1.) ): super().__init__() self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb) + self.frac_lengths_mask = frac_lengths_mask + self.to_embed = nn.Linear(dim * 2 + dim_phoneme_emb, dim) self.null_cond = nn.Parameter(torch.zeros(dim)) @@ -294,6 +298,10 @@ def __init__( Rearrange('... 1 -> ...') ) + @property + def device(self): + return next(self.parameters()).device + @torch.inference_mode() def forward_with_cond_scale( self, @@ -319,7 +327,14 @@ def forward( target = None, mask = None ): - assert cond.shape[-1] == x.shape[-1] + batch, seq_len, cond_dim = cond.shape + assert cond_dim == x.shape[-1] + + # construct mask if not given + + if not exists(mask): + frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) + mask = mask_from_frac_lengths(seq_len, frac_lengths) # classifier free guidance @@ -376,7 +391,8 @@ def __init__( ff_mult = 4, conv_pos_embed_kernel_size = 31, conv_pos_embed_groups = None, - attn_flash = False + attn_flash = False, + frac_lengths_mask: Tuple[float, float] = (0.7, 1.) ): super().__init__() self.sinu_pos_emb = LearnedSinusoidalPosEmb(dim) @@ -384,6 +400,8 @@ def __init__( self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb) + self.frac_lengths_mask = frac_lengths_mask + self.to_embed = nn.Linear(dim * 2 + dim_phoneme_emb, dim) self.null_cond = nn.Parameter(torch.zeros(dim)) @@ -405,6 +423,10 @@ def __init__( self.to_pred = nn.Linear(dim, dim, bias = False) + @property + def device(self): + return next(self.parameters()).device + @torch.inference_mode() def forward_with_cond_scale( self, @@ -431,7 +453,8 @@ def forward( target = None, mask = None, ): - assert cond.shape[-1] == x.shape[-1] + batch, seq_len, cond_dim = cond.shape + assert cond_dim == x.shape[-1] # auto manage shape of times, for odeint times @@ -441,6 +464,12 @@ def forward( if times.ndim == 1 and times.shape[0] == 1: times = repeat(times, '1 -> b', b = cond.shape[0]) + # construct mask if not given + + if not exists(mask): + frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) + mask = mask_from_frac_lengths(seq_len, frac_lengths) + # classifier free guidance if cond_drop_prob > 0.: @@ -511,8 +540,7 @@ def __init__( ode_atol = 1e-5, ode_rtol = 1e-5, ode_method = 'dopri5', - cond_drop_prob = 0., - prob_seq_drop = 0.3 # not entirely sure + cond_drop_prob = 0. ): super().__init__() self.sigma = sigma @@ -520,7 +548,6 @@ def __init__( self.voicebox = voicebox self.cond_drop_prob = cond_drop_prob - self.prob_seq_drop = prob_seq_drop self.odeint_kwargs = dict( atol = ode_atol, @@ -591,16 +618,6 @@ def forward( flow = x1 - (1 - σ) * x0 - # construct mask if not given - - if ( - not exists(mask) and - exists(self.prob_seq_drop) and - self.prob_seq_drop > 0. - ): - frac_lengths = torch.full((batch,), self.prob_seq_drop, device = self.device) - mask = mask_from_frac_lengths(seq_len, frac_lengths) - # predict self.voicebox.train()