diff --git a/mel2wav/modules.py b/mel2wav/modules.py index da39c5d..7f7d846 100644 --- a/mel2wav/modules.py +++ b/mel2wav/modules.py @@ -40,7 +40,7 @@ def __init__( ############################################## window = torch.hann_window(win_length).float() mel_basis = librosa_mel_fn( - sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax + sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax ) mel_basis = torch.from_numpy(mel_basis).float() self.register_buffer("mel_basis", mel_basis) @@ -61,8 +61,10 @@ def forward(self, audio): win_length=self.win_length, window=self.window, center=False, + return_complex=False ) - real_part, imag_part = fft.unbind(-1) + + real_part, imag_part = fft[:, :, :, 0], fft[:, :, :, 1] magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) mel_output = torch.matmul(self.mel_basis, magnitude) log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))