-
Notifications
You must be signed in to change notification settings - Fork 20
add unit tests for FSDP2 + torch.compile(transformer block) #321
Changes from all commits
b5cad8d
a6b8913
272e85b
097ceed
b6ebf8d
2eaa51b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | ||
scales = torch.split(scale_tensor, 1) # Replicate | ||
for scale, float8_linear in zip(scales, float8_linears): | ||
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor | ||
float8_linear.weight._local_tensor._precomputed_scale = ( | ||
scale._local_tensor.squeeze() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure tensor is like otherwise torch.compile errors out in gurads, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we check the traces? I want to make sure there is no CPU sync point introduced from making this tensor a scalar tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean when the scalar is used later downstream. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
TL;DR: this is more like a bug when I implement for the 1st iteration, for the 2nd iteration after does it make sense to say this is a bug in user code, instead of a misfunction in dynamo ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
ah, I see. I should be looking for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I annotated the function with in both cases, I do not see any worries ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good! should be fine |
||
) | ||
|
||
|
||
# FSDP pads its local tensor on dim-0. The subclass should be preserved such | ||
|
@@ -301,7 +303,7 @@ def __tensor_flatten__(self): | |
], | ||
{ | ||
"mm_config": self._mm_config, | ||
"is_amax_initialized": is_amax_initialized, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"is_amax_initialized": self.is_amax_initialized, | ||
}, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,7 @@ def test_transformer_parity(self): | |
TensorScalingType.DYNAMIC, | ||
TensorScalingType.DELAYED, | ||
], | ||
"compile_transformer_block": [False, True], | ||
}, | ||
self._test_transformer_parity, | ||
) | ||
|
@@ -98,6 +99,7 @@ def _test_transformer_parity( | |
enable_fsdp_fp8_all_gather: bool, | ||
precompute: bool, | ||
scaling_type_w: TensorScalingType, | ||
compile_transformer_block: bool, | ||
): | ||
if not enable_fsdp_fp8_all_gather and precompute: | ||
return | ||
|
@@ -112,11 +114,17 @@ def _test_transformer_parity( | |
module = self.init_transformer(weight_tying=weight_tying).cuda() | ||
ref_module = copy.deepcopy(module) | ||
swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) | ||
if compile_transformer_block: | ||
for layer_id, transformer_block in ref_module.layers.named_children(): | ||
transformer_block = torch.compile(transformer_block, dynamic=False) | ||
ref_module.layers.register_module(layer_id, transformer_block) | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) | ||
for submodule in module.modules(): | ||
if isinstance(submodule, TransformerBlock): | ||
fully_shard(submodule) | ||
for layer_id, transformer_block in module.layers.named_children(): | ||
if compile_transformer_block: | ||
transformer_block = torch.compile(transformer_block, dynamic=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior? optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us and keep the more advanced "how to apply compile" logic localized to torchtitan. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is just trying to match torchtitan's behavior. The
I want to check at PR time that float8_experimental are compatiable with torchtitan (thus compiling transformer block) for float8_experimental, I am with you it's good to also cover compiling full model. |
||
fully_shard(transformer_block) | ||
module.layers.register_module(layer_id, transformer_block) | ||
fully_shard(module) | ||
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) | ||
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) | ||
|
@@ -132,6 +140,7 @@ def _test_transformer_parity( | |
local_inp, | ||
precompute, | ||
scaling_type_w=scaling_type_w, | ||
compile_transformer_block=compile_transformer_block, | ||
) | ||
|
||
@skip_if_lt_x_gpu(2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix linter from the trunk