Skip to content

Commit

Permalink
fix bug for none grad in deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Nov 24, 2023
1 parent 4480ffa commit 3623fd2
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion msamp/deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
# SPDX-License-Identifier: Apache-2.0. DeepSpeed Team)

"""DeepSpeedEngine in MS-AMP."""
import torch
import deepspeed
from deepspeed.runtime.engine import SparseTensor, ZERO_OPTIMIZATION, AMP, amp, \
FP16, BFLOAT16, logger, DeepSpeedEngine, instrument_w_nvtx, log_dist, \
see_memory_usage, DummyOptim, DeepSpeedZeroOptimizer, DeepSpeedZeRoOffload, \
PipelineModule, ZeroStageEnum
from deepspeed.moe.utils import is_moe_param

from msamp import initialize as msamp_initialize
from msamp.common.tensor import ScalingTensor, TensorDist
from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingTensor, TensorDist, ScalingMeta
from msamp.optim import LBOptimizer
from msamp.deepspeed.runtime.fp8.fused_optimizer import FP8Optimizer
from msamp.deepspeed.runtime.zero import utils # noqa: F401
Expand Down Expand Up @@ -301,6 +304,38 @@ def _configure_zero_optimizer(self, optimizer):

return optimizer

def _get_gradients_for_reduction(self):
non_expert_grads = []
expert_grads = {}
if self.has_moe_layers:
for key in self.expert_data_parallel_group.keys():
expert_grads[key] = []

for param_name, param in self.module.named_parameters():
if param.grad is None:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
if isinstance(param, ScalingTensor):
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
param.grad = ScalingTensor(torch.zeros(param.size(), dtype=param.dtype, device=param.device), meta)
else:
param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)

grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
# Call param.grad without data to avoid problem with setting of updated grads
grad_data = SparseTensor(param.grad)

if is_moe_param(param):
expert_grads[param.group_name].append(grad_data)
else:
non_expert_grads.append(grad_data)

return non_expert_grads, expert_grads

@instrument_w_nvtx
def backward( # noqa: C901
self,
Expand Down Expand Up @@ -434,3 +469,4 @@ def msamp_enabled(self):
def msamp_optlevel(self):
"""Return the opt level of MS-AMP."""
return self._config.msamp_optlevel

0 comments on commit 3623fd2

Please sign in to comment.