Skip to content

Commit

Permalink
making MSD and MPD optional
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 6, 2024
1 parent f9340cc commit e788bb4
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 93 deletions.
100 changes: 65 additions & 35 deletions egs/libritts/CODEC/encodec/encodec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import random
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand All @@ -25,8 +25,8 @@ def __init__(
quantizer: nn.Module,
decoder: nn.Module,
multi_scale_discriminator: nn.Module,
multi_period_discriminator: nn.Module,
multi_scale_stft_discriminator: nn.Module,
multi_period_discriminator: Optional[nn.Module] = None,
multi_scale_stft_discriminator: Optional[nn.Module] = None,
cache_generator_outputs: bool = False,
):
super(Encodec, self).__init__()
Expand Down Expand Up @@ -113,28 +113,42 @@ def _forward_generator(
with torch.no_grad():
# do not store discriminator gradient in generator turn
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)

gen_period_adv_loss = torch.tensor(0.0)
feature_period_loss = torch.tensor(0.0)
if self.multi_period_discriminator is not None:
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)

gen_scale_adv_loss = torch.tensor(0.0)
feature_scale_loss = torch.tensor(0.0)
if self.multi_scale_discriminator is not None:
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)

# calculate losses
with autocast(enabled=False):
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)

if self.multi_period_discriminator is not None:
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
if self.multi_scale_discriminator is not None:
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)

feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
feature_period_loss = self.feature_match_loss(
feats=fmap_p, feats_hat=fmap_p_hat
)
feature_scale_loss = self.feature_match_loss(
feats=fmap_s, feats_hat=fmap_s_hat
)

if self.multi_period_discriminator is not None:
feature_period_loss = self.feature_match_loss(
feats=fmap_p, feats_hat=fmap_p_hat
)
if self.multi_scale_discriminator is not None:
feature_scale_loss = self.feature_match_loss(
feats=fmap_s, feats_hat=fmap_s_hat
)

wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat
Expand Down Expand Up @@ -245,28 +259,44 @@ def _forward_discriminator(
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
speech_hat.contiguous().detach()
)
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)

disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor(
0.0
), torch.tensor(0.0)
if self.multi_period_discriminator is not None:
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)

disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor(
0.0
), torch.tensor(0.0)
if self.multi_scale_discriminator is not None:
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
# calculate losses
with autocast(enabled=False):
(
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
(
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
(
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
if self.multi_period_discriminator is not None:
(
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
) = self.discriminator_adversarial_loss(
outputs=y_p, outputs_hat=y_p_hat
)
if self.multi_scale_discriminator is not None:
(
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
) = self.discriminator_adversarial_loss(
outputs=y_s, outputs_hat=y_s_hat
)

stats = dict(
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
Expand Down
108 changes: 50 additions & 58 deletions egs/libritts/CODEC/encodec/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
quantizer=quantizer,
decoder=decoder,
multi_scale_discriminator=MultiScaleDiscriminator(),
multi_period_discriminator=MultiPeriodDiscriminator(),
multi_scale_discriminator=None,
multi_period_discriminator=None,
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
n_filters=params.stft_discriminator_n_filters
),
Expand Down Expand Up @@ -456,17 +456,13 @@ def save_bad_model(suffix: str = ""):
forward_generator=False,
)
disc_loss = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* d_weight
/ 3
)
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
) * d_weight
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
Expand Down Expand Up @@ -499,13 +495,11 @@ def save_bad_model(suffix: str = ""):
return_sample=params.batch_idx_train % params.log_interval == 0,
)
gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* g_weight
/ 3
)
gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
) * g_weight
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
)
reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss
Expand Down Expand Up @@ -714,17 +708,13 @@ def compute_validation_loss(
forward_generator=False,
)
disc_loss = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* d_weight
/ 3
)
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
) * d_weight
assert disc_loss.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
Expand Down Expand Up @@ -753,13 +743,9 @@ def compute_validation_loss(
return_sample=False,
)
gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* g_weight
/ 3
)
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
) * g_weight
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
)
Expand Down Expand Up @@ -836,20 +822,16 @@ def scan_pessimistic_batches_for_oom(
forward_generator=False,
)
loss_d = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
) * adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
optimizer_d.zero_grad()
loss_d.backward()
Expand Down Expand Up @@ -879,7 +861,6 @@ def scan_pessimistic_batches_for_oom(
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
+ params.lambda_rec
* (
params.lambda_wav * wav_reconstruction_loss
Expand Down Expand Up @@ -962,9 +943,17 @@ def run(rank, world_size, args):
logging.info(f"Number of parameters in decoder: {num_param_d}")
num_param_q = sum([p.numel() for p in quantizer.parameters()])
logging.info(f"Number of parameters in quantizer: {num_param_q}")
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()])
num_param_ds = (
sum([p.numel() for p in multi_scale_discriminator.parameters()])
if multi_scale_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()])
num_param_dp = (
sum([p.numel() for p in multi_period_discriminator.parameters()])
if multi_period_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
num_param_dstft = sum(
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
Expand Down Expand Up @@ -998,12 +987,15 @@ def run(rank, world_size, args):
lr=params.lr,
betas=(0.5, 0.9),
)
discriminator_params = [
multi_scale_stft_discriminator.parameters(),
]
if multi_scale_discriminator is not None:
discriminator_params.append(multi_scale_discriminator.parameters())
if multi_period_discriminator is not None:
discriminator_params.append(multi_period_discriminator.parameters())
optimizer_d = torch.optim.AdamW(
itertools.chain(
multi_scale_stft_discriminator.parameters(),
multi_scale_discriminator.parameters(),
multi_period_discriminator.parameters(),
),
itertools.chain(*discriminator_params),
lr=params.lr,
betas=(0.5, 0.9),
)
Expand Down

0 comments on commit e788bb4

Please sign in to comment.