From d74145ce4a4406e7e9b684121898e0f7bdce7451 Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Tue, 28 Nov 2023 22:54:57 -0500 Subject: [PATCH] run_crossval: save/load pre-trained model state --- src/mpol/crossval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 2f5bdee6..0bc266b2 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -174,8 +174,12 @@ def run_crossval(self, dataset): ) # initial short training loop to get model image to approximate dirty image model_pretrained = train_to_dirty_image(model=model, imager=self._imager) + # save the model to a state we can load in subsequent kfolds + torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt") + else: + # create a new model for this kfold, initializing it to the model pretrained on the dirty image + model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt")) - trainer = TrainTest( imager=self._imager, optimizer=optimizer, epochs=self._epochs,