diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index c39e9be528..f7ddfbf276 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -725,6 +725,25 @@ def maybe_enable_compiled_autograd( ): loss.backward(retain_graph=self.loss_backward_retain_graph) + # l1 grad nom + if isinstance(self.module, DDP): + _module = self.module.module + else: + _module = self.module + if hasattr(_module, "gradient_l1_loss"): + gradient_l1_loss = _module.gradient_l1_loss() + with maybe_enable_compiled_autograd(self.enable_compiled_autograd): + if grad_scaler: + gradient_l1_loss = grad_scaler.scale(gradient_l1_loss) + gradient_l1_loss.backward( + retain_graph=self.loss_backward_retain_graph + ) + else: + gradient_l1_loss.backward( + retain_graph=self.loss_backward_retain_graph + ) + loss = loss + gradient_l1_loss + total_grad_norm = None if should_update_weights: total_grad_norm = self._update_weights(state)