diff --git a/tests/fsdp/test_fsdp_distributed.py b/tests/fsdp/test_fsdp_distributed.py index f11098eb..9379cb75 100644 --- a/tests/fsdp/test_fsdp_distributed.py +++ b/tests/fsdp/test_fsdp_distributed.py @@ -9,7 +9,7 @@ import torch.distributed as dist from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl -#from tests.helper import decorator +from tests.helper import decorator from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel @@ -37,7 +37,7 @@ def world_size(self): @requires_nccl() @skip_if_lt_x_gpu(2) - #@decorator.cuda_test + @decorator.cuda_test def test_fp8_fsdp(self): """Test forward and backward functionality in FP8 FSDP.""" rank = self.rank