From 2e48bb27eca234d04262177d388a91d5de37abec Mon Sep 17 00:00:00 2001 From: "David Li (RL)" Date: Tue, 4 Mar 2025 04:53:49 -0800 Subject: [PATCH] gradient regularization loss Differential Revision: D70343465 Privacy Context Container: L1277806 --- torchtnt/framework/auto_unit.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index c39e9be528..f7ddfbf276 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -725,6 +725,25 @@ def maybe_enable_compiled_autograd( ): loss.backward(retain_graph=self.loss_backward_retain_graph) + # l1 grad nom + if isinstance(self.module, DDP): + _module = self.module.module + else: + _module = self.module + if hasattr(_module, "gradient_l1_loss"): + gradient_l1_loss = _module.gradient_l1_loss() + with maybe_enable_compiled_autograd(self.enable_compiled_autograd): + if grad_scaler: + gradient_l1_loss = grad_scaler.scale(gradient_l1_loss) + gradient_l1_loss.backward( + retain_graph=self.loss_backward_retain_graph + ) + else: + gradient_l1_loss.backward( + retain_graph=self.loss_backward_retain_graph + ) + loss = loss + gradient_l1_loss + total_grad_norm = None if should_update_weights: total_grad_norm = self._update_weights(state)