Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
delayed scaling: stop syncing weight amax values across ranks (#277)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #277

FSDP already ensures that each rank receives the same weight, so the
amaxes of weights are the same on each rank.

I checked performance before/after on the multi GPU benchmark and
didn't see a significant impact on the toy model, but less comms value is better.

Reviewed By: drisspg

Differential Revision: D58396925

fbshipit-source-id: 9dc1253bdd49de4c1cf61843c1d778956981aa0e
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jun 14, 2024
1 parent 323fb48 commit 5d293a7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
5 changes: 2 additions & 3 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
torch.manual_seed(0)

# TODO: Add more shapes for the benchmark
B, M, K, N = 32, 32, 32, 32
B, M, K, N = 32, 1024, 1024, 1024
lr = 0.01


Expand Down Expand Up @@ -152,8 +152,7 @@ def run_n_iterations(n, fn):
cleanup()


def run():
compile = True
def run(compile: bool):
base_dtype = torch.bfloat16
WORLD_SIZE = torch.cuda.device_count()
print(f"{base_dtype = }")
Expand Down
9 changes: 7 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
scale_fn_name,
float8_dtype,
is_initialized,
reduce_amax,
):
"""
If x is about to be cast to `float8` and the amax buffers are not initialized,
Expand All @@ -41,8 +42,9 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
return
with torch.no_grad():
# Note: we need to enable distributed reduction here in order
# to match numerics between single GPU and multi GPU code
new_amax = tensor_to_amax(x, reduce_amax=True)
# to match numerics between single GPU and multi GPU code for
# activations and gradients
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
Expand Down Expand Up @@ -89,6 +91,7 @@ def backward(ctx, go):
scale_fn_name,
torch.float8_e5m2,
is_amax_initialized,
reduce_amax=True,
)

fp8_amax_dL_dY.fill_(tensor_to_amax(go))
Expand Down Expand Up @@ -235,6 +238,7 @@ def cast_x_to_float8(
scale_fn_name,
torch.float8_e4m3fn,
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
Expand All @@ -257,6 +261,7 @@ def cast_w_to_float8(
scale_fn_name,
torch.float8_e4m3fn,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
Expand Down
10 changes: 4 additions & 6 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
"""
Manages the float8 amax and scale bookkeeping. In detail, it does the
following:
1. in distributed contexts, syncs amax values across workers
1. in distributed contexts, syncs amax values across workers for activations and gradients
2. adds the `amax` values to history
3. calculates the scales to be used for next iteration
4. sets the `amax_and_scale_synced` flag on the Float8Linear modules
Expand Down Expand Up @@ -262,10 +262,10 @@ def inner_func():

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
# Note: do not reduce the weight values, because FSDP already ensures
# the weight values on all ranks are the same after all-gather.
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list
+ fp8_amax_w_tensor_list
+ fp8_amax_dL_dY_tensor_list
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list
)
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
Expand All @@ -275,13 +275,11 @@ def inner_func():

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
Expand Down

0 comments on commit 5d293a7

Please sign in to comment.