Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor Audio Quality with input_values Input in Parler_TTS #81

Open
LiuZH-19 opened this issue Jul 2, 2024 · 3 comments
Open

Poor Audio Quality with input_values Input in Parler_TTS #81

LiuZH-19 opened this issue Jul 2, 2024 · 3 comments

Comments

@LiuZH-19
Copy link

LiuZH-19 commented Jul 2, 2024

I am using the Parler_TTS model with a reference audio (input_values) during inference, similar to MusicGen, to perform continuation tasks.

model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids, input_values=input_values)

While the model continues in the style of the reference audio, the resulting audio quality is poor and contains a lot of noise.
Why does the audio quality degrade when using a reference audio input, and how can this be improved?

Thank you!

@ylacombe
Copy link
Collaborator

ylacombe commented Aug 1, 2024

we should remove input_values as it's not used in this model, it's an artifact left from the fact I was inspired by Musicgen architecture

@stg1205
Copy link

stg1205 commented Aug 14, 2024

change the code after this comment: "# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask"
to

if "input_values" in model_kwargs:
            mask = (output_ids != generation_config.bos_token_id) & (output_ids != generation_config.pad_token_id)
else:
          _, mask = self.decoder.build_delay_pattern_mask(
              input_ids,
              bos_token_id=generation_config.bos_token_id,
              pad_token_id=generation_config.pad_token_id,
              max_length=output_ids.shape[1],
          )
          mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)

I haven't looked into any details, for now it works. I found this bug by comparing the output_ids with the original input_ids encoded by dac, and there are some wrong delays in output_ids.

@Guppy16
Copy link

Guppy16 commented Aug 17, 2024

#110

Does this help? I have an example notebook doing continuation as well. You need to use the decode_input_ids argument (as well as fix a bug similar to how @stg1205 showed above)

apresence pushed a commit to apresence/parler-tts that referenced this issue Sep 24, 2024
Credits:

1. ylacombe
- Add input_values to DACModel
- dac_wrapper/modeling_dac.py
- huggingface#110 (comment)

2. stg2015
- Delay mask adjustment for input_values
- modeling_parler_tts.py
- huggingface#81 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants