From 81a63fd80f8a4db1cfef912a2c8d108202ad473b Mon Sep 17 00:00:00 2001 From: cyril Date: Sat, 7 Oct 2023 17:07:39 +0800 Subject: [PATCH] fix the bug about loading checkpoint --- video_llama/runners/runner_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/video_llama/runners/runner_base.py b/video_llama/runners/runner_base.py index c9441239..4f7ef54b 100644 --- a/video_llama/runners/runner_base.py +++ b/video_llama/runners/runner_base.py @@ -627,14 +627,14 @@ def _load_checkpoint(self, url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) - checkpoint = torch.load(cached_file, map_location=self.device, strict=False) + checkpoint = torch.load(cached_file, map_location=self.device) elif os.path.isfile(url_or_filename): - checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False) + checkpoint = torch.load(url_or_filename, map_location=self.device) else: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] - self.unwrap_dist_model(self.model).load_state_dict(state_dict) + self.unwrap_dist_model(self.model).load_state_dict(state_dict, strict=False) self.optimizer.load_state_dict(checkpoint["optimizer"]) if self.scaler and "scaler" in checkpoint: