Skip to content

Commit

Permalink
Handle the scaling factor when amax is too tiny that leads to an infi…
Browse files Browse the repository at this point in the history
…nite scale (NVIDIA#786)

* Handle the scaling factor when amax is too tiny that leads to an infinite scale

Signed-off-by: Jinze Xue <[email protected]>

* revert formatting changes

Signed-off-by: Jinze Xue <[email protected]>

* fix comments

Signed-off-by: Jinze Xue <[email protected]>

* Apply review suggestion

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>

* Apply review suggestion

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>

* Apply review suggestion

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>

* apply review suggestion

Signed-off-by: Jinze Xue <[email protected]>

* add test_recipe.py to qa/L0_pytorch_unittest/test.sh; fix unittest for is_first_microbatch=False

Signed-off-by: Jinze Xue <[email protected]>

* revert changes to update_weight_scale_inv

Signed-off-by: Jinze Xue <[email protected]>

* Debug test failures

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Jinze Xue <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Jinze Xue <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
4 people authored May 1, 2024
1 parent a817868 commit 7acb5e2
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 2 deletions.
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set -e

pip install pytest==7.2 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
Expand Down
98 changes: 97 additions & 1 deletion tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
amax_and_scale_update,
_amax_and_scale_update,
get_default_fp8_recipe,
)

Expand Down Expand Up @@ -162,3 +163,98 @@ def test_amax_and_scale_update(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)

@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
@pytest.mark.parametrize(
"fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"]
)
def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype):

if fp8_dtype == tex.DType.kFloat8E4M3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
fp8_max = fp8_format.value.max_fwd
elif fp8_dtype == tex.DType.kFloat8E5M2:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
fp8_max = fp8_format.value.max_bwd
else:
raise ValueError(f"{fp8_dtype=} is not supported")

scaling_factor_compute_algo = None
if fused_update:
scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe:
te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin)
)
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
)

# Setup fp8_meta dictionary
def setup_fp8_meta():
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))
return module.fp8_meta

fp8_meta = setup_fp8_meta()
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)

# Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose
fp8_meta[forward_key] = tex.FP8TensorMeta()
fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda")
fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")

# test different scenarios
if amax_case == "zero":
fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda")
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "tiny":
# calculate the minimum amax value that results in a FP32 maximum scale
fp32_max = torch.tensor(torch.finfo(torch.float32).max)
tiny_amax = fp8_max / fp32_max
# make the amax less than the minimum amax so that the scale will be infinite
amax_value = tiny_amax / 2
fp8_meta[forward_key].amax_history = torch.tensor(
[[amax_value]], dtype=torch.float32, device="cuda"
)
# expected scale is FP32_max
expected_scale = fp32_max.view(1).cuda()
elif amax_case == "normal":
# plus a small epsilon to avoid zero amax
amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5
fp8_meta[forward_key].amax_history = amax_value.view(1, 1)
expected_scale = fp8_max / amax_value
elif amax_case == "inf":
fp8_meta[forward_key].amax_history = torch.tensor(
[[torch.inf]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "nan":
fp8_meta[forward_key].amax_history = torch.tensor(
[[torch.nan]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")

if fused_update:
tex.fused_amax_and_scale_update_after_reduction(
fp8_meta[forward_key].amax_history.clone().view(-1),
[fp8_meta[forward_key].amax_history],
[fp8_meta[forward_key].scale],
[fp8_meta[forward_key].scale_inv],
recipe.amax_compute_algo,
fp8_dtype,
recipe.margin,
)
else:
_amax_and_scale_update(
fp8_meta[forward_key].amax_history,
fp8_meta[forward_key].scale,
fp8_meta[forward_key].scale_inv,
fp8_max,
recipe,
)

torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale))
26 changes: 26 additions & 0 deletions transformer_engine/common/recipe/delayed_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cmath>
#include <string>
#include <limits>

#include "../common.h"
#include "../util/logging.h"
Expand Down Expand Up @@ -151,6 +152,13 @@ kernel(const float* amax_history_ptr,
} else {
scale = scale_ptr[bid];
}
// When the amax is too tiny that the scale becoming infinite in FP32,
// we set the scale to the max value of FP32. In this case, the tensor’s
// amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
updated_scale_ptr[bid] = scale;

// Update scale inverse
Expand Down Expand Up @@ -239,12 +247,30 @@ kernel_bulk(

// Update scale and scale inverse
if (tid == 0) {
// Computing the scaling factor requires consideration of the following scenarios:
// 1. amax == 0:
// No action is possible, set scale to the previous scale (or 1).
// 2. 0 < amax < tiny_amax
// The amax is too tiny that the scale becomes infinite in FP32.
// Set scale = FP32_max
// 3. tiny_amax <= amax < FP32_max:
// Set scale = FP8_max (or scaled_max) / amax
// 4. When amax == inf or amax == nan:
// No action is possible, set scale to the previous scale (or 1).

float scale;
if (isfinite(amax) && amax > 0) {
scale = scaled_max / amax;
} else {
scale = p.param[bid].scale[count];
}
// When the amax is too tiny that the scale becoming infinite in FP32,
// we set the scale to the max value of FP32. In this case, the tensor’s
// amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale;
}
Expand Down
15 changes: 14 additions & 1 deletion transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,24 @@ def _default_sf_compute(
scale: torch.Tensor,
fp8_max: float,
margin: int,
_fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter
) -> torch.Tensor:
"""Default function to convert amax to scaling factor."""
"""Default function to convert amax to scaling factor.
Computing the scaling factor requires consideration of the following scenarios:
1. amax == 0:
No action is possible, set scale to the previous scale (or 1).
2. 0 < amax < tiny_amax
The amax is too tiny that the scale becomes infinite in FP32.
Set scale = FP32_max
3. tiny_amax <= amax < FP32_max:
Set scale = FP8_max (or scaled_max) / amax
4. When amax == inf or amax == nan:
No action is possible, set scale to the previous scale (or 1).
"""
sf = (fp8_max / amax) / (2 ** margin)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
scale.copy_(sf)
return scale

Expand Down

0 comments on commit 7acb5e2

Please sign in to comment.