Skip to content
This repository has been archived by the owner on Oct 22, 2023. It is now read-only.

Fix "save every n epochs" handling #123

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 26 additions & 32 deletions scripts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,16 +1108,15 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True):

if args.stop_text_encoder_training == True:
save_dir = frozen_directory
if step != 0:
if save_model:
pipeline.save_pretrained(save_dir,safe_serialization=True)
with open(os.path.join(save_dir, "args.json"), "w") as f:
json.dump(args.__dict__, f, indent=2)
if args.stop_text_encoder_training == True:
#delete every folder in frozen_directory but the text encoder
for folder in os.listdir(save_dir):
if folder != "text_encoder" and os.path.isdir(os.path.join(save_dir, folder)):
shutil.rmtree(os.path.join(save_dir, folder))
if save_model:
pipeline.save_pretrained(save_dir,safe_serialization=True)
with open(os.path.join(save_dir, "args.json"), "w") as f:
json.dump(args.__dict__, f, indent=2)
if args.stop_text_encoder_training == True:
#delete every folder in frozen_directory but the text encoder
for folder in os.listdir(save_dir):
if folder != "text_encoder" and os.path.isdir(os.path.join(save_dir, folder)):
shutil.rmtree(os.path.join(save_dir, folder))
imgs = []
if args.add_sample_prompt is not None or batch_prompts != [] and args.stop_text_encoder_training != True:
prompts = []
Expand Down Expand Up @@ -1511,42 +1510,38 @@ def help(event=None):
progress_bar_e.refresh()
global_step += 1

if mid_quit_step==True:
accelerator.wait_for_everyone()
save_and_sample_weights(global_step,'quit_step')
quit()
if mid_generation==True:
mid_train_playground(global_step)
mid_generation=False
if mid_checkpoint_step == True:
save_and_sample_weights(global_step,'step',save_model=True)
mid_checkpoint_step=False
if mid_sample_step == True:
mid_sample_step=False
elif mid_sample_step == True:
save_and_sample_weights(global_step,'step',save_model=False)
mid_sample_step=False
if mid_quit_step==True:
accelerator.wait_for_everyone()
save_and_sample_weights(global_step,'quit_step')
quit()
if global_step >= args.max_train_steps:
break
progress_bar_e.update(1)
if mid_quit==True:
accelerator.wait_for_everyone()
save_and_sample_weights(epoch,'quit_epoch')
quit()
if not epoch % args.save_every_n_epoch:
if args.save_every_n_epoch == 1 and epoch == 0:
save_and_sample_weights(epoch,'epoch')
if epoch != 0:
save_and_sample_weights(epoch,'epoch')
else:
pass
#save_and_sample_weights(epoch,'epoch',False)
print_instructions()
if epoch % args.save_every_n_epoch and mid_checkpoint==True or mid_sample==True:
if mid_checkpoint==True:
save_and_sample_weights(epoch,'epoch',True)
mid_checkpoint=False
elif mid_sample==True:
save_and_sample_weights(epoch,'epoch',False)
mid_sample=False
if epoch == args.num_train_epochs - 1:
save_and_sample_weights(epoch,'epoch',True)
elif args.save_every_n_epoch and (epoch + 1) % args.save_every_n_epoch == 0:
save_and_sample_weights(epoch,'epoch',True)
elif mid_checkpoint==True:
save_and_sample_weights(epoch,'epoch',True)
mid_checkpoint=False
mid_sample=False
elif mid_sample==True:
save_and_sample_weights(epoch,'epoch',False)
mid_sample=False
accelerator.wait_for_everyone()
except Exception:
try:
Expand All @@ -1558,7 +1553,6 @@ def help(event=None):
raise
except KeyboardInterrupt:
send_telegram_message("Training stopped", args.telegram_chat_id, args.telegram_token)
save_and_sample_weights(args.num_train_epochs,'epoch')
try:
send_telegram_message("Training finished!", args.telegram_chat_id, args.telegram_token)
except:
Expand Down