Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[FSDP2] precompute scale after optimizer.step for dynamic scaling #266

Closed
wants to merge 32 commits into from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented May 21, 2024

Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce

updated README for API usage: call precompute_float8_scale_for_fsdp inside the training loop after optimizer step

from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)

unit test pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic

FSDP pre-forward: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces
Screenshot 2024-05-30 at 12 38 24 AM

Screenshot 2024-05-30 at 12 48 14 AM

Pre-computing amax: shortened from 5ms to 1.7ms, by switching from torch._foreach_abs + torch.max(a) to torch._foreach_norm(weights, ord=math.inf)

Screenshot 2024-05-30 at 12 50 17 AM Screenshot 2024-05-30 at 12 49 54 AM

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 21, 2024
@weifengpy weifengpy marked this pull request as draft May 23, 2024 23:16
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
self._tensor = tensor
self._mm_config = mm_config
# Optional cache for pre-computed fp8 data/scale
self._fp8_data: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One major requirement for tensor subclasses that I don't think is respected here: __tensor_flatten__ and __tensor_unflatten__ must properly convey every inner tensor on the subclass.

So when we call __tensor_flatten__ on this subclass, if either of _fp8_data/scale/amax are set to valid tensors, they need to be returned there (and similarly __tensor_unflatten__ needs to handle them as extra args)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out! This saves me a lot of debugging time. I can give it a try by including _fp8_data/scale/amax in __tensor_flatten__ and __tensor_unflatten__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile works after patching pytorch/pytorch#127431
will compare traces in 2nd PR

weifengpy and others added 8 commits May 30, 2024 00:30
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy changed the title [DO NOT LAND] precast after optimizer.step and dump profiler traces [FSDP2] pre-compute amax after optimizer.step for dynamic scaling Jun 6, 2024
def compute_amaxes(weights: List[DTensor]):
max_weights = torch._foreach_norm(weights, ord=math.inf)
amax_tensor = torch.vstack(max_weights)
amax_tensor = torch.clamp(amax_tensor, EPS) # R
Copy link
Contributor Author

@weifengpy weifengpy Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.clamp calls all_reduce. I avoided calling it again in amax_to_scale(clamp_amax=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are relying on torch.clamp to run the all-reduce implicitly from changing sharding from partial to replicate?

If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to amax_to_scale? I agree the current way is faster since we are doing one clamp for all amaxes, but in case float8 folks are not happy with this fragmentation, this seems like another way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestions. I can collect feedback from float8 folks if they have a preference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@@ -190,9 +191,20 @@ def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _pre_computed_amax, we skip tensor_to_amax and directly do amax_to_scale

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as ready for review June 6, 2024 08:21
@weifengpy weifengpy requested review from vkuzo, awgu and drisspg June 6, 2024 08:21
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me! I want to check with float8 folks on the amax_to_scale change.

def compute_amaxes(weights: List[DTensor]):
max_weights = torch._foreach_norm(weights, ord=math.inf)
amax_tensor = torch.vstack(max_weights)
amax_tensor = torch.clamp(amax_tensor, EPS) # R
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are relying on torch.clamp to run the all-reduce implicitly from changing sharding from partial to replicate?

If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to amax_to_scale? I agree the current way is faster since we are doing one clamp for all amaxes, but in case float8 folks are not happy with this fragmentation, this seems like another way.

@vkuzo
Copy link
Contributor

vkuzo commented Jun 6, 2024

nice! Can we include the intended user API in the PR summary?

@@ -151,6 +151,7 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
self._tensor = tensor
self._mm_config = mm_config
self._pre_computed_amax = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be added to __tensor_flatten__?

can we add some comments on intended usage of this?

Copy link
Contributor

@drisspg drisspg Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on adding to flatten/unflatten and comments/ intended usage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -322,3 +328,34 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal amaxes/scales are ready
child.amax_and_scale_synced = True


def precompute_float8_amax(module: nn.Module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this in distributed_utils.py?

I think the function name should include that this is intended for FSDP2 with float8 all-gather

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving to fsdp_utils.py according to PR #310

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indicating fsdp by renaming to precompute_float8_amax_for_fsdp

weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_amaxes(weights: List[DTensor]):
max_weights = torch._foreach_norm(weights, ord=math.inf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment that this is equivalent to max(abs(w))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def compute_amaxes(weights: List[DTensor]):
max_weights = torch._foreach_norm(weights, ord=math.inf)
amax_tensor = torch.vstack(max_weights)
amax_tensor = torch.clamp(amax_tensor, EPS) # R
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic.

"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
if clamp_amax:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think if you have this on a seperate line
amax = clamp(amax, eps) if clamp_amax else amax

makes the logic a lil easier to follow

@facebook-github-bot
Copy link
Contributor

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -130,20 +146,38 @@ def unwrap(t):
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config
if self._precomputed_amax:
return ["_tensor", "_precomputed_amax"], self._mm_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does having Optional[torch.Tensor] as a subclass field work with torch.compile? Or do we not care about torch.compile in this code path?

Copy link
Contributor Author

@weifengpy weifengpy Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile assumes every tensor from __tensor_flatten__ is not None. I added if-else to make torch.compile work. I verified it in pytorch/pytorch#129457

):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing. How about precomputing the scale instead so we don't have to have gotchas like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion! I changed the API to precompute scale and it shows another 9% speed up in unit test vs precomputing amax

fsdp_pre_all_gather is also greatly simplified because of using self._precomputed_scale

@weifengpy weifengpy marked this pull request as draft July 10, 2024 21:15
weifengpy and others added 5 commits July 10, 2024 14:16
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy changed the title [FSDP2] precompute amax after optimizer.step for dynamic scaling [FSDP2] precompute scale after optimizer.step for dynamic scaling Jul 10, 2024
@weifengpy weifengpy marked this pull request as ready for review July 10, 2024 23:43
@facebook-github-bot
Copy link
Contributor

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@weifengpy weifengpy requested a review from vkuzo July 10, 2024 23:43
README.md Outdated
y.sum().backward()
optimizer.step()

# specific to fsdp2 + float8 with dynamic scaling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we say that this is specific to FSDP2 with float8 all-gather turned on? Also, maybe we can show how to turn that on, since I don't think it's documented in the README yet? Can be a follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we say that this is specific to FSDP2 with float8 all-gather turned on?

will change it in this PR

maybe we can show how to turn that on, since I don't think it's documented in the README yet

good catch. will polish README again after landing changes in torchtitan to turn on/off fp8 all-gather

for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
float8_linears: List[Float8Linear] = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this expensive for real models? if yes, maybe we can offer option to precompute this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition is that this should be pretty fast as the number of nn.Modules in the model is usually at most in the thousands and this is pure Python overhead. @weifengpy you can check the traces you have if you see any noticeable gaps from this.

Copy link
Contributor Author

@weifengpy weifengpy Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checked the profiler traces. it's roughly 0.15ms cpu overhead (5% of precompute_float8_dynamic_scale_for_fsdp and is tiny portion of 1 training loop). no cuda are launched

thus I am keeping it as is now for simplicity
Screenshot 2024-07-11 at 2 45 17 PM

]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_scales(weights: List[DTensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional nit: maybe move outside to prevent nested functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what is the downside of nested functions

@weifengpy By the way, this was originally a nested function just so that we could try to torch.compile it effectively in the scales = compute_scales(weights) line. Does it still need to be a separate function for torch.compile reasons? If so, we should probably add a comment before the def compute_scales mentioning that it is separate for torch.compile; otherwise, we can consider inlining the function.

Copy link
Contributor Author

@weifengpy weifengpy Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove nested functions to make the code easy to read. I profiled the unit test and precompute_float8_scale_for_fsdp takes 1.9ms. that's a tiny portion of the overall training loop. No obvious reason to speed up with torch.compile yet. I can bring back the nested function in case we need torch.compile again.

Screenshot 2024-07-11 at 1 50 46 PM

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! had some final comments. thanks for doing this!

from float8_experimental.float8_utils import EPS


def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a @torch.no_grad() decorator on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. adding @torch.no_grad()

Comment on lines 15 to 16
Calculate scale for all float8 parameters after optimizer step
It performs a single all-reduce instead of many all-reduces for each parameter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion:

Suggested change
Calculate scale for all float8 parameters after optimizer step
It performs a single all-reduce instead of many all-reduces for each parameter
Calculate scale dynamically for all float8 parameters.
This should be run after the optimizer step. It performs a single all-reduce to compute the
amaxes for all float8 weights.

"""
Calculate scale for all float8 parameters after optimizer step
It performs a single all-reduce instead of many all-reduces for each parameter
Exmaple usage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (typo):

Suggested change
Exmaple usage:
Example usage:

@vkuzo I assume that there are no docs builds for float8_experimental, so this example is for users who will read the code itself?

Otherwise, we might need to check the formatting -- I recall the format for examples being a bit different.

for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
float8_linears: List[Float8Linear] = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition is that this should be pretty fast as the number of nn.Modules in the model is usually at most in the thousands and this is pure Python overhead. @weifengpy you can check the traces you have if you see any noticeable gaps from this.

]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_scales(weights: List[DTensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what is the downside of nested functions

@weifengpy By the way, this was originally a nested function just so that we could try to torch.compile it effectively in the scales = compute_scales(weights) line. Does it still need to be a separate function for torch.compile reasons? If so, we should probably add a comment before the def compute_scales mentioning that it is separate for torch.compile; otherwise, we can consider inlining the function.

float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name in the warning needs to be updated

I am okay with not including this warning by the way. This was also to help debugging to make sure we actually found weights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. I am removing the warnings for simplicity

@weifengpy weifengpy marked this pull request as draft July 11, 2024 20:55
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as ready for review July 11, 2024 21:53
@facebook-github-bot
Copy link
Contributor

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@weifengpy merged this pull request in 6cba2ae.

weifengpy added a commit to pytorch/torchtitan that referenced this pull request Jul 16, 2024
we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants