Skip to content

Commit

Permalink
Pixel loss improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
RossM committed May 9, 2024
1 parent 467df26 commit 7247f1a
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,9 @@ def create_vae():
torch_dtype=torch.float32,
)

cum_latent_loss = 0
cum_pixel_loss = 0

printm("Created tenc")
pbar2.set_description("Loading VAE...")
pbar2.update()
Expand Down Expand Up @@ -1797,13 +1800,13 @@ def lora_save_function(weights, filename):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

def loss_fn(model_pred, target):
alpha_prod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps]
sqrt_alpha_prod = alpha_prod ** 0.5
sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5
snr = alpha_prod / (1 - alpha_prod)

latent_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))
if args.pixel_loss_weight != 0:
alpha_prod = noise_scheduler.alphas_cumprod.to(timesteps.device)[
timesteps, None, None, None]
sqrt_alpha_prod = alpha_prod ** 0.5
sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5

if noise_scheduler.config.prediction_type == "epsilon":
pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod * model_pred) / sqrt_alpha_prod
elif noise_scheduler.config.prediction_type == "v_prediction":
Expand All @@ -1816,10 +1819,22 @@ def loss_fn(model_pred, target):
pixel_pred = vae.decode(pred_latents.to(weight_dtype)).sample

pixel_loss = F.mse_loss(pixel_pred.float(), pixel_target.float(), reduction="none").mean(dim=(1,2,3))

# Ad hoc weight factor to make losses have approximately the same scale
pixel_loss = 2000 * pixel_loss
else:
pixel_loss = latent_loss.new_zeros([])

nonlocal cum_pixel_loss, cum_latent_loss
cum_pixel_loss += pixel_loss.mean().item()
cum_latent_loss += latent_loss.mean().item()

pixel_loss_weight = args.pixel_loss_weight * (1 - sqrt_one_minus_alpha_prod)

loss = torch.lerp(latent_loss, pixel_loss, pixel_loss_weight).mean()

#progress_bar.write(f"{latent_loss.mean()}, {pixel_loss.mean()}, {loss}, {cum_latent_loss / cum_pixel_loss}")

loss = torch.lerp(latent_loss.mean(), pixel_loss.mean(), args.pixel_loss_weight)
return loss

if args.model_type != "SDXL":
Expand Down

0 comments on commit 7247f1a

Please sign in to comment.