diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index ded45dd377..2c14664dce 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -8,6 +8,7 @@ set -e pip install pytest==7.2 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py +pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 6b65960ec6..92c7f26f59 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -9,9 +9,10 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te +import transformer_engine_extensions as tex from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, - amax_and_scale_update, + _amax_and_scale_update, get_default_fp8_recipe, ) @@ -162,3 +163,98 @@ def test_amax_and_scale_update( fp8_meta[backward_key].scale_inv, ref_scale_inv_backward, ) + + @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) + @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) + @pytest.mark.parametrize( + "fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"] + ) + def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype): + + if fp8_dtype == tex.DType.kFloat8E4M3: + fp8_format = transformer_engine.common.recipe.Format.E4M3 + fp8_max = fp8_format.value.max_fwd + elif fp8_dtype == tex.DType.kFloat8E5M2: + fp8_format = transformer_engine.common.recipe.Format.HYBRID + fp8_max = fp8_format.value.max_bwd + else: + raise ValueError(f"{fp8_dtype=} is not supported") + + scaling_factor_compute_algo = None + if fused_update: + scaling_factor_compute_algo = ( + lambda amax, scale, fp8_max, recipe: + te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin) + ) + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo + ) + + # Setup fp8_meta dictionary + def setup_fp8_meta(): + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + module = te.Linear(16, 16) + y = module(torch.zeros([16, 16], device="cuda")) + y.backward(torch.zeros_like(y)) + return module.fp8_meta + + fp8_meta = setup_fp8_meta() + forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + + # Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose + fp8_meta[forward_key] = tex.FP8TensorMeta() + fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda") + fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") + + # test different scenarios + if amax_case == "zero": + fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda") + expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + elif amax_case == "tiny": + # calculate the minimum amax value that results in a FP32 maximum scale + fp32_max = torch.tensor(torch.finfo(torch.float32).max) + tiny_amax = fp8_max / fp32_max + # make the amax less than the minimum amax so that the scale will be infinite + amax_value = tiny_amax / 2 + fp8_meta[forward_key].amax_history = torch.tensor( + [[amax_value]], dtype=torch.float32, device="cuda" + ) + # expected scale is FP32_max + expected_scale = fp32_max.view(1).cuda() + elif amax_case == "normal": + # plus a small epsilon to avoid zero amax + amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5 + fp8_meta[forward_key].amax_history = amax_value.view(1, 1) + expected_scale = fp8_max / amax_value + elif amax_case == "inf": + fp8_meta[forward_key].amax_history = torch.tensor( + [[torch.inf]], dtype=torch.float32, device="cuda" + ) + expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + elif amax_case == "nan": + fp8_meta[forward_key].amax_history = torch.tensor( + [[torch.nan]], dtype=torch.float32, device="cuda" + ) + expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + if fused_update: + tex.fused_amax_and_scale_update_after_reduction( + fp8_meta[forward_key].amax_history.clone().view(-1), + [fp8_meta[forward_key].amax_history], + [fp8_meta[forward_key].scale], + [fp8_meta[forward_key].scale_inv], + recipe.amax_compute_algo, + fp8_dtype, + recipe.margin, + ) + else: + _amax_and_scale_update( + fp8_meta[forward_key].amax_history, + fp8_meta[forward_key].scale, + fp8_meta[forward_key].scale_inv, + fp8_max, + recipe, + ) + + torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) + torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index de48a53ebf..2e232f50e2 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -8,6 +8,7 @@ #include #include +#include #include "../common.h" #include "../util/logging.h" @@ -151,6 +152,13 @@ kernel(const float* amax_history_ptr, } else { scale = scale_ptr[bid]; } + // When the amax is too tiny that the scale becoming infinite in FP32, + // we set the scale to the max value of FP32. In this case, the tensor’s + // amax won't get mapped to the FP8 max representable, but rather + // something below that, but this is the best thing we can do. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } updated_scale_ptr[bid] = scale; // Update scale inverse @@ -239,12 +247,30 @@ kernel_bulk( // Update scale and scale inverse if (tid == 0) { + // Computing the scaling factor requires consideration of the following scenarios: + // 1. amax == 0: + // No action is possible, set scale to the previous scale (or 1). + // 2. 0 < amax < tiny_amax + // The amax is too tiny that the scale becomes infinite in FP32. + // Set scale = FP32_max + // 3. tiny_amax <= amax < FP32_max: + // Set scale = FP8_max (or scaled_max) / amax + // 4. When amax == inf or amax == nan: + // No action is possible, set scale to the previous scale (or 1). + float scale; if (isfinite(amax) && amax > 0) { scale = scaled_max / amax; } else { scale = p.param[bid].scale[count]; } + // When the amax is too tiny that the scale becoming infinite in FP32, + // we set the scale to the max value of FP32. In this case, the tensor’s + // amax won't get mapped to the FP8 max representable, but rather + // something below that, but this is the best thing we can do. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } p.param[bid].scale[count] = scale; p.param[bid].scale_inv[count] = 1 / scale; } diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 1f359d4864..b28e380473 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -598,11 +598,24 @@ def _default_sf_compute( scale: torch.Tensor, fp8_max: float, margin: int, + _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter ) -> torch.Tensor: - """Default function to convert amax to scaling factor.""" + """Default function to convert amax to scaling factor. + Computing the scaling factor requires consideration of the following scenarios: + 1. amax == 0: + No action is possible, set scale to the previous scale (or 1). + 2. 0 < amax < tiny_amax + The amax is too tiny that the scale becomes infinite in FP32. + Set scale = FP32_max + 3. tiny_amax <= amax < FP32_max: + Set scale = FP8_max (or scaled_max) / amax + 4. When amax == inf or amax == nan: + No action is possible, set scale to the previous scale (or 1). + """ sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) scale.copy_(sf) return scale