Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add unit tests for FSDP2 + torch.compile(transformer block) #321

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pytest test/test_numerics_integration.py
./test/test_dtensor.sh

# run integration tests on the FSDP2 integration
python test/test_fsdp2/test_fsdp2_eager.py
python test/test_fsdp2/test_fsdp2.py

# run all of these tests
./test/test_everything.sh
Expand Down
2 changes: 0 additions & 2 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional, Tuple
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix linter from the trunk


import torch

from float8_experimental.float8_tensor import (
Expand Down
6 changes: 4 additions & 2 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure tensor is like tensor(4674.8633) instead of tensor([[4674.8633]]

otherwise torch.compile errors out in gurads, torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the traces? I want to make sure there is no CPU sync point introduced from making this tensor a scalar tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the traces? I want to make sure there is no CPU sync point introduced from making this tensor a scalar tensor

sure. I checked the trace and it's seems to be purely executed on cpu, no kernels launch, and no cudaStreamSynchronize if that's what you refer as "CPU sync point"
Screenshot 2024-07-17 at 2 26 24 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when the scalar is used later downstream.

Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise.

TL;DR: this is more like a bug when I implement precompute_float8_dynamic_scale_for_fsdp

for the 1st iteration, self._precomputed_scale is None and thus we calcuclate scale through cast_to_float8_e4m3_dynamic (code) , where scale are in tensor(4674.8633). Dynamo generates a guard assersion on tensor(4674.8633).size() and tensor(4674.8633).stride(), so it expect same input shapes in 2nd iteration

for the 2nd iteration after precompute_float8_dynamic_scale_for_fsdp, we have self._precomputed_scale=tensor([[4674.8633]]) because I only called torch.split(scale_tensor, 1) without .squeeze. Guard assersion find out .size() and .stride() changed and throw out the error

does it make sense to say this is a bug in user code, instead of a misfunction in dynamo ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when the scalar is used later downstream.

ah, I see. I should be looking for cudaStreamSynchronize, right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be like cudaDeviceSynchronize if I understand correctly (but basically you would see the CPU thread blocked).

Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when the scalar is used later downstream.

_precomputed_scale will be used inside fsdp_pre_all_gather when calling following code https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/fsdp_utils.py#L167

float8_tensor = Float8Tensor.to_float8(
    self._tensor,
    self._precomputed_scale,
    torch.float8_e4m3fn,
    mm_config=self._mm_config,
)

I annotated the function with record_function("Float8Tensor.to_float8"). Here are the snapshots for cpu thread and cuda stream

Screenshot 2024-07-17 at 4 33 53 PM Screenshot 2024-07-17 at 4 34 31 PM

in both cases, I do not see cudaStreamSynchronize and cuda stream stays ahead of cpu thread

any worries ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! should be fine

)


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down Expand Up @@ -301,7 +303,7 @@ def __tensor_flatten__(self):
],
{
"mm_config": self._mm_config,
"is_amax_initialized": is_amax_initialized,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit run --all-files complains about undefined is_amax_initialized in trunk. fixing it so I can commit without bypassing linter

"is_amax_initialized": self.is_amax_initialized,
},
)

Expand Down
2 changes: 1 addition & 1 deletion test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ then
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_dtensor.sh
pytest test/test_fsdp2/test_fsdp2_eager.py
pytest test/test_fsdp2/test_fsdp2.py
fi

echo "all tests successful"
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_transformer_parity(self):
TensorScalingType.DYNAMIC,
TensorScalingType.DELAYED,
],
"compile_transformer_block": [False, True],
},
self._test_transformer_parity,
)
Expand All @@ -98,6 +99,7 @@ def _test_transformer_parity(
enable_fsdp_fp8_all_gather: bool,
precompute: bool,
scaling_type_w: TensorScalingType,
compile_transformer_block: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
return
Expand All @@ -112,11 +114,17 @@ def _test_transformer_parity(
module = self.init_transformer(weight_tying=weight_tying).cuda()
ref_module = copy.deepcopy(module)
swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w)
if compile_transformer_block:
for layer_id, transformer_block in ref_module.layers.named_children():
transformer_block = torch.compile(transformer_block, dynamic=False)
ref_module.layers.register_module(layer_id, transformer_block)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w)
for submodule in module.modules():
if isinstance(submodule, TransformerBlock):
fully_shard(submodule)
for layer_id, transformer_block in module.layers.named_children():
if compile_transformer_block:
transformer_block = torch.compile(transformer_block, dynamic=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior?

optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us and keep the more advanced "how to apply compile" logic localized to torchtitan.

Copy link
Contributor Author

@weifengpy weifengpy Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior?

This is just trying to match torchtitan's behavior. The .squeeze is needed regardless of compiling transformer blocks or compiling whole model.

optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us

I want to check at PR time that float8_experimental are compatiable with torchtitan (thus compiling transformer block)

for float8_experimental, I am with you it's good to also cover compiling full model.
For FSDP2, it should work. For FSDP+TP, I remember there is some problem in to compile full model. Will see if I can follow up

fully_shard(transformer_block)
module.layers.register_module(layer_id, transformer_block)
fully_shard(module)
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
Expand All @@ -132,6 +140,7 @@ def _test_transformer_parity(
local_inp,
precompute,
scaling_type_w=scaling_type_w,
compile_transformer_block=compile_transformer_block,
)

@skip_if_lt_x_gpu(2)
Expand Down
8 changes: 6 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
Expand All @@ -23,6 +23,7 @@ def check_parity_no_mp(
local_inp: torch.Tensor,
precompute: bool = False,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
compile_transformer_block: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -46,7 +47,10 @@ def check_parity_no_mp(
):
precompute_float8_dynamic_scale_for_fsdp(model)

test_cls.assertEqual(losses[0], losses[1])
if compile_transformer_block:
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
else:
test_cls.assertEqual(losses[0], losses[1])


def check_parity_bf16_mp(
Expand Down
Loading