Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add unit tests for FSDP2 + torch.compile(transformer block) #321

Closed
wants to merge 6 commits into from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jul 17, 2024

TorchTitan complains about FSDP2 + float8 + torch.compile(transformer block).

there is a mismatch in float8 scale so dynamo guards assersion failed torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())

  • in 1st iteration, we calculate float8 scale through cast_to_float8_e4m3_dynamic (code). scale is a scalar tensor, eg tensor(4674.8633)
  • in 2nd iteration, we calulate float8 scale through precompute_float8_dynamic_scale, but scale is NOT a scalar tensor, eg tensor([[4674.8633]]
  • this PR calls .squeeze to make sure scales are always scalar tensors, and dynamo guards assersion always hold true

added unit test so we can catch the isssue at PR time

TODO: add fp8 + torch.compile to CI in torchtitan

weifengpy and others added 3 commits July 17, 2024 15:11
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 17, 2024
@@ -301,7 +303,7 @@ def __tensor_flatten__(self):
],
{
"mm_config": self._mm_config,
"is_amax_initialized": is_amax_initialized,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit run --all-files complains about undefined is_amax_initialized in trunk. fixing it so I can commit without bypassing linter

@@ -46,7 +47,10 @@ def check_parity_no_mp(
):
precompute_float8_dynamic_scale_for_fsdp(model)

test_cls.assertEqual(losses[0], losses[1])
if compile_transformer_block:
torch.testing.assert_close(losses[0], losses[1], atol=9.5e-2, rtol=9.5e-2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems kind of high 🤔 I wonder how this value was determined. Can we instead compare the ref as also compiling each transformer block (but without FSDP applied)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will try to switch the ref model to Float8Linear + torch.compiled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after applying torch.compile to ref_model, we can achieve atol/rtol=1e-4. I can dig more as follow ups if we want to reach higher numeric parity like 1e-5

@@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure tensor is like tensor(4674.8633) instead of tensor([[4674.8633]]

otherwise torch.compile errors out in gurads, torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the traces? I want to make sure there is no CPU sync point introduced from making this tensor a scalar tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the traces? I want to make sure there is no CPU sync point introduced from making this tensor a scalar tensor

sure. I checked the trace and it's seems to be purely executed on cpu, no kernels launch, and no cudaStreamSynchronize if that's what you refer as "CPU sync point"
Screenshot 2024-07-17 at 2 26 24 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when the scalar is used later downstream.

Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise.

TL;DR: this is more like a bug when I implement precompute_float8_dynamic_scale_for_fsdp

for the 1st iteration, self._precomputed_scale is None and thus we calcuclate scale through cast_to_float8_e4m3_dynamic (code) , where scale are in tensor(4674.8633). Dynamo generates a guard assersion on tensor(4674.8633).size() and tensor(4674.8633).stride(), so it expect same input shapes in 2nd iteration

for the 2nd iteration after precompute_float8_dynamic_scale_for_fsdp, we have self._precomputed_scale=tensor([[4674.8633]]) because I only called torch.split(scale_tensor, 1) without .squeeze. Guard assersion find out .size() and .stride() changed and throw out the error

does it make sense to say this is a bug in user code, instead of a misfunction in dynamo ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when the scalar is used later downstream.

ah, I see. I should be looking for cudaStreamSynchronize, right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be like cudaDeviceSynchronize if I understand correctly (but basically you would see the CPU thread blocked).

Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when the scalar is used later downstream.

_precomputed_scale will be used inside fsdp_pre_all_gather when calling following code https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/fsdp_utils.py#L167

float8_tensor = Float8Tensor.to_float8(
    self._tensor,
    self._precomputed_scale,
    torch.float8_e4m3fn,
    mm_config=self._mm_config,
)

I annotated the function with record_function("Float8Tensor.to_float8"). Here are the snapshots for cpu thread and cuda stream

Screenshot 2024-07-17 at 4 33 53 PM Screenshot 2024-07-17 at 4 34 31 PM

in both cases, I do not see cudaStreamSynchronize and cuda stream stays ahead of cpu thread

any worries ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! should be fine

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@@ -4,8 +4,6 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional, Tuple
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix linter from the trunk

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot
Copy link
Contributor

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

fully_shard(submodule)
for layer_id, transformer_block in module.layers.named_children():
if compile_transformer_block:
transformer_block = torch.compile(transformer_block, dynamic=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior?

optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us and keep the more advanced "how to apply compile" logic localized to torchtitan.

Copy link
Contributor Author

@weifengpy weifengpy Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior?

This is just trying to match torchtitan's behavior. The .squeeze is needed regardless of compiling transformer blocks or compiling whole model.

optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us

I want to check at PR time that float8_experimental are compatiable with torchtitan (thus compiling transformer block)

for float8_experimental, I am with you it's good to also cover compiling full model.
For FSDP2, it should work. For FSDP+TP, I remember there is some problem in to compile full model. Will see if I can follow up

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, thanks for fixing this!

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

@facebook-github-bot
Copy link
Contributor

@weifengpy merged this pull request in 7f0d6bb.

weifengpy added a commit to pytorch/torchtitan that referenced this pull request Jul 19, 2024
fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants