diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index aae613a5..18d56971 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -1602,6 +1602,9 @@ def lora_save_function(weights, filename): # Compute Soft-MinSNR loss weight loss_weight = (snr ** -1 + args.min_snr_gamma ** -1) ** -1 + if noise_scheduler.config.prediction_type == "epsilon": + loss_weight.div_(snr) + # Compute cumulative loss weight scaled to (0, 1) cum_loss_weight = torch.cumsum(loss_weight, 0) cum_loss_weight = cum_loss_weight / cum_loss_weight[-1] @@ -2064,7 +2067,7 @@ def loss_fn(model_pred, target): status_handler.end(status.textinfo) break - if status.do_save_model or (global_step > 0 and global_step % 10000 == 0): + if status.do_save_model or status.do_save_samples or (global_step > 0 and global_step % 10000 == 0): check_save(False) accelerator.wait_for_everyone()