diff --git a/config/default.yaml b/config/default.yaml index c0d63c7..6a5e298 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -39,6 +39,7 @@ train: final: 0.05 summary_interval: 1 checkpoint_interval: 1000 + complex_loss_ratio: 0.1 # the lambda in power-law compression loss computation --- log: chkpt_dir: 'chkpt' diff --git a/utils/train.py b/utils/train.py index 1fd8d94..015e65f 100644 --- a/utils/train.py +++ b/utils/train.py @@ -64,9 +64,16 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, mask = model(mixed_mag, dvec) output = mixed_mag * mask - # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power) - # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power) - loss = criterion(output, target_mag) + # Power-law compression + magnitude_loss = criterion( + torch.pow(torch.abs(output), hp.audio.power), + torch.pow(torch.abs(target_mag), hp.audio.power), + ) + complex_loss = criterion( + torch.pow(torch.clamp(output, min=0.0), hp.audio.power), + torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power), + ) + loss = magnitude_loss + complex_loss * hp.train.complex_loss_ratio optimizer.zero_grad() loss.backward()