Skip to content

Commit

Permalink
Recover if error
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Aug 5, 2024
1 parent e319efc commit f468869
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
sample[new_key] = sample[key].detach().clone() # Copy Torch tensor
else:
sample[new_key] = copy.deepcopy(sample[key])
return sample
return sample


class CompressCoilModule(DirectModule):
Expand Down
13 changes: 11 additions & 2 deletions direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,17 @@ def training_loop(
gc.collect()
torch.cuda.empty_cache()
continue

self.checkpoint_and_write_to_logs(iter_idx)
elif "Rejection sampled exceeded number of tries." in str(e):
if fail_counter == 10:
self.checkpoint_and_write_to_logs(iter_idx)
raise TrainingException(f"Rejection sampled exceeded number of tries 10 times in a row: {e}.")
fail_counter += 1
self.logger.info(f"Rejection sampled exceeded number of tries. Retry {fail_counter}/10.")
self.__optimizer.zero_grad()
gc.collect()
torch.cuda.empty_cache()
continue
# self.checkpoint_and_write_to_logs(iter_idx)
self.logger.info(f"Cannot recover from exception {e}. Exiting.")
raise RuntimeError(e)

Expand Down

0 comments on commit f468869

Please sign in to comment.