Skip to content

Commit

Permalink
figured out part of the whole masking thing, still need to take care …
Browse files Browse the repository at this point in the history
…of p_drop, which assume is just like MLM per token prob of masking
  • Loading branch information
lucidrains committed Aug 6, 2023
1 parent 6ecd507 commit 9b12220
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
51 changes: 34 additions & 17 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchdiffeq import odeint_adjoint as odeint

from beartype import beartype
from beartype.typing import Tuple

from einops import rearrange, repeat, reduce, pack, unpack

Expand Down Expand Up @@ -263,13 +264,16 @@ def __init__(
ff_mult = 4,
conv_pos_embed_kernel_size = 31,
conv_pos_embed_groups = None,
attn_flash = False
attn_flash = False,
frac_lengths_mask: Tuple[float, float] = (0.1, 1.)
):
super().__init__()

self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG
self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb)

self.frac_lengths_mask = frac_lengths_mask

self.to_embed = nn.Linear(dim * 2 + dim_phoneme_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim))
Expand All @@ -294,6 +298,10 @@ def __init__(
Rearrange('... 1 -> ...')
)

@property
def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def forward_with_cond_scale(
self,
Expand All @@ -319,7 +327,14 @@ def forward(
target = None,
mask = None
):
assert cond.shape[-1] == x.shape[-1]
batch, seq_len, cond_dim = cond.shape
assert cond_dim == x.shape[-1]

# construct mask if not given

if not exists(mask):
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
mask = mask_from_frac_lengths(seq_len, frac_lengths)

# classifier free guidance

Expand Down Expand Up @@ -376,14 +391,17 @@ def __init__(
ff_mult = 4,
conv_pos_embed_kernel_size = 31,
conv_pos_embed_groups = None,
attn_flash = False
attn_flash = False,
frac_lengths_mask: Tuple[float, float] = (0.7, 1.)
):
super().__init__()
self.sinu_pos_emb = LearnedSinusoidalPosEmb(dim)

self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG
self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb)

self.frac_lengths_mask = frac_lengths_mask

self.to_embed = nn.Linear(dim * 2 + dim_phoneme_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim))
Expand All @@ -405,6 +423,10 @@ def __init__(

self.to_pred = nn.Linear(dim, dim, bias = False)

@property
def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def forward_with_cond_scale(
self,
Expand All @@ -431,7 +453,8 @@ def forward(
target = None,
mask = None,
):
assert cond.shape[-1] == x.shape[-1]
batch, seq_len, cond_dim = cond.shape
assert cond_dim == x.shape[-1]

# auto manage shape of times, for odeint times

Expand All @@ -441,6 +464,12 @@ def forward(
if times.ndim == 1 and times.shape[0] == 1:
times = repeat(times, '1 -> b', b = cond.shape[0])

# construct mask if not given

if not exists(mask):
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
mask = mask_from_frac_lengths(seq_len, frac_lengths)

# classifier free guidance

if cond_drop_prob > 0.:
Expand Down Expand Up @@ -511,16 +540,14 @@ def __init__(
ode_atol = 1e-5,
ode_rtol = 1e-5,
ode_method = 'dopri5',
cond_drop_prob = 0.,
prob_seq_drop = 0.3 # not entirely sure
cond_drop_prob = 0.
):
super().__init__()
self.sigma = sigma

self.voicebox = voicebox

self.cond_drop_prob = cond_drop_prob
self.prob_seq_drop = prob_seq_drop

self.odeint_kwargs = dict(
atol = ode_atol,
Expand Down Expand Up @@ -591,16 +618,6 @@ def forward(

flow = x1 - (1 - σ) * x0

# construct mask if not given

if (
not exists(mask) and
exists(self.prob_seq_drop) and
self.prob_seq_drop > 0.
):
frac_lengths = torch.full((batch,), self.prob_seq_drop, device = self.device)
mask = mask_from_frac_lengths(seq_len, frac_lengths)

# predict

self.voicebox.train()
Expand Down

0 comments on commit 9b12220

Please sign in to comment.