Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Auto scaling factor tuning for FP8 collective communication #140

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
fp8e4m3 for wgrad
wkcn committed Dec 7, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 69f73ce35b6a1d48e441643d2cefab2388b54423
13 changes: 6 additions & 7 deletions msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
@@ -547,16 +547,15 @@ def reduce_model_grads(self, args, timers): # noqa: C901
for p in model_group:
g = p.main_grad
if g is not None and not torch.is_tensor(g):
if g.qtype != WGRAD_QTYPE:
raise TypeError('g.qtype != WGRAD_QTYPE: {} != {}'.format(g.qtype, WGRAD_QTYPE))
if g.qtype != Dtypes.kfloat8_e4m3:
raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype))
# stat overflow ratio
num_infs = torch.count_nonzero((g.value & 0x7f) == 126)
overflow_ratio = num_infs / g.numel()
if args.wgrad_auto_scaling_ratio is not None:
if overflow_ratio > args.wgrad_auto_scaling_ratio:
g.meta.pre_scale /= 2.0
else:
g.meta.pre_scale *= 2.0**(1.0 / args.wgrad_auto_scaling_window)
if overflow_ratio > args.wgrad_auto_scaling_ratio:
g.meta.pre_scale /= 2.0
else:
g.meta.pre_scale *= 2.0**(1.0 / args.wgrad_auto_scaling_window)

# synchonize pre_scale
for model_id, model in enumerate(self.models):