diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index dd31af8f..23c3e1cd 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -1829,10 +1829,10 @@ def loss_fn(model_pred, target): pixel_target = vae.decode(latents[pixel_mask].to(weight_dtype)).sample pixel_pred = vae.decode(pred_latents[pixel_mask].to(weight_dtype)).sample - pixel_loss[pixel_mask] = F.mse_loss(pixel_pred.float(), pixel_target.float(), reduction="none").mean(dim=(1,2,3)) + pixel_loss[pixel_mask] = F.l1_loss(pixel_pred.float(), pixel_target.float(), reduction="none").mean(dim=(1,2,3)) # Ad hoc weight factor to make losses have approximately the same scale - pixel_loss = 700 * pixel_loss + pixel_loss = 25 * pixel_loss else: pixel_loss = latent_loss.new_zeros([])