Skip to content

Commit

Permalink
run_crossval: save/load pre-trained model state
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Nov 29, 2023
1 parent aa23919 commit d74145c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,12 @@ def run_crossval(self, dataset):
)
# initial short training loop to get model image to approximate dirty image
model_pretrained = train_to_dirty_image(model=model, imager=self._imager)
# save the model to a state we can load in subsequent kfolds
torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt")
else:
# create a new model for this kfold, initializing it to the model pretrained on the dirty image
model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt"))

trainer = TrainTest(
imager=self._imager,
optimizer=optimizer,
epochs=self._epochs,
Expand Down

0 comments on commit d74145c

Please sign in to comment.