Skip to content

Commit

Permalink
Use L1 pixel loss
Browse files Browse the repository at this point in the history
  • Loading branch information
RossM committed May 20, 2024
1 parent 72213a8 commit abbc723
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,10 +1829,10 @@ def loss_fn(model_pred, target):
pixel_target = vae.decode(latents[pixel_mask].to(weight_dtype)).sample
pixel_pred = vae.decode(pred_latents[pixel_mask].to(weight_dtype)).sample

pixel_loss[pixel_mask] = F.mse_loss(pixel_pred.float(), pixel_target.float(), reduction="none").mean(dim=(1,2,3))
pixel_loss[pixel_mask] = F.l1_loss(pixel_pred.float(), pixel_target.float(), reduction="none").mean(dim=(1,2,3))

# Ad hoc weight factor to make losses have approximately the same scale
pixel_loss = 700 * pixel_loss
pixel_loss = 25 * pixel_loss

else:
pixel_loss = latent_loss.new_zeros([])
Expand Down

0 comments on commit abbc723

Please sign in to comment.