Skip to content

Commit

Permalink
run_crossval: pre-train to dirty image
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Nov 29, 2023
1 parent fdcf265 commit aa23919
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,14 @@ def run_crossval(self, dataset):
# if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
# train_set, test_set = train_set.to(self._device), test_set.to(self._device)

# create a new model and optimizer for this k_fold
self._model = SimpleNet(coords=self._coords, nchan=self._imager.nchan)
# if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
# self._model = self._model.to(self._device)

optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learn_rate)
if self._start_dirty_image is True:
if kk == 0:
if self._verbose:
logging.info(
"\n Pre-training to dirty image to initialize subsequent optimization loops"
)
# initial short training loop to get model image to approximate dirty image
model_pretrained = train_to_dirty_image(model=model, imager=self._imager)

trainer = TrainTest(
imager=self._imager,
Expand Down

0 comments on commit aa23919

Please sign in to comment.