diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 2f5bdee6..0bc266b2 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -174,8 +174,12 @@ def run_crossval(self, dataset): ) # initial short training loop to get model image to approximate dirty image model_pretrained = train_to_dirty_image(model=model, imager=self._imager) + # save the model to a state we can load in subsequent kfolds + torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt") + else: + # create a new model for this kfold, initializing it to the model pretrained on the dirty image + model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt")) - trainer = TrainTest( imager=self._imager, optimizer=optimizer, epochs=self._epochs,