Skip to content

Commit ad03e3e

Browse files
David Li (RL)facebook-github-bot
David Li (RL)
authored andcommitted
gradient regularization loss
Differential Revision: D70343465 Privacy Context Container: L1277806
1 parent 714ae04 commit ad03e3e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

torchtnt/framework/auto_unit.py

+19
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,25 @@ def maybe_enable_compiled_autograd(
725725
):
726726
loss.backward(retain_graph=self.loss_backward_retain_graph)
727727

728+
# l1 grad nom
729+
if isinstance(self.module, DDP):
730+
_module = self.module.module
731+
else:
732+
_module = self.module
733+
if hasattr(_module, "gradient_l1_loss"):
734+
gradient_l1_loss = _module.gradient_l1_loss()
735+
with maybe_enable_compiled_autograd(self.enable_compiled_autograd):
736+
if grad_scaler:
737+
gradient_l1_loss = grad_scaler.scale(gradient_l1_loss)
738+
gradient_l1_loss.backward(
739+
retain_graph=self.loss_backward_retain_graph
740+
)
741+
else:
742+
gradient_l1_loss.backward(
743+
retain_graph=self.loss_backward_retain_graph
744+
)
745+
loss = loss + gradient_l1_loss
746+
728747
total_grad_norm = None
729748
if should_update_weights:
730749
total_grad_norm = self._update_weights(state)

0 commit comments

Comments
 (0)