diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index e02b74a3..5b6d12db 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -1277,7 +1277,7 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]: sample[new_key] = sample[key].detach().clone() # Copy Torch tensor else: sample[new_key] = copy.deepcopy(sample[key]) - return sample + return sample class CompressCoilModule(DirectModule): diff --git a/direct/engine.py b/direct/engine.py index f867c9c3..c599a206 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -326,8 +326,17 @@ def training_loop( gc.collect() torch.cuda.empty_cache() continue - - self.checkpoint_and_write_to_logs(iter_idx) + elif "Rejection sampled exceeded number of tries." in str(e): + if fail_counter == 10: + self.checkpoint_and_write_to_logs(iter_idx) + raise TrainingException(f"Rejection sampled exceeded number of tries 10 times in a row: {e}.") + fail_counter += 1 + self.logger.info(f"Rejection sampled exceeded number of tries. Retry {fail_counter}/10.") + self.__optimizer.zero_grad() + gc.collect() + torch.cuda.empty_cache() + continue + # self.checkpoint_and_write_to_logs(iter_idx) self.logger.info(f"Cannot recover from exception {e}. Exiting.") raise RuntimeError(e)