diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index 180319eb3..d4488830c 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -495,6 +495,11 @@ def back(grad): self._save_for_tensor(self.GRADIENT_PREFIX + tname, grad) self._save_custom_tensors_post_step() + # update step time + now = time.time() + if self.step_event: + self.step_event.update_end_time(now) + return back def _backward_apply(self, module):