From 2a0365707c0c3d91b0b91bc456161db693652ff1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 17 Jun 2023 15:17:21 +0530 Subject: [PATCH] fix load model hook --- training/train_muse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/train_muse.py b/training/train_muse.py index 8eb1e3fd..ef59b786 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -267,7 +267,7 @@ def load_model_hook(models, input_dir): load_model = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=model_cls) ema.load_state_dict(load_model.state_dict()) ema.to(accelerator.device) - del ema + del load_model def save_model_hook(models, weights, output_dir): ema.save_pretrained(os.path.join(output_dir, "ema_model"))