Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Delay pattern mask is applied twice #110

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
385 changes: 385 additions & 0 deletions helpers/voice_enrolment_demo/enrol.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been able to test this with the following code, which also requires a small modification of the DAC code (adding main_input_name = "input_values" as a class attribute of DACModel) :

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, set_seed, AutoFeatureExtractor
import soundfile as sf
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

init_audio, init_sr = torchaudio.load(PATH_TO_SPECIFY)
init_audio = torchaudio.functional.resample(init_audio, init_sr, model.config.sampling_rate)
init_audio = init_audio.mean(0)
init_prompt = "Here, write the transcript of the init audio"

prompt = "Is it really working ?"
description = "A man speaker speaks quickly with a low-pitched voice. The recording is of very high quality, with the speaker's voice sounding clear and very close up." # TODO: adapt the prompt to describe the input audio

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(init_prompt + " " + prompt, return_tensors="pt").input_ids.to(device)
input_values = feature_extractor(init_audio, sampling_rate=model.config.sampling_rate, return_tensors="pt").input_values.to(device)
set_seed(2)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids, input_values = input_values)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

I found that Parler has difficulty generalizing to unseen speakers (meaning using a speaker that has not been seen during training or that has not been generated by Parler), so there's no actual edge of using it for voice cloning. However, from my experiment, it's working quite well with Parler generation!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ylacombe , I tried the above code sample with both the mini and large model but the audio file generated is noisy and inconsistent. I've used the input audio generated through ParlerTTS itself.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clean snippet! When calling model.generate(...) Is there a preference for using input_values=input_values? I was originally doing something along the lines of decoder_input_ids=input_values.squeeze().long().

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion parler_tts/modeling_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3387,7 +3387,8 @@ def generate(
)

# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
# but don't overwrite the input_ids tensor with the delay pattern mask. We perform that later
_, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
Comment on lines +3390 to +3391
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As pointed out, this is a redundant operation that has no impact on the results!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think this line does indeed change the results when using enrolled tokens. Perhaps your setup is working because it is slightly different as you've described below. I shall try this and get back to you

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so my testing shows that this fix is required to get the right audio when doing the enrolment. Here is an example audio file generated with and without the fix:
audio.zip

input_ids,
bos_token_id=generation_config._bos_token_tensor,
pad_token_id=generation_config._pad_token_tensor,
Expand Down Expand Up @@ -3442,6 +3443,7 @@ def generate(
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
logits_warper=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should keep the logits_warper, I'm not sure why you removed it!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't remove it! Originally, logits_warper wasn't being passed in, so this part of the code was failing. I believe when doing greedy search, logits_warper=None should be set. Please could you double check this!

**model_kwargs,
)

Expand Down