diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index d972fd96ab..070f46e937 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -3,8 +3,12 @@ # See LICENSE for license information. """Fused Adam optimizer.""" +from __future__ import annotations +from collections.abc import Iterable from copy import deepcopy from itertools import chain +from typing import Optional +import warnings import torch import transformer_engine_torch as tex @@ -52,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer): params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) - bias_correction (bool, optional): apply correction factor to - moment estimates. (default: True) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve @@ -62,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer): amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) NOT SUPPORTED in FusedAdam! + bias_correction (bool, optional): apply correction factor to + moment estimates. (default: True) adam_w_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) master_weights (bool, optional): whether to maintain FP32 master weights @@ -106,15 +108,15 @@ class FusedAdam(torch.optim.Optimizer): def __init__( self, - params, - lr=1e-3, + params: Iterable[torch.nn.Parameter | dict], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + *, bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, adam_w_mode=True, - weight_decay=0.0, - amsgrad=False, - set_grad_none=True, capturable=False, master_weights=False, master_weight_dtype=torch.float32, @@ -122,6 +124,7 @@ def __init__( exp_avg_sq_dtype=torch.float32, use_decoupled_grad=False, store_param_remainders=False, + set_grad_none: Optional[bool] = None, # deprecated ): if amsgrad: @@ -160,7 +163,6 @@ def __init__( } super().__init__(params, defaults) self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none self.capturable = capturable self.master_weights = master_weights @@ -204,19 +206,46 @@ def __init__( store_param_remainders and master_weights and master_weight_dtype == torch.float32 ) - def zero_grad(self): - # pylint: disable=missing-function-docstring - if not self.use_decoupled_grad and not self.set_grad_none: - super().zero_grad() + # Deprecated options + self.set_grad_none = set_grad_none + if self.set_grad_none is not None: + warnings.warn( + "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "Use set_to_none kwarg in zero_grad instead.", + DeprecationWarning, + ) + + def zero_grad(self, set_to_none: Optional[bool] = None) -> None: + """Reset parameter gradients. + + Arguments: + set_to_none (bool, optional): whether to set grads to `None` + instead of zeroing out buffers. (default: True) + + """ + + # Handle deprecated set_grad_none option + if self.set_grad_none is not None: + if set_to_none is not None and set_to_none != self.set_grad_none: + raise ValueError( + f"Called zero_grad with set_to_none={set_to_none}, " + f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + ) + set_to_none = self.set_grad_none + if set_to_none is None: + set_to_none = True + + if not self.use_decoupled_grad and not set_to_none: + super().zero_grad(set_to_none=set_to_none) return for group in self.param_groups: for p in group["params"]: - if self.use_decoupled_grad and self.set_grad_none: + if self.use_decoupled_grad and set_to_none: p.decoupled_grad = None - elif self.use_decoupled_grad and not self.set_grad_none: + elif self.use_decoupled_grad and not set_to_none: p.decoupled_grad.zero_() - elif not self.use_decoupled_grad and self.set_grad_none: + elif not self.use_decoupled_grad and set_to_none: p.grad = None def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 53fa59821c..8a76ec5901 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -3,6 +3,11 @@ # See LICENSE for license information. """Fused SGD optimizer.""" +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional +import warnings + import torch from torch.optim.optimizer import Optimizer, required @@ -37,8 +42,8 @@ class FusedSGD(Optimizer): parameter groups lr (float): learning rate momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) Example: @@ -74,15 +79,16 @@ class FusedSGD(Optimizer): def __init__( self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, + params: Iterable[torch.nn.Parameter | dict], + lr: float | Any = required, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + *, wd_after_momentum=False, materialize_master_grads=True, - set_grad_none=False, + set_grad_none: Optional[bool] = None, # deprecated ): if lr is not required and lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") @@ -98,7 +104,7 @@ def __init__( "weight_decay": weight_decay, "nesterov": nesterov, } - if nesterov and (momentum <= 0 or dampening != 0): + if nesterov and (momentum <= 0.0 or dampening != 0.0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) @@ -106,7 +112,6 @@ def __init__( self.materialize_master_grads = materialize_master_grads self.most_recent_scale = 1.0 self.scale_set_by_backward = False - self.set_grad_none = set_grad_none # Skip buffer self._dummy_overflow_buf = torch.tensor( @@ -114,14 +119,42 @@ def __init__( ) self.multi_tensor_sgd = tex.multi_tensor_sgd + # Deprecated options + self.set_grad_none = set_grad_none + if self.set_grad_none is not None: + warnings.warn( + "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "Use set_to_none kwarg in zero_grad instead.", + DeprecationWarning, + ) + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) - def zero_grad(self): - # pylint: disable=missing-function-docstring - if self.set_grad_none: + def zero_grad(self, set_to_none: Optional[bool] = None) -> None: + """Reset parameter gradients. + + Arguments: + set_to_none (bool, optional): whether to set grads to `None` + instead of zeroing out buffers. (default: True) + + """ + + # Handle deprecated set_grad_none option + if self.set_grad_none is not None: + if set_to_none is not None and set_to_none != self.set_grad_none: + raise ValueError( + f"Called zero_grad with set_to_none={set_to_none}, " + f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + ) + set_to_none = self.set_grad_none + if set_to_none is None: + set_to_none = True + + # Reset grads + if set_to_none: for group in self.param_groups: for p in group["params"]: p.grad = None