Skip to content

Commit

Permalink
implement power low compression loss
Browse files Browse the repository at this point in the history
  • Loading branch information
stegben committed Sep 1, 2019
1 parent 36b59da commit 0acfd8e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ train:
final: 0.05
summary_interval: 1
checkpoint_interval: 1000
complex_loss_ratio: 0.1 # the lambda in power-law compression loss computation
---
log:
chkpt_dir: 'chkpt'
Expand Down
13 changes: 10 additions & 3 deletions utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,16 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp,
mask = model(mixed_mag, dvec)
output = mixed_mag * mask

# output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
# target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
loss = criterion(output, target_mag)
# Power-law compression
magnitude_loss = criterion(
torch.pow(torch.abs(output), hp.audio.power),
torch.pow(torch.abs(target_mag), hp.audio.power),
)
complex_loss = criterion(
torch.pow(torch.clamp(output, min=0.0), hp.audio.power),
torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power),
)
loss = magnitude_loss + complex_loss * hp.train.complex_loss_ratio

optimizer.zero_grad()
loss.backward()
Expand Down

0 comments on commit 0acfd8e

Please sign in to comment.