-
Notifications
You must be signed in to change notification settings - Fork 20
[FSDP2] precompute scale after optimizer.step for dynamic scaling #266
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
requires_grad=tensor.requires_grad, | ||
) | ||
|
||
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): | ||
self._tensor = tensor | ||
self._mm_config = mm_config | ||
# Optional cache for pre-computed fp8 data/scale | ||
self._fp8_data: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One major requirement for tensor subclasses that I don't think is respected here: __tensor_flatten__
and __tensor_unflatten__
must properly convey every inner tensor on the subclass.
So when we call __tensor_flatten__
on this subclass, if either of _fp8_data/scale/amax
are set to valid tensors, they need to be returned there (and similarly __tensor_unflatten__
needs to handle them as extra args)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing this out! This saves me a lot of debugging time. I can give it a try by including _fp8_data/scale/amax
in __tensor_flatten__
and __tensor_unflatten__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile
works after patching pytorch/pytorch#127431
will compare traces in 2nd PR
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
def compute_amaxes(weights: List[DTensor]): | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) | ||
amax_tensor = torch.vstack(max_weights) | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # R |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.clamp
calls all_reduce
. I avoided calling it again in amax_to_scale(clamp_amax=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you are relying on torch.clamp
to run the all-reduce implicitly from changing sharding from partial to replicate?
If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to amax_to_scale
? I agree the current way is faster since we are doing one clamp for all amaxes, but in case float8 folks are not happy with this fragmentation, this seems like another way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the suggestions. I can collect feedback from float8 folks if they have a preference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
@@ -190,9 +191,20 @@ def __repr__(self): | |||
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" | |||
|
|||
def fsdp_pre_all_gather(self, mesh): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if _pre_computed_amax
, we skip tensor_to_amax
and directly do amax_to_scale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me! I want to check with float8 folks on the amax_to_scale
change.
def compute_amaxes(weights: List[DTensor]): | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) | ||
amax_tensor = torch.vstack(max_weights) | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # R |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you are relying on torch.clamp
to run the all-reduce implicitly from changing sharding from partial to replicate?
If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to amax_to_scale
? I agree the current way is faster since we are doing one clamp for all amaxes, but in case float8 folks are not happy with this fragmentation, this seems like another way.
nice! Can we include the intended user API in the PR summary? |
@@ -151,6 +151,7 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): | |||
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): | |||
self._tensor = tensor | |||
self._mm_config = mm_config | |||
self._pre_computed_amax = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be added to __tensor_flatten__
?
can we add some comments on intended usage of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on adding to flatten/unflatten and comments/ intended usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -322,3 +328,34 @@ def inner_func(): | |||
for child in fp8_layers: | |||
# Set a flag to signal amaxes/scales are ready | |||
child.amax_and_scale_synced = True | |||
|
|||
|
|||
def precompute_float8_amax(module: nn.Module) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we put this in distributed_utils.py
?
I think the function name should include that this is intended for FSDP2 with float8 all-gather
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moving to fsdp_utils.py
according to PR #310
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indicating fsdp by renaming to precompute_float8_amax_for_fsdp
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
def compute_amaxes(weights: List[DTensor]): | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a comment that this is equivalent to max(abs(w))?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def compute_amaxes(weights: List[DTensor]): | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) | ||
amax_tensor = torch.vstack(max_weights) | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # R |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic.
float8_experimental/float8_utils.py
Outdated
""" | ||
scale = torch.empty_like(amax, dtype=torch.float32) | ||
if float8_dtype in FP8_TYPES: | ||
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) | ||
if clamp_amax: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think if you have this on a seperate line
amax = clamp(amax, eps) if clamp_amax else amax
makes the logic a lil easier to follow
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -130,20 +146,38 @@ def unwrap(t): | |||
) | |||
|
|||
def __tensor_flatten__(self): | |||
return ["_tensor"], self._mm_config | |||
if self._precomputed_amax: | |||
return ["_tensor", "_precomputed_amax"], self._mm_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does having Optional[torch.Tensor]
as a subclass field work with torch.compile? Or do we not care about torch.compile in this code path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile assumes every tensor from __tensor_flatten__
is not None. I added if-else
to make torch.compile work. I verified it in pytorch/pytorch#129457
float8_experimental/float8_utils.py
Outdated
): | ||
"""Converts the amax value of a tensor to the fp8 scale. | ||
Args: | ||
amax: The amax value of the tensor. | ||
float8_dtype: The float8 dtype. | ||
orig_dtype: The original dtype of the tensor. | ||
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit confusing. How about precomputing the scale instead so we don't have to have gotchas like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good suggestion! I changed the API to precompute scale and it shows another 9% speed up in unit test vs precomputing amax
fsdp_pre_all_gather
is also greatly simplified because of using self._precomputed_scale
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
README.md
Outdated
y.sum().backward() | ||
optimizer.step() | ||
|
||
# specific to fsdp2 + float8 with dynamic scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we say that this is specific to FSDP2 with float8 all-gather turned on? Also, maybe we can show how to turn that on, since I don't think it's documented in the README yet? Can be a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we say that this is specific to FSDP2 with float8 all-gather turned on?
will change it in this PR
maybe we can show how to turn that on, since I don't think it's documented in the README yet
good catch. will polish README again after landing changes in torchtitan to turn on/off fp8 all-gather
for m in module.modules() | ||
): | ||
raise NotImplementedError("Only supports delayed scaling") | ||
float8_linears: List[Float8Linear] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this expensive for real models? if yes, maybe we can offer option to precompute this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intuition is that this should be pretty fast as the number of nn.Module
s in the model is usually at most in the thousands and this is pure Python overhead. @weifengpy you can check the traces you have if you see any noticeable gaps from this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float8_experimental/fsdp_utils.py
Outdated
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
def compute_scales(weights: List[DTensor]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional nit: maybe move outside to prevent nested functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious what is the downside of nested functions
@weifengpy By the way, this was originally a nested function just so that we could try to torch.compile
it effectively in the scales = compute_scales(weights)
line. Does it still need to be a separate function for torch.compile
reasons? If so, we should probably add a comment before the def compute_scales
mentioning that it is separate for torch.compile
; otherwise, we can consider inlining the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove nested functions to make the code easy to read. I profiled the unit test and precompute_float8_scale_for_fsdp
takes 1.9ms. that's a tiny portion of the overall training loop. No obvious reason to speed up with torch.compile
yet. I can bring back the nested function in case we need torch.compile
again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! had some final comments. thanks for doing this!
float8_experimental/fsdp_utils.py
Outdated
from float8_experimental.float8_utils import EPS | ||
|
||
|
||
def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a @torch.no_grad()
decorator on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. adding @torch.no_grad()
float8_experimental/fsdp_utils.py
Outdated
Calculate scale for all float8 parameters after optimizer step | ||
It performs a single all-reduce instead of many all-reduces for each parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion:
Calculate scale for all float8 parameters after optimizer step | |
It performs a single all-reduce instead of many all-reduces for each parameter | |
Calculate scale dynamically for all float8 parameters. | |
This should be run after the optimizer step. It performs a single all-reduce to compute the | |
amaxes for all float8 weights. |
float8_experimental/fsdp_utils.py
Outdated
""" | ||
Calculate scale for all float8 parameters after optimizer step | ||
It performs a single all-reduce instead of many all-reduces for each parameter | ||
Exmaple usage: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (typo):
Exmaple usage: | |
Example usage: |
@vkuzo I assume that there are no docs builds for float8_experimental
, so this example is for users who will read the code itself?
Otherwise, we might need to check the formatting -- I recall the format for examples being a bit different.
for m in module.modules() | ||
): | ||
raise NotImplementedError("Only supports delayed scaling") | ||
float8_linears: List[Float8Linear] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intuition is that this should be pretty fast as the number of nn.Module
s in the model is usually at most in the thousands and this is pure Python overhead. @weifengpy you can check the traces you have if you see any noticeable gaps from this.
float8_experimental/fsdp_utils.py
Outdated
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
def compute_scales(weights: List[DTensor]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious what is the downside of nested functions
@weifengpy By the way, this was originally a nested function just so that we could try to torch.compile
it effectively in the scales = compute_scales(weights)
line. Does it still need to be a separate function for torch.compile
reasons? If so, we should probably add a comment before the def compute_scales
mentioning that it is separate for torch.compile
; otherwise, we can consider inlining the function.
float8_experimental/fsdp_utils.py
Outdated
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor | ||
else: | ||
warnings.warn( | ||
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function name in the warning needs to be updated
I am okay with not including this warning by the way. This was also to help debugging to make sure we actually found weights
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got you. I am removing the warnings for simplicity
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@weifengpy merged this pull request in 6cba2ae. |
we have landed fp8 all-gather optimizations in float8_experimental pytorch-labs/float8_experimental#266 this PR proposes torchtitan changes. also include fp8 in CI ``` from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp # inside the training loop model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model) ``` FSDP2 fp8 all-gather are added to CI ``` CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp ``` TP fp8 all-gather are locally tested. will add them to CI after uploading a new tokenizer with vacab size 2560 (divisible by 16) ``` CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2 ``` precompute scales after optimizer.step <img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM" src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817"> FSDP2 pre-all-gather do not have any small all-reduces <img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM" src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08"> TODO * upload tokenizer with vacab size 2560 to enable CI on TP fp8 all-gather * torch.compile complains about fp8 * add delayed scaling and brainstorm about best config option to express fp8 * compare perf between delayed scaling and dynamic scaling https://github.com/pytorch-labs/float8_experimental/pull/312/files
Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce
updated README for API usage: call
precompute_float8_scale_for_fsdp
inside the training loop after optimizer stepunit test
pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic
FSDP pre-forward: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces
Pre-computing amax: shortened from 5ms to 1.7ms, by switching from
torch._foreach_abs
+torch.max(a)
totorch._foreach_norm(weights, ord=math.inf)