Skip to content

Commit

Permalink
Fix MinSNR. Fix save samples button.
Browse files Browse the repository at this point in the history
  • Loading branch information
RossM committed Jun 7, 2024
1 parent 974c4be commit 4c291b6
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,9 @@ def lora_save_function(weights, filename):
# Compute Soft-MinSNR loss weight
loss_weight = (snr ** -1 + args.min_snr_gamma ** -1) ** -1

if noise_scheduler.config.prediction_type == "epsilon":
loss_weight.div_(snr)

# Compute cumulative loss weight scaled to (0, 1)
cum_loss_weight = torch.cumsum(loss_weight, 0)
cum_loss_weight = cum_loss_weight / cum_loss_weight[-1]
Expand Down Expand Up @@ -2064,7 +2067,7 @@ def loss_fn(model_pred, target):
status_handler.end(status.textinfo)
break

if status.do_save_model or (global_step > 0 and global_step % 10000 == 0):
if status.do_save_model or status.do_save_samples or (global_step > 0 and global_step % 10000 == 0):
check_save(False)

accelerator.wait_for_everyone()
Expand Down

0 comments on commit 4c291b6

Please sign in to comment.