Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KaiservonAfrika authored Oct 11, 2022
1 parent 197c816 commit c38b5e6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
shutil.copyfile(os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), "/content/drive/My Drive/NewestD.pth")
shutil.copyfile(os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), "/content/drive/My Drive/NewestG.pth")
# old_g=os.path.join(hps.model_dir, "G_{}.pth".format(global_step-2000))
# old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-2000))
# if os.path.exists(old_g):
# os.remove(old_g)
# if os.path.exists(old_d):
# os.remove(old_d)
# shutil.copyfile(os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), "/content/drive/My Drive/tempmodel/NewestD.pth")
# shutil.copyfile(os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), "/content/drive/My Drive/tempmodel/NewestG.pth")
old_g=os.path.join(hps.model_dir, "G_{}.pth".format(global_step-2000))
old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-2000))
if os.path.exists(old_g):
os.remove(old_g)
if os.path.exists(old_d):
os.remove(old_d)
global_step += 1

if rank == 0:
Expand Down

0 comments on commit c38b5e6

Please sign in to comment.