Skip to content

Commit

Permalink
Update deepspeed to latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Jan 30, 2024
1 parent 2fbe898 commit 8a9ec1f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 26 deletions.
29 changes: 17 additions & 12 deletions msamp/deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
see_memory_usage, DummyOptim, DeepSpeedZeroOptimizer, DeepSpeedZeRoOffload, \
PipelineModule, ZeroStageEnum
from deepspeed.moe.utils import is_moe_param
from deepspeed.accelerator import get_accelerator

from msamp import initialize as msamp_initialize
from msamp.common.dtype import Dtypes
Expand All @@ -32,21 +33,25 @@ def split_half_float_double_sparse(tensors):
Returns:
list: list of buckets, each bucket is a tuple of (dtype, list of tensors).
"""
supported_types = [
'torch.cuda.HalfTensor', 'torch.cuda.FloatTensor', 'torch.cuda.DoubleTensor', 'torch.cuda.BFloat16Tensor',
'msamp.common.tensor.tensor.ScalingTensor',
SparseTensor.type()
]
supported_types = get_accelerator().supported_dtypes() + [torch.int8, torch.uint8]

for t in tensors:
assert t.type() in supported_types, f'attempting to reduce an unsupported grad type: {t.type()}'
assert t.dtype in supported_types, f'attempting to reduce an unsupported grad type: {t.type()}'

buckets = []
sparse_tensor_buckets, dense_tensor_buckets = [], []
for _, dtype in enumerate(supported_types):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append((dtype, bucket))
return buckets
sparse_bucket, dense_bucket = [], []
for t in tensors:
if t.dtype == dtype:
if isinstance(t, SparseTensor):
sparse_bucket.append(t)
else:
dense_bucket.append(t)
if sparse_bucket:
sparse_tensor_buckets.append((dtype, sparse_bucket))
if dense_bucket:
dense_tensor_buckets.append((dtype, dense_bucket))
return sparse_tensor_buckets, dense_tensor_buckets


deepspeed.runtime.engine.split_half_float_double_sparse = split_half_float_double_sparse
Expand Down Expand Up @@ -233,7 +238,7 @@ def _configure_zero_optimizer(self, optimizer):
expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=overlap_comm,
cpu_offload=self.zero_cpu_offload(),
offload_optimizer_config=self.zero_offload_optimizer(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
Expand Down
38 changes: 30 additions & 8 deletions msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed import comm as dist
from deepspeed.runtime.zero.stage_1_and_2 import all_gather_dp_groups, DeepSpeedZeroOptimizer, \
get_accelerator, move_to_cpu, logger, see_memory_usage
get_accelerator, logger, see_memory_usage

from msamp.common.tensor import ScalingTensor, ScalingMeta
from msamp.common.dtype import Dtypes
Expand Down Expand Up @@ -178,7 +178,8 @@ def _pad_and_flat(self, values_partitions, group_fp8_mems, group_id):

# the number of elements in each partition is the same.
values = list(chain(*values_partitions))
move_to_cpu(values)
for value in values:
value.data = value.data.cpu()
# flat tensors
flat = _flatten_dense_tensors(values).cuda()
for p, q in zip(values, _unflatten_dense_tensors(flat, values)):
Expand Down Expand Up @@ -632,6 +633,26 @@ def fp8_get_flat_partition(self, tensor_list):
flat_tensor_list.append(tensor.grad)
return flat_tensor_list

def start_timers(self, timer_names):
if self.timers is None:
return

for name in timer_names:
self.timers(name).start()

def stop_timers(self, timer_names):
if self.timers is None:
return

for name in timer_names:
self.timers(name).stop()

def log_timers(self, timer_names):
if self.timers is None:
return

self.timers.log(names=list(timer_names))

def step(self, closure=None): # noqa C901
"""Performs a single optimization step. closure is not supported."""
self.micro_step_id = -1
Expand Down Expand Up @@ -758,14 +779,14 @@ def step(self, closure=None): # noqa C901
self.start_timers([OPTIMIZER_ALLGATHER])
# Gather the updated weights from everyone.
# Then all partitions of the model parameters are updated and ready for next round forward.
all_gather_dp_groups(
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
dp_process_group=self.real_dp_process_group,
start_alignment_factor=self.nccl_start_alignment_factor,
allgather_bucket_size=self.allgather_bucket_size
)
all_gather_dp_groups(groups_flat=self.bit16_groups_flat,
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
dp_process_group=self.real_dp_process_group,
start_alignment_factor=self.nccl_start_alignment_factor,
allgather_bucket_size=self.allgather_bucket_size)

all_gather_dp_groups(
groups_flat=list(filter(lambda g: g is not None, self.fp8_groups_flat)),
partitioned_param_groups=list(filter(lambda g: g is not None, self.fp8_parallel_partitioned_groups)),
dp_process_group=self.real_dp_process_group,
start_alignment_factor=self.fp8_nccl_start_alignment_factor,
Expand Down Expand Up @@ -813,6 +834,7 @@ def all_gather_fp8_metas(self):

# step 2. all gather
all_gather_dp_groups(
groups_flat=flats,
partitioned_param_groups=scale_invs_parallel_partitioned_groups,
dp_process_group=self.real_dp_process_group,
start_alignment_factor=self.fp8_nccl_start_alignment_factor,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies = [
"transformer-engine@git+https://github.com/NVIDIA/[email protected]#egg=transformer-engine",
"flash-attn==1.0.9",
"colorlog>=6.7.0",
"deepspeed==0.9.2",
"deepspeed==0.13.1",
"mpi4py",
]
dynamic = ["version"]
Expand Down
10 changes: 5 additions & 5 deletions tests/deepspeed/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_split_half_float_double_sparse(self):
"""Test split_half_float_double_sparse method."""
tensors = []

dtype_list = [torch.float32, torch.float16, torch.float64, torch.bfloat16]
dtype_list = [torch.float32, torch.float16, torch.bfloat16]

size_list = [3, 4, 5, 6]
size_list = [3, 4, 5]

for i, size in enumerate(size_list):
for _ in range(size):
Expand All @@ -46,13 +46,13 @@ def test_split_half_float_double_sparse(self):
for i in range(num_scaling_tensor):
tensor = torch.randn(2, 2, dtype=torch.float32, device='cuda').cast(Dtypes.kfloat8_e4m3)
tensors.append(tensor)
buckets = split_half_float_double_sparse(tensors)
_, buckets = split_half_float_double_sparse(tensors)

assert len(buckets) == 5
assert len(buckets) == 4

has_scaling_tensor = False
for dtype, bucket in buckets:
if dtype == 'msamp.common.tensor.tensor.ScalingTensor':
if dtype == torch.uint8:
assert len(bucket) == num_scaling_tensor
has_scaling_tensor = True
assert has_scaling_tensor
Expand Down

0 comments on commit 8a9ec1f

Please sign in to comment.