File tree 1 file changed +19
-0
lines changed
1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -725,6 +725,25 @@ def maybe_enable_compiled_autograd(
725
725
):
726
726
loss .backward (retain_graph = self .loss_backward_retain_graph )
727
727
728
+ # l1 grad nom
729
+ if isinstance (self .module , DDP ):
730
+ _module = self .module .module
731
+ else :
732
+ _module = self .module
733
+ if hasattr (_module , "gradient_l1_loss" ):
734
+ gradient_l1_loss = _module .gradient_l1_loss ()
735
+ with maybe_enable_compiled_autograd (self .enable_compiled_autograd ):
736
+ if grad_scaler :
737
+ gradient_l1_loss = grad_scaler .scale (gradient_l1_loss )
738
+ gradient_l1_loss .backward (
739
+ retain_graph = self .loss_backward_retain_graph
740
+ )
741
+ else :
742
+ gradient_l1_loss .backward (
743
+ retain_graph = self .loss_backward_retain_graph
744
+ )
745
+ loss = loss + gradient_l1_loss
746
+
728
747
total_grad_norm = None
729
748
if should_update_weights :
730
749
total_grad_norm = self ._update_weights (state )
You can’t perform that action at this time.
0 commit comments