Skip to content

Commit

Permalink
fix load model hook
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jun 17, 2023
1 parent 3b7708d commit 2a03657
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def load_model_hook(models, input_dir):
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=model_cls)
ema.load_state_dict(load_model.state_dict())
ema.to(accelerator.device)
del ema
del load_model

def save_model_hook(models, weights, output_dir):
ema.save_pretrained(os.path.join(output_dir, "ema_model"))
Expand Down

0 comments on commit 2a03657

Please sign in to comment.