diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 55d7681..5c7d21a 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp +from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp def check_parity_no_mp( @@ -31,7 +31,7 @@ def check_parity_no_mp( # TODO(future): add amax syncing once delayed scaling is supported optim.step() if model is fsdp_model and precompute: - precompute_float8_amax_for_fsdp(model) + precompute_float8_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1])