Skip to content

Commit

Permalink
Fix L1 loss weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
RossM committed Jun 11, 2024
1 parent 4c291b6 commit 67eaa58
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,7 @@ def loss_fn(model_pred, target):
latent_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))

if args.l1_weight != 0:
l1_weight = args.l1_weight * (1 - (1 - alpha_prod) ** args.l1_gamma)
l1_weight = args.l1_weight * (1 - (1 - alpha_prod) ** args.l1_gamma) * snr ** 0.5
latent_loss = latent_loss + l1_weight * F.l1_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))

loss = latent_loss
Expand Down

0 comments on commit 67eaa58

Please sign in to comment.