Skip to content

Commit

Permalink
Prep for Voice Steering feature
Browse files Browse the repository at this point in the history
Credits:

1. ylacombe
- Add input_values to DACModel
- dac_wrapper/modeling_dac.py
- huggingface#110 (comment)

2. stg2015
- Delay mask adjustment for input_values
- modeling_parler_tts.py
- huggingface#81 (comment)
  • Loading branch information
apresence committed Sep 24, 2024
1 parent dcaed95 commit d3c7fdc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
3 changes: 3 additions & 0 deletions parler_tts/dac_wrapper/modeling_dac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
class DACModel(PreTrainedModel):
config_class = DACConfig

# Set main input to 'input_values' for voice steering
main_input_name = "input_values"

def __init__(self, config):
super().__init__(config)

Expand Down
20 changes: 13 additions & 7 deletions parler_tts/modeling_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3483,13 +3483,19 @@ def generate(
# Apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])

# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config._bos_token_tensor,
pad_token_id=generation_config._pad_token_tensor,
max_length=output_ids.shape[1],
)
if "input_values" in model_kwargs:
# Handle input_values for voice steering
mask = (output_ids != generation_config.bos_token_id) & (output_ids != generation_config.pad_token_id)
else:
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)


mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
Expand Down

0 comments on commit d3c7fdc

Please sign in to comment.