Skip to content

Commit

Permalink
Update te to latest stable version
Browse files Browse the repository at this point in the history
  • Loading branch information
root authored and root committed Feb 1, 2024
1 parent 9b4a462 commit 67c0a3f
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 29 deletions.
9 changes: 6 additions & 3 deletions msamp/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def _fp8_post_backward_hook(state, handle, *unused):
if accumulate_grad and torch.count_nonzero(state._flat_param._saved_grad_shard).item() > 0:
raise NotImplementedError('accumulate_grad is not supported yet for fp8')

old_communication_hook = state._communication_hook
state._communication_hook = state._get_fp8_comm_hook()
comm_hook_attr = '_communication_hook' if hasattr(state, '_communication_hook') else '_comm_hook'

old_communication_hook = getattr(state, comm_hook_attr)
setattr(state, comm_hook_attr, state._get_fp8_comm_hook())
old_post_backward_hook(state, handle, *unused)
state._communication_hook = old_communication_hook
setattr(state, comm_hook_attr, old_communication_hook)

26 changes: 10 additions & 16 deletions msamp/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,18 @@

"""MS-AMP fsdp.flat_param module."""

from typing import Optional, Sequence

import torch
import torch.nn as nn
from torch.distributed.fsdp.flat_param import FlatParamHandle


class FP8FlatParamHandle(FlatParamHandle):
"""A handle for a flat parameter which may have fp32 and fp8."""
def _init_flat_param(
self,
params: Sequence[Optional[nn.Parameter]],
module: nn.Module,
use_orig_params: bool,
) -> None:
"""Initialize the flat parameter and save fp8 related metadata."""
super()._init_flat_param(params, module, use_orig_params)
def __init__(self, *args, **kwargs):
super().__init__( *args, **kwargs)
self._init_fp8_meta()

def _init_fp8_meta(self):
"""Save fp8 related metadata."""
metas = []
paddeds = []
original_shapes = []
Expand Down Expand Up @@ -52,11 +46,11 @@ def _use_unsharded_views(self, as_params: bool) -> None:
for i, param_info in enumerate(self.flat_param._param_infos):
if hasattr(param_info.module, param_info.param_name):
param = getattr(param_info.module, param_info.param_name)

param._scaling_metas = self.flat_param._scaling_metas[i]
param._meta = self.flat_param._metas[i]
param._padded = self.flat_param._paddeds[i]
param._original_shape = self.flat_param._original_shapes[i]
if hasattr(self.flat_param, '_scaling_metas'):
param._scaling_metas = self.flat_param._scaling_metas[i]
param._meta = self.flat_param._metas[i]
param._padded = self.flat_param._paddeds[i]
param._original_shape = self.flat_param._original_shapes[i]

@torch.no_grad()
def _use_sharded_views(self) -> None:
Expand Down
8 changes: 3 additions & 5 deletions msamp/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ def _fp8_allreduce_hook(state, grad, output):
from msamp.operators.dist_op import DistOp
dtype = Dtypes.get_dtype_from_qtype(meta.qtype)
DistOp.enable_fp8(meta.qtype)
torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group)
torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group if state else None)
DistOp.disable_fp8()
else:
default_hooks.allreduce_hook(
state=state,
grad=grad[start:end],
)
torch.distributed.all_reduce(grad[start:end], group=state.process_group if state else None)

start = self.rank * output.numel()
end = (self.rank + 1) * output.numel()
output.copy_(grad[start:end])
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ classifiers=[
]
dependencies = [
"torch",
"transformer-engine@git+https://github.com/NVIDIA/[email protected]#egg=transformer-engine",
"flash-attn==1.0.9",
"transformer-engine@git+git+https://github.com/NVIDIA/TransformerEngine.git@stable",
"colorlog>=6.7.0",
"deepspeed==0.13.1",
"mpi4py",
Expand Down
4 changes: 2 additions & 2 deletions tests/fsdp/test_fsdp_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.distributed as dist
from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl

from tests.helper import decorator
#from tests.helper import decorator
from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel


Expand Down Expand Up @@ -37,7 +37,7 @@ def world_size(self):

@requires_nccl()
@skip_if_lt_x_gpu(2)
@decorator.cuda_test
#@decorator.cuda_test
def test_fp8_fsdp(self):
"""Test forward and backward functionality in FP8 FSDP."""
rank = self.rank
Expand Down
2 changes: 1 addition & 1 deletion tests/te/test_te_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _check_model(model):

scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)]
assert len(scaling_params) == 4
is_fp8_available, _ = te.fp8.is_fp8_available()
is_fp8_available = te.fp8.check_fp8_support()
if is_fp8_available:
# Do a forward pass to make sure the model is working.
fp8_format = Format.HYBRID
Expand Down

0 comments on commit 67c0a3f

Please sign in to comment.