Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Delay pattern mask is applied twice #110

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Guppy16
Copy link

@Guppy16 Guppy16 commented Aug 16, 2024

There are a few bugfixes / contributions:

  1. (Bugfix) In ParlerTTSForConditionalGeneration::generate(), the delay patter mask is built and applied to input_ids before calling _sample(). However, we do not want to apply the mask until we're inside the _sample() function. The above bug doesn't affect the current inference setup because the input_ids returned happens to be the same as what's passed in.

  2. (Contribution) I've provided an example of how to do audio enrolment to improve the consistency of audio generation in (helpers/voice_enrolment/enrol.ipynb ). I believe it helps, but I'm not sure if it's always better. Nonetheless, I mainly provided this as an example to demonstrate how the bugfix helps: we can try to provide the enrolment as prefix tokens in decoder_input_ids args when calling model.generate(). You should notice that without the bugfix, the audio sounds "crackly", which is because the mask has effectively been applied twice on the prefix.

  3. (Bugfix) When performing deterministic greedy decoding (by passing in do_sample=False, there is bug where the logits_warper is not passed in. I believe this should just be None(?), which I've commited in this PR. Related to this, I also want to raise an issue that deterministic sampling by setting do_sample=False or temperatute=0.1 tends to generate random noise.

@Guppy16
Copy link
Author

Guppy16 commented Aug 19, 2024

Perhaps the bugfixes also need to be applied in ParlerTTSForCausalLM? (I haven't touched this class so I'm not sure about it's intended use)

@Guppy16
Copy link
Author

Guppy16 commented Aug 21, 2024

@ylacombe Would it be possible for you to review this?

@ylacombe
Copy link
Collaborator

ylacombe commented Sep 2, 2024

Hey @Guppy16, thanks for opening this ! I'll take a look in the coming days!

@ylacombe
Copy link
Collaborator

ylacombe commented Sep 5, 2024

(Bugfix) In ParlerTTSForConditionalGeneration::generate(), the delay patter mask is built and applied to input_ids before calling _sample(). However, we do not want to apply the mask until we're inside the _sample() function. The above bug doesn't affect the current inference setup because the input_ids returned happens to be the same as what's passed in.

So it's actually a redundant operation that changes nothing, right ? Just want to make sure it's not a bug. When I experiment, it seems that it doesn't change anything

Copy link
Collaborator

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Guppy16, so even though the modification you're proposing makes sense, I'm not sure that your other additions should be merged:

  1. the logit wrapper modif is an error
  2. I'd rather have an issue opened that explain how to do your voice enrolment thing than a notebook ! Would you like to write a guide about this and add it to the issues ? I can ping it afterwards

@@ -3442,6 +3443,7 @@ def generate(
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
logits_warper=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should keep the logits_warper, I'm not sure why you removed it!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't remove it! Originally, logits_warper wasn't being passed in, so this part of the code was failing. I believe when doing greedy search, logits_warper=None should be set. Please could you double check this!

Comment on lines +3390 to +3391
# but don't overwrite the input_ids tensor with the delay pattern mask. We perform that later
_, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As pointed out, this is a redundant operation that has no impact on the results!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think this line does indeed change the results when using enrolled tokens. Perhaps your setup is working because it is slightly different as you've described below. I shall try this and get back to you

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so my testing shows that this fix is required to get the right audio when doing the enrolment. Here is an example audio file generated with and without the fix:
audio.zip

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been able to test this with the following code, which also requires a small modification of the DAC code (adding main_input_name = "input_values" as a class attribute of DACModel) :

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, set_seed, AutoFeatureExtractor
import soundfile as sf
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

init_audio, init_sr = torchaudio.load(PATH_TO_SPECIFY)
init_audio = torchaudio.functional.resample(init_audio, init_sr, model.config.sampling_rate)
init_audio = init_audio.mean(0)
init_prompt = "Here, write the transcript of the init audio"

prompt = "Is it really working ?"
description = "A man speaker speaks quickly with a low-pitched voice. The recording is of very high quality, with the speaker's voice sounding clear and very close up." # TODO: adapt the prompt to describe the input audio

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(init_prompt + " " + prompt, return_tensors="pt").input_ids.to(device)
input_values = feature_extractor(init_audio, sampling_rate=model.config.sampling_rate, return_tensors="pt").input_values.to(device)
set_seed(2)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids, input_values = input_values)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

I found that Parler has difficulty generalizing to unseen speakers (meaning using a speaker that has not been seen during training or that has not been generated by Parler), so there's no actual edge of using it for voice cloning. However, from my experiment, it's working quite well with Parler generation!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ylacombe , I tried the above code sample with both the mini and large model but the audio file generated is noisy and inconsistent. I've used the input audio generated through ParlerTTS itself.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clean snippet! When calling model.generate(...) Is there a preference for using input_values=input_values? I was originally doing something along the lines of decoder_input_ids=input_values.squeeze().long().

apresence pushed a commit to apresence/parler-tts that referenced this pull request Sep 24, 2024
Credits:

1. ylacombe
- Add input_values to DACModel
- dac_wrapper/modeling_dac.py
- huggingface#110 (comment)

2. stg2015
- Delay mask adjustment for input_values
- modeling_parler_tts.py
- huggingface#81 (comment)
@Guppy16
Copy link
Author

Guppy16 commented Sep 25, 2024

Thanks a lot for reviewing this, as well as your great suggestions! I'll work on this in the coming few days.

@Guppy16
Copy link
Author

Guppy16 commented Sep 27, 2024

  1. I'd rather have an issue opened that explain how to do your voice enrolment thing than a notebook ! Would you like to write a guide about this and add it to the issues ? I can ping it afterwards

Looks like @apresence has made a start on this! I've added a modified version of ur snippet there (#139)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants