Skip to content

Commit

Permalink
handle correct calculation of sampled audio from the semantic token i…
Browse files Browse the repository at this point in the history
…ds coming from spear tts
  • Loading branch information
lucidrains committed Sep 27, 2023
1 parent 9d24c26 commit 205a163
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.5',
version = '0.1.6',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
45 changes: 36 additions & 9 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ def __init__(

self.vocos = Vocos.from_pretrained(pretrained_vocos_path)

@property
def downsample_factor(self):
raise NotImplementedError

@property
def latent_dim(self):
return self.num_mels
Expand Down Expand Up @@ -427,6 +431,10 @@ def __init__(

self.register_buffer('bandwidth_id', torch.tensor([bandwidth_id]))

@property
def downsample_factor(self):
return self.encodec.downsample_factor

@property
def latent_dim(self):
return self.encodec.codebook_dim
Expand Down Expand Up @@ -705,6 +713,7 @@ def __init__(
if not condition_on_text:
dim_cond_emb = 0

self.dim_cond_emb = dim_cond_emb
self.condition_on_text = condition_on_text
self.num_cond_tokens = num_cond_tokens

Expand Down Expand Up @@ -844,6 +853,11 @@ def forward(
cond_emb = F.interpolate(cond_emb, (seq_len, 1), mode = 'bilinear')
cond_emb = rearrange(cond_emb, 'b d n 1 -> b n d')

if exists(self_attn_mask):
self_attn_mask = rearrange(self_attn_mask.float(), 'b n -> b 1 n 1')
self_attn_mask = F.interpolate(self_attn_mask, (seq_len, 1), mode = 'bilinear')
self_attn_mask = rearrange(self_attn_mask, 'b 1 n 1 -> b n').bool()

# concat source signal, semantic / phoneme conditioning embed, and conditioning
# and project

Expand Down Expand Up @@ -947,7 +961,7 @@ def load(self, path, strict = True):
def sample(
self,
*,
cond,
cond = None,
texts: Optional[List[str]] = None,
text_token_ids: Optional[Tensor] = None,
semantic_token_ids = None,
Expand All @@ -957,9 +971,6 @@ def sample(
cond_scale = 1.,
decode_to_audio = True
):
shape = cond.shape
batch = shape[0]

# take care of condition as raw audio

cond_is_raw_audio = is_probably_audio_from_shape(cond)
Expand Down Expand Up @@ -997,19 +1008,35 @@ def sample(
else:
cond_token_ids = phoneme_ids

cond_length = cond.shape[-2]
cond_tokens_seq_len = cond_token_ids.shape[-1]

# calculate the correct conditioning length
# based on the sampling freqs of wav2vec and audio-enc-dec, as well as downsample factor
# (cond_time x cond_sampling_freq / cond_downsample_factor) == (audio_time x audio_sampling_freq / audio_downsample_factor)

wav2vec = self.text_to_semantic.wav2vec
audio_enc_dec = self.voicebox.audio_enc_dec

cond_target_length = (cond_tokens_seq_len * wav2vec.target_sample_hz / wav2vec.downsample_factor) / (audio_enc_dec.sampling_rate / audio_enc_dec.downsample_factor)
cond_target_length = math.ceil(cond_target_length)

# curtail or pad (todo: generalize this to any dimension and put in a function)

if cond_length > cond_tokens_seq_len:
cond = cond[:, :cond_tokens_seq_len, :]
elif cond_length < cond_tokens_seq_len:
cond = F.pad(cond, (0, 0, 0, cond_tokens_seq_len), value = 0.)
if exists(cond):
cond_length = cond.shape[-2]

if cond_length > cond_target_length:
cond = cond[:, :cond_target_length, :]
elif cond_length < cond_target_length:
cond = F.pad(cond, (0, 0, 0, cond_target_length), value = 0.)
else:
cond = torch.zeros((cond_token_ids.shape[0], cond_target_length, self.dim_cond_emb), device = self.device)
else:
assert num_cond_inputs == 0, 'no conditioning inputs should be given if not conditioning on text'

shape = cond.shape
batch = shape[0]

# neural ode

self.voicebox.eval()
Expand Down

0 comments on commit 205a163

Please sign in to comment.