diff --git a/scripts/trainer.py b/scripts/trainer.py index 92f1b91..aab3146 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -1108,16 +1108,15 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): if args.stop_text_encoder_training == True: save_dir = frozen_directory - if step != 0: - if save_model: - pipeline.save_pretrained(save_dir,safe_serialization=True) - with open(os.path.join(save_dir, "args.json"), "w") as f: - json.dump(args.__dict__, f, indent=2) - if args.stop_text_encoder_training == True: - #delete every folder in frozen_directory but the text encoder - for folder in os.listdir(save_dir): - if folder != "text_encoder" and os.path.isdir(os.path.join(save_dir, folder)): - shutil.rmtree(os.path.join(save_dir, folder)) + if save_model: + pipeline.save_pretrained(save_dir,safe_serialization=True) + with open(os.path.join(save_dir, "args.json"), "w") as f: + json.dump(args.__dict__, f, indent=2) + if args.stop_text_encoder_training == True: + #delete every folder in frozen_directory but the text encoder + for folder in os.listdir(save_dir): + if folder != "text_encoder" and os.path.isdir(os.path.join(save_dir, folder)): + shutil.rmtree(os.path.join(save_dir, folder)) imgs = [] if args.add_sample_prompt is not None or batch_prompts != [] and args.stop_text_encoder_training != True: prompts = [] @@ -1511,19 +1510,20 @@ def help(event=None): progress_bar_e.refresh() global_step += 1 + if mid_quit_step==True: + accelerator.wait_for_everyone() + save_and_sample_weights(global_step,'quit_step') + quit() if mid_generation==True: mid_train_playground(global_step) mid_generation=False if mid_checkpoint_step == True: save_and_sample_weights(global_step,'step',save_model=True) mid_checkpoint_step=False - if mid_sample_step == True: + mid_sample_step=False + elif mid_sample_step == True: save_and_sample_weights(global_step,'step',save_model=False) mid_sample_step=False - if mid_quit_step==True: - accelerator.wait_for_everyone() - save_and_sample_weights(global_step,'quit_step') - quit() if global_step >= args.max_train_steps: break progress_bar_e.update(1) @@ -1531,22 +1531,17 @@ def help(event=None): accelerator.wait_for_everyone() save_and_sample_weights(epoch,'quit_epoch') quit() - if not epoch % args.save_every_n_epoch: - if args.save_every_n_epoch == 1 and epoch == 0: - save_and_sample_weights(epoch,'epoch') - if epoch != 0: - save_and_sample_weights(epoch,'epoch') - else: - pass - #save_and_sample_weights(epoch,'epoch',False) - print_instructions() - if epoch % args.save_every_n_epoch and mid_checkpoint==True or mid_sample==True: - if mid_checkpoint==True: - save_and_sample_weights(epoch,'epoch',True) - mid_checkpoint=False - elif mid_sample==True: - save_and_sample_weights(epoch,'epoch',False) - mid_sample=False + if epoch == args.num_train_epochs - 1: + save_and_sample_weights(epoch,'epoch',True) + elif args.save_every_n_epoch and (epoch + 1) % args.save_every_n_epoch == 0: + save_and_sample_weights(epoch,'epoch',True) + elif mid_checkpoint==True: + save_and_sample_weights(epoch,'epoch',True) + mid_checkpoint=False + mid_sample=False + elif mid_sample==True: + save_and_sample_weights(epoch,'epoch',False) + mid_sample=False accelerator.wait_for_everyone() except Exception: try: @@ -1558,7 +1553,6 @@ def help(event=None): raise except KeyboardInterrupt: send_telegram_message("Training stopped", args.telegram_chat_id, args.telegram_token) - save_and_sample_weights(args.num_train_epochs,'epoch') try: send_telegram_message("Training finished!", args.telegram_chat_id, args.telegram_token) except: