diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 0e8de20d..919e589a 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -18,7 +18,10 @@ def _fp8_post_backward_hook(state, handle, *unused): if accumulate_grad and torch.count_nonzero(state._flat_param._saved_grad_shard).item() > 0: raise NotImplementedError('accumulate_grad is not supported yet for fp8') - old_communication_hook = state._communication_hook - state._communication_hook = state._get_fp8_comm_hook() + comm_hook_attr = '_communication_hook' if hasattr(state, '_communication_hook') else '_comm_hook' + + old_communication_hook = getattr(state, comm_hook_attr) + setattr(state, comm_hook_attr, state._get_fp8_comm_hook()) old_post_backward_hook(state, handle, *unused) - state._communication_hook = old_communication_hook + setattr(state, comm_hook_attr, old_communication_hook) + diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 35807f7f..393304a4 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -3,24 +3,18 @@ """MS-AMP fsdp.flat_param module.""" -from typing import Optional, Sequence - import torch -import torch.nn as nn from torch.distributed.fsdp.flat_param import FlatParamHandle class FP8FlatParamHandle(FlatParamHandle): """A handle for a flat parameter which may have fp32 and fp8.""" - def _init_flat_param( - self, - params: Sequence[Optional[nn.Parameter]], - module: nn.Module, - use_orig_params: bool, - ) -> None: - """Initialize the flat parameter and save fp8 related metadata.""" - super()._init_flat_param(params, module, use_orig_params) + def __init__(self, *args, **kwargs): + super().__init__( *args, **kwargs) + self._init_fp8_meta() + def _init_fp8_meta(self): + """Save fp8 related metadata.""" metas = [] paddeds = [] original_shapes = [] @@ -52,11 +46,11 @@ def _use_unsharded_views(self, as_params: bool) -> None: for i, param_info in enumerate(self.flat_param._param_infos): if hasattr(param_info.module, param_info.param_name): param = getattr(param_info.module, param_info.param_name) - - param._scaling_metas = self.flat_param._scaling_metas[i] - param._meta = self.flat_param._metas[i] - param._padded = self.flat_param._paddeds[i] - param._original_shape = self.flat_param._original_shapes[i] + if hasattr(self.flat_param, '_scaling_metas'): + param._scaling_metas = self.flat_param._scaling_metas[i] + param._meta = self.flat_param._metas[i] + param._padded = self.flat_param._paddeds[i] + param._original_shape = self.flat_param._original_shapes[i] @torch.no_grad() def _use_sharded_views(self) -> None: diff --git a/msamp/fsdp/fully_sharded_data_parallel.py b/msamp/fsdp/fully_sharded_data_parallel.py index 2fc9438a..7f47d4d8 100644 --- a/msamp/fsdp/fully_sharded_data_parallel.py +++ b/msamp/fsdp/fully_sharded_data_parallel.py @@ -31,13 +31,11 @@ def _fp8_allreduce_hook(state, grad, output): from msamp.operators.dist_op import DistOp dtype = Dtypes.get_dtype_from_qtype(meta.qtype) DistOp.enable_fp8(meta.qtype) - torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group) + torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group if state else None) DistOp.disable_fp8() else: - default_hooks.allreduce_hook( - state=state, - grad=grad[start:end], - ) + torch.distributed.all_reduce(grad[start:end], group=state.process_group if state else None) + start = self.rank * output.numel() end = (self.rank + 1) * output.numel() output.copy_(grad[start:end]) diff --git a/pyproject.toml b/pyproject.toml index 04f1fb4d..b91a1d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,7 @@ classifiers=[ ] dependencies = [ "torch", - "transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@v0.11#egg=transformer-engine", - "flash-attn==1.0.9", + "transformer-engine@git+git+https://github.com/NVIDIA/TransformerEngine.git@stable", "colorlog>=6.7.0", "deepspeed==0.13.1", "mpi4py", diff --git a/tests/fsdp/test_fsdp_distributed.py b/tests/fsdp/test_fsdp_distributed.py index 9379cb75..f11098eb 100644 --- a/tests/fsdp/test_fsdp_distributed.py +++ b/tests/fsdp/test_fsdp_distributed.py @@ -9,7 +9,7 @@ import torch.distributed as dist from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl -from tests.helper import decorator +#from tests.helper import decorator from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel @@ -37,7 +37,7 @@ def world_size(self): @requires_nccl() @skip_if_lt_x_gpu(2) - @decorator.cuda_test + #@decorator.cuda_test def test_fp8_fsdp(self): """Test forward and backward functionality in FP8 FSDP.""" rank = self.rank diff --git a/tests/te/test_te_replacer.py b/tests/te/test_te_replacer.py index a015fc70..f336c776 100644 --- a/tests/te/test_te_replacer.py +++ b/tests/te/test_te_replacer.py @@ -65,7 +65,7 @@ def _check_model(model): scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)] assert len(scaling_params) == 4 - is_fp8_available, _ = te.fp8.is_fp8_available() + is_fp8_available = te.fp8.check_fp8_support() if is_fp8_available: # Do a forward pass to make sure the model is working. fp8_format = Format.HYBRID