diff --git a/msamp/deepspeed/runtime/engine.py b/msamp/deepspeed/runtime/engine.py index c081513f..0209bb7b 100644 --- a/msamp/deepspeed/runtime/engine.py +++ b/msamp/deepspeed/runtime/engine.py @@ -4,14 +4,17 @@ # SPDX-License-Identifier: Apache-2.0. DeepSpeed Team) """DeepSpeedEngine in MS-AMP.""" +import torch import deepspeed from deepspeed.runtime.engine import SparseTensor, ZERO_OPTIMIZATION, AMP, amp, \ FP16, BFLOAT16, logger, DeepSpeedEngine, instrument_w_nvtx, log_dist, \ see_memory_usage, DummyOptim, DeepSpeedZeroOptimizer, DeepSpeedZeRoOffload, \ PipelineModule, ZeroStageEnum +from deepspeed.moe.utils import is_moe_param from msamp import initialize as msamp_initialize -from msamp.common.tensor import ScalingTensor, TensorDist +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingTensor, TensorDist, ScalingMeta from msamp.optim import LBOptimizer from msamp.deepspeed.runtime.fp8.fused_optimizer import FP8Optimizer from msamp.deepspeed.runtime.zero import utils # noqa: F401 @@ -301,6 +304,38 @@ def _configure_zero_optimizer(self, optimizer): return optimizer + def _get_gradients_for_reduction(self): + non_expert_grads = [] + expert_grads = {} + if self.has_moe_layers: + for key in self.expert_data_parallel_group.keys(): + expert_grads[key] = [] + + for param_name, param in self.module.named_parameters(): + if param.grad is None: + # In cases where there is an imbalance of empty grads across + # ranks we must create empty grads, this will ensure that every + # rank is reducing the same size. In some cases it may make + # sense in the future to support the ability to average not + # w.r.t. world size but with a different value. + if isinstance(param, ScalingTensor): + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + param.grad = ScalingTensor(torch.zeros(param.size(), dtype=param.dtype, device=param.device), meta) + else: + param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) + + grad_data = param.grad.data + if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: + # Call param.grad without data to avoid problem with setting of updated grads + grad_data = SparseTensor(param.grad) + + if is_moe_param(param): + expert_grads[param.group_name].append(grad_data) + else: + non_expert_grads.append(grad_data) + + return non_expert_grads, expert_grads + @instrument_w_nvtx def backward( # noqa: C901 self, @@ -434,3 +469,4 @@ def msamp_enabled(self): def msamp_optlevel(self): """Return the opt level of MS-AMP.""" return self._config.msamp_optlevel +