diff --git a/.gitignore b/.gitignore index f491b21f43..850b352d31 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ downloads/ .pytest_cache/ compile_commands.json .nfs +tensor_dumps/ \ No newline at end of file diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh old mode 100644 new mode 100755 diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index ce78fcaae2..6785dbf6f4 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -4,6 +4,7 @@ add_executable(test_operator test_cast.cu + test_cast_current_scaling.cu test_cast_dbias.cu test_cast_dbias_dgelu.cu test_cast_gated_swiglu.cu @@ -13,6 +14,7 @@ add_executable(test_operator test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu + test_cast_transpose_current_scaling.cu test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu index f57d1f035d..81c975b0a8 100644 --- a/tests/cpp/operator/test_cast.cu +++ b/tests/cpp/operator/test_cast.cu @@ -35,6 +35,8 @@ void compute_ref(const InputType *data, OutputType *output_c, *amax = current_max; } + +// delayed tensor scaling test template void performTest(const std::vector& shape) { using namespace test; @@ -55,6 +57,7 @@ void performTest(const std::vector& shape) { nvte_quantize(input.data(), output_c.data(), 0); float ref_amax; + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), full_size, &ref_amax, output_c.scale()); @@ -105,6 +108,7 @@ TEST_P(CastTestSuite, TestCast) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // delayed tensor scaling performTest(size); ); ); diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu new file mode 100644 index 0000000000..18325d6daf --- /dev/null +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, + const size_t size, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i] = OutputType(scale * current); + } +} + + +template +void compute_amax_scale_ref(const InputType *data, + const size_t size, + float *amax_ptr, float *scale_ptr, float* scale_inv_ptr, + float max_fp8, float epsilon) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + } + *amax_ptr = current_max; + + // compute scale from amax + float clamp_amax = current_max; + if (current_max <= epsilon){ + clamp_amax = epsilon; + } + + float scale = 1.f; + float scale_inv = 1.f; + + if (isinf(clamp_amax) || clamp_amax == 0.f) { + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; + return; + } + + // use ieee_div in CPU + scale = max_fp8 / clamp_amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + scale_inv = 1.0f / scale; + + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; +} + +// current tensor scaling test +template +void performTest(const std::vector& shape) { + using namespace test; + + const size_t full_size = product(shape); + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + bool is_out_fp8 = isFp8Type(otype); + + // find out max fp8 value + float max_fp8; + if (is_out_fp8){ + switch (otype) { + case DType::kFloat8E5M2: { + max_fp8 = Quantized_Limits::max(); + } break; + case DType::kFloat8E4M3: { + max_fp8 = Quantized_Limits::max(); + } break; + default: + NVTE_ERROR("Invalid type."); + } + } + + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype, true, false); + + std::unique_ptr ref_output_c = std::make_unique(full_size); + + fillUniform(&input); + + // compute amax + float amax_to_check = 0.0f; + if (is_out_fp8){ + nvte_compute_amax(input.data(), output_c.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(output_c.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling + amax_to_check = output_c.amax(); + output_c.set_tensor_amax_nullptr(); + } + nvte_quantize(input.data(), output_c.data(), 0); + + float ref_amax; + float ref_scale; + float ref_scale_inv; + if (is_out_fp8){ + compute_amax_scale_ref(input.rowwise_cpu_dptr(), + full_size, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f); + } + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + full_size, nullptr, is_out_fp8 ? output_c.scale() : 1.0f ); + + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); + compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); + compareResults("scale", output_c.scale(), ref_scale, 0.0f, rtol_fp32); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, 0.0f, rtol); +} + +std::vector> test_cases = { + {16}, + {16000}, + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; +} // namespace + +class CastCSTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastCSTestSuite, TestCastCS) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // current tensor scaling + performTest(size); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastCSTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 830682eec3..380ae96190 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -38,6 +38,8 @@ void compute_ref(const InputType *data, OutputType *output_c, OutputType *output *amax = current_max; } + +// delayed tensor scaling test template void performTest(const size_t N, const size_t H) { using namespace test; @@ -75,6 +77,7 @@ void performTest(const size_t N, const size_t H) { compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } + std::vector> test_cases = {{2048, 12288}, {768, 1024}, {256, 65536}, @@ -101,6 +104,7 @@ TEST_P(CTTestSuite, TestCastTranspose) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // delayed tensor scaling performTest(size.first, size.second); ); ); diff --git a/tests/cpp/operator/test_cast_transpose_current_scaling.cu b/tests/cpp/operator/test_cast_transpose_current_scaling.cu new file mode 100644 index 0000000000..267970b34f --- /dev/null +++ b/tests/cpp/operator/test_cast_transpose_current_scaling.cu @@ -0,0 +1,210 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t, + const size_t N, const size_t H, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i * H + j] = OutputType(scale * current); + output_t[j * N + i] = OutputType(scale * current); + } + } +} + +template +void compute_amax_scale_ref(const InputType *data, + const size_t N, const size_t H, + float *amax_ptr, float *scale_ptr, float* scale_inv_ptr, + float max_fp8, float epsilon) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + current_max = fmaxf(current_max, fabsf(current)); + } + } + *amax_ptr = current_max; + + // compute scale from amax + float clamp_amax = current_max; + if (current_max <= epsilon){ + clamp_amax = epsilon; + } + + float scale = 1.f; + float scale_inv = 1.f; + + if (isinf(clamp_amax) || clamp_amax == 0.f) { + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; + return; + } + + // use ieee_div in CPU + scale = max_fp8 / clamp_amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + scale_inv = 1.0f / scale; + + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; +} + +// current tensor scaling test +template +void performTest(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + bool is_out_fp8 = isFp8Type(otype); + + // find out max fp8 value + float max_fp8; + if (is_out_fp8){ + switch (otype) { + case DType::kFloat8E5M2: { + max_fp8 = Quantized_Limits::max(); + } break; + case DType::kFloat8E4M3: { + max_fp8 = Quantized_Limits::max(); + } break; + default: + NVTE_ERROR("Invalid type."); + } + } + + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); + + std::unique_ptr ref_output_c = std::make_unique(N * H); + std::unique_ptr ref_output_t = std::make_unique(N * H); + + fillUniform(&input); + + // compute amax + float amax_to_check = 0.0f; + if (is_out_fp8){ + nvte_compute_amax(input.data(), output.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(output.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling + amax_to_check = output.amax(); + output.set_tensor_amax_nullptr(); + } + nvte_quantize(input.data(), output.data(), 0); + + float ref_amax; + float ref_scale; + float ref_scale_inv; + if (is_out_fp8){ + compute_amax_scale_ref(input.rowwise_cpu_dptr(), + N, H, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f); + } + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + ref_output_t.get(), N, H, nullptr, + is_out_fp8 ? output.scale() : 1.0f ); + + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); + compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); + compareResults("scale", output.scale(), ref_scale, 0.0f, rtol_fp32); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32); + compareResults("scale_inv_columnwise", output.columnwise_cpu_scale_inv_ptr()[0], ref_scale_inv, 0.0f, rtol_fp32); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output, ref_output_c.get(), true, 0.0f, rtol); + compareResults("output_t", output, ref_output_t.get(), false, 0.0f, rtol); +} + +std::vector> test_cases = {{2048, 12288}, + {768, 1024}, + {256, 65536}, + {65536, 128}, + {256, 256}, + {120, 2080}, + {8, 8}, + {1, 3221}, // Prime 456 + {2333, 1}, // Prime 345 + {1481, 677}}; // Primes 234, 123 +} // namespace + +class CTCSTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CTCSTestSuite, TestCastTransposeCS) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // current tensor scaling + performTest(size.first, size.second); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CTCSTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::ValuesIn(test::all_fp_types), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + std::to_string(std::get<2>(info.param).first) + "X" + + std::to_string(std::get<2>(info.param).second); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ec4a9bdbb7..24aff83d8a 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -103,10 +103,6 @@ size_t DIVUP(const size_t &x, const size_t &y){ return (((x) + ((y)-1)) / (y)); } -inline bool is_tensor_scaling(const NVTEScalingMode &mode) { - return mode == NVTE_DELAYED_TENSOR_SCALING; -} - struct scale_inv_meta { std::vector shape; DType type; @@ -233,7 +229,7 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); if (isFp8Type(type)) { - if (is_tensor_scaling(scaling_mode)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) @@ -296,11 +292,13 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (isFp8Type(dtype())) { - if (is_tensor_scaling(tensor_.scaling_mode())) { - cudaMemcpy(amax_cpu_data_.get(), - tensor_.amax(), - sizeof(float), - cudaMemcpyDeviceToHost); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.amax() != nullptr){ + cudaMemcpy(amax_cpu_data_.get(), + tensor_.amax(), + sizeof(float), + cudaMemcpyDeviceToHost); + } cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float), @@ -336,9 +334,11 @@ void Tensor::from_cpu() const { cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } if (isFp8Type(dtype())) { - if (is_tensor_scaling(tensor_.scaling_mode())) { - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.amax() != nullptr){ + cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + } cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } @@ -361,7 +361,7 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (is_tensor_scaling(tensor_.scaling_mode())) { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; from_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index dc515ccb8e..4352056ddb 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -256,6 +256,10 @@ class Tensor { return columnwise_; } + void set_tensor_amax_nullptr(){ + tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 2d301e3151..e2e78b72b1 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -14,13 +14,15 @@ import torch from torch import nn import torch.distributed as dist - +import transformer_engine_torch as tex from transformer_engine.common.recipe import ( MXFP8BlockScaling, DelayedScaling, + Float8CurrentScaling, Format, Recipe, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -45,6 +47,8 @@ def quantization_recipe() -> Recipe: ) if QUANTIZATION == "mxfp8": return MXFP8BlockScaling() + if QUANTIZATION == "fp8_cs": + return Float8CurrentScaling() return te.fp8.get_default_fp8_recipe() @@ -88,6 +92,7 @@ def main(argv=None, namespace=None): HIDDEN_SIZE = 128 test_dict = [ + test_quantizer, test_linear, test_layernorm, test_layernorm_linear, @@ -152,7 +157,12 @@ def dist_print(msg, src=None, end="\n", error=False): def _get_tolerances(dtype): - if QUANTIZATION is not None: + # loose tolerances for fp8_cs because of sequence parallel & amax reduction + # so that each rank has a different scale_inv for computing Y when we have + # row parallel & sequence parallel, because we do the all_gather in backward pass + if QUANTIZATION == "fp8_cs": + return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} if dtype == torch.float16: @@ -293,6 +303,98 @@ def _alloc_main_grad(model_single_node, model_distributed): param.main_grad = torch.zeros_like(param, dtype=torch.float32) +############################################### +# Quantizer # +############################################### +def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): + """ + quantizer is the reference quantizer on a single GPU. + quantizer_dist is the distributed quantizer to be tested on multiple GPUs. + """ + if quantizer_class == Float8CurrentScalingQuantizer: + quantizer_dist = quantizer_class( + fp8_dtype=fp8_dtype, + device=device, + with_amax_reduction=True, + amax_reduction_group=tp_group, + amax_reduction_size=tp_size, + ) + quantizer = quantizer_class( + fp8_dtype=fp8_dtype, + device=device, + with_amax_reduction=False, + ) + return quantizer, quantizer_dist + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_class}") + + +def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad).cuda()) + return out + + +@run_distributed_test() +def _test_quantizer(input_dtype, fp8_dtype): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + fp8_dtype (tex.DType): The data type of the fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + x_hp_cpu[M - 1, N - 1] = 1e4 + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + Float8CurrentScalingQuantizer, fp8_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the input + if WORLD_RANK == 0: + x_fp8_single = quantizer(x_hp_rank0) + + # multi-GPU quantizer + x_fp8_dist = quantizer_dist(x_hp_local_rank) + + # check scale_inv with zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close( + x_fp8_single._scale_inv, x_fp8_dist._scale_inv, rtol=0.0, atol=0.0 + ) + + +def test_quantizer(): + """ + Run quantizer tests with various configurations. + Currently only check fp8_cs because it needs to do amax reduction in the quantizer. + """ + # skip this test for other quantization schemes + if QUANTIZATION != "fp8_cs": + return + + input_dtypes = [torch.float32, torch.bfloat16] + fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + + for input_dtype in input_dtypes: + for fp8_dtype in fp8_dtypes: + _test_quantizer(input_dtype, fp8_dtype) + + ############################################ # Linear # ############################################ @@ -339,6 +441,11 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) ) input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[BATCH_SIZE - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed, dim=0).detach() else: input_distributed = input_single_node.clone() @@ -501,6 +608,12 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs # Duplicate input for sequence parallelism input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + # make the last element of the input a large value to test the amax reduction on purpose + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed).detach() else: input_distributed = input_single_node.clone() @@ -599,6 +712,12 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg # Duplicate input for sequence parallelism input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + # make the last element of the input a large value to test the amax reduction on purpose + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed).detach() else: input_distributed = input_single_node.clone() @@ -651,6 +770,7 @@ def test_layernorm_mlp(): {"return_bias": True}, {"return_layernorm_output": True}, ] + for kwargs in kwargs_list: for set_parallel_mode in [True]: for sequence_parallel in [False, True]: @@ -745,6 +865,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True}, {"activation": "relu"}, ] + for kwargs in kwargs_list: for sequence_parallel in [False, True]: _test_transformer_layer_parallel(sequence_parallel, **kwargs) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 7be9cd01ae..b4e2b680b3 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -48,10 +48,12 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) + if quantization == "fp8_cs" and not fp8_available: + pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) _run_test(quantization) diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py new file mode 100644 index 0000000000..1895b31d78 --- /dev/null +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType_To_Torch + + +# compute amax and scale +def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): + x_fp32 = x.to(torch.float32) + amax = torch.amax(torch.abs(x_fp32)).view(1) + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + # option1: set scale to fp32 max when scale is inf + scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale) + # option2: when scale is inf, set scale to 1 + scale = torch.where(scale == torch.inf, 1.0, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + # TODO: If/when adding a URM option an option is to cap to 126 + # rather than allowing the full range of FP32 (2 - 2^23) x 2^127 + # addresses cases where adding a mantissa overflows into inf scales. + # Not necessary currently without additional scale smudging options. + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax + + +def _multi_dim_transpose(tensor): + # Get the number of dimensions + dims = list(range(len(tensor.shape))) + + if len(dims) <= 1: + return tensor + + # circular shift of shapes + new_order = [] + new_order.append(dims[-1]) + for i in range(len(dims) - 1): + new_order.append(dims[i]) + + # Permute the tensor according to the new order + output_tensor = tensor.permute(new_order).contiguous() + + return output_tensor + + +# current scaling reference quantization +def ref_per_tensor_cs_cast( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + return_transpose: bool = False, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> torch.Tensor: + + quant_dtype_torch = TE_DType_To_Torch[fp8_dtype] + scale, scale_inv, _ = _ref_compute_amax_scale( + tensor, + quant_dtype_torch, + amax_epsilon, + force_pow_2_scales, + ) + + qx = (tensor.float() * scale).to(quant_dtype_torch) + sx = scale_inv + qx_t = None + sx_t = None + + if tensor.shape == torch.Size([]): + qx = qx.view([]) + + if return_transpose: + qx_t = _multi_dim_transpose(qx) + sx_t = sx + return qx, sx, qx_t, sx_t diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py new file mode 100644 index 0000000000..9741b1258c --- /dev/null +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -0,0 +1,802 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pathlib +import os +import torch +import pytest + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8CurrentScaling +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype + + +# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) + + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +class GetRecipes: + + @staticmethod + def none(): + return None + + @staticmethod + def fp8_per_tensor_current_scaling_default(): + # return default configs + return Float8CurrentScaling() + + +# base class for validating current_scaling x linear layer +class TestFP8RecipeLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + return torch.mean(torch.abs((a - b) / b)) + + @staticmethod + def _load_golden_tensor_values(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias): + recipe = get_recipe() + batch_size, hidden_size, out_size = dims + fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) + + # Expected tensor names based on the naming template + scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example + "ScalingType.PER_TENSOR" + ) + current_seed = torch.initial_seed() # Get the current seed + + expected_tensor_names = { + "y": f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "dgrad": f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "wgrad": f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "bgrad": f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + } + + if not use_bias: + expected_tensor_names.pop("bgrad") + + # Check if all expected tensors are in the tensor dumps directory + tensor_map = {} + for tensor_key, tensor_name in expected_tensor_names.items(): + tensor_path = dump_dir / tensor_name + if not os.path.exists(tensor_path): + print(f"Missing tensor: {tensor_name}") + return None + + # Load the tensor + tensor_map[tensor_key] = torch.load(tensor_path) + return tensor_map + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + def compare_recipe( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed, + dtype, + y_error=0.0, + dgrad_error=0.0, + wgrad_error=0.0, + bgrad_error=0.0, + recipe1_golden_tensors=None, + recipe2_golden_tensors=None, + ): + x, w, bias, gradient = self._prepare_data( + batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype + ) + + # recipe1 + using_fp8_recipe = recipe1 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) + else: + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) + + # recipe2 + using_fp8_recipe = recipe2 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) + else: + y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) + + # Compare results (mean abs relative error) + assert ( + self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error + ), "y and y_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error + ), "dgrad and dgrad_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error + ), "wgrad and wgrad_ref has too large mean abs relative error" + if use_bias: + assert ( + self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error + ), "bgrad and bgrad_ref has too large mean abs relative error" + + # enforce zero tolerance check when we can find golden tensor value dump + if recipe2_golden_tensors is not None: + torch.testing.assert_close( + y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0 + ) + torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0) + torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0) + if use_bias: + torch.testing.assert_close( + bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0 + ) + + +class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): + + @staticmethod + def _check_golden_tensor_dumps( + dump_dir, get_recipe, dims, input_dtype, use_bias, normalization + ): + recipe = get_recipe() + batch_size, hidden_size, out_size = dims + fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) + + # Expected tensor names based on the naming template + scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example + "ScalingType.PER_TENSOR" + ) + current_seed = torch.initial_seed() # Get the current seed + + expected_tensor_names = { + "y": f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "ln_out": f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "dgrad": f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "wgrad": f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "bgrad": f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + } + + if not use_bias: + expected_tensor_names.pop("bgrad") + + # Check if all expected tensors are in the tensor dumps directory + tensor_map = {} + for tensor_key, tensor_name in expected_tensor_names.items(): + tensor_path = dump_dir / tensor_name + if not os.path.exists(tensor_path): + print(f"Missing tensor: {tensor_name}") + return None + + # Load the tensor + tensor_map[tensor_key] = torch.load(tensor_path) + return tensor_map + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + def compare_recipe( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed, + dtype, + y_error=0.0, + ln_out_error=0.0, + dgrad_error=0.0, + wgrad_error=0.0, + bgrad_error=0.0, + normalization="LayerNorm", + LayerNormLinearClass1=te.LayerNormLinear, + LayerNormLinearClass2=te.LayerNormLinear, + recipe1_golden_tensors=None, + recipe2_golden_tensors=None, + ): + x, w, bias, gradient = self._prepare_data( + batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype + ) + + # recipe1 + using_fp8_recipe = recipe1 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass1, + ) + else: + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass1, + ) + + # recipe2 + using_fp8_recipe = recipe2 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass2, + ) + else: + y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass2, + ) + + # Compare results (mean abs relative error) + assert ( + self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error + ), "y and y_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(ln_out, ln_out_ref).item() < ln_out_error + ), "ln_out and ln_out_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error + ), "dgrad and dgrad_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error + ), "wgrad and wgrad_ref has too large mean abs relative error" + if use_bias: + assert ( + self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error + ), "bgrad and bgrad_ref has too large mean abs relative error" + + # enforce zero tolerance check when we can find golden tensor value dump + if recipe2_golden_tensors is not None: + torch.testing.assert_close( + y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0 + ) + torch.testing.assert_close(ln_out, recipe2_golden_tensors["ln_out"], atol=0.0, rtol=0.0) + torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0) + torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0) + if use_bias: + torch.testing.assert_close( + bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0 + ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 9d01527ac5..42600e3099 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -12,9 +12,17 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8Tensor, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch +from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported import transformer_engine_torch as tex +from references.ref_per_tensor_cs import ref_per_tensor_cs_cast + # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] # TE FP8 dtypes @@ -42,6 +50,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# delayed scaling def to_float8( tensor: torch.Tensor, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, @@ -56,6 +65,29 @@ def to_float8( return quantizer(tensor.cuda()) +# current scaling +def to_float8_CS( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + return_transpose: bool = False, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> Float8Tensor: + """Cast tensor to FP8""" + tensor = tensor.cuda() + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=tensor.device, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + if return_transpose: + quantizer.set_usage(rowwise=True, columnwise=True) + else: + quantizer.set_usage(rowwise=True, columnwise=False) + return quantizer(tensor) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFloat8Tensor: @@ -310,3 +342,89 @@ def test_set_data(self): assert x.size() == y.size() assert x.dtype == y.dtype assert x.device == y.device + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestCurrentScalingFloat8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize( + "dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]] + ) + @pytest.mark.parametrize("return_transpose", [True, False], ids=str) + @pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str) + @pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str) + def test_quantize( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + dims: DimsType, + return_transpose: bool, + force_pow_2_scales: bool, + amax_epsilon: float, + ) -> None: + """Check numerical error when casting to FP8""" + + # Skip invalid configurations + if non_tn_fp8_gemm_supported() and return_transpose: + pytest.skip("FP8 transpose is neither needed nor supported on current system") + + # Initialize random high precision data + device = "cuda" + x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1 + + # Cast to FP8 and back + x_fp8 = to_float8_CS( + x_hp, + fp8_dtype=fp8_dtype, + return_transpose=return_transpose, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + # get reference implementation of current scaling + x_fp8_ref, sx_ref, x_fp8_t_ref, _ = ref_per_tensor_cs_cast( + x_hp, + fp8_dtype=fp8_dtype, + return_transpose=return_transpose, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + torch.testing.assert_close(x_fp8._data, x_fp8_ref.view(torch.uint8), atol=0.0, rtol=0.0) + torch.testing.assert_close(x_fp8._scale_inv, sx_ref, atol=0.0, rtol=0.0) + if return_transpose: + torch.testing.assert_close( + x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0 + ) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]]) + def test_quantize_dequantize( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType + ) -> None: + """Check numerical error when casting to FP8 and back""" + + # Initialize random high precision data + device = "cuda" + x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1 + + # Cast to FP8 and back + x_fp8 = to_float8_CS(x_hp, fp8_dtype=fp8_dtype) + x_fp8_dequantized = x_fp8.dequantize() + + # Check results + torch.testing.assert_close(x_fp8_dequantized, x_hp, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a72ba097a1..5bec7f7c7f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -100,6 +100,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), ] @@ -670,6 +671,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for full recompute.") config = model_configs[model] @@ -1482,6 +1485,8 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_mxfp8) if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for grouped linear.") + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1675,6 +1680,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_mxfp8) if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for grouped linear.") + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index dcac5f1500..30989bec61 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -23,6 +23,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# FP8 per tensor delayed scaling @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFP8Recipe: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 0a2abb6e4e..007618ad57 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -86,6 +86,7 @@ list(APPEND transformer_engine_SOURCES fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu + recipe/current_scaling.cu recipe/delayed_scaling.cu comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 46eb248156..4163505db6 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -29,6 +29,18 @@ namespace transformer_engine { +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); } + +inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -132,7 +144,7 @@ struct Tensor { if (!has_data() && has_columnwise_data()) { const auto &data_shape = columnwise_data.shape; if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (is_tensor_scaling(scaling_mode)) { return product(data_shape, 1, data_shape.size()); } else { return product(data_shape, 0, data_shape.size() - 1); @@ -152,7 +164,7 @@ struct Tensor { if (!has_data() && has_columnwise_data()) { const auto &data_shape = columnwise_data.shape; if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (is_tensor_scaling(scaling_mode)) { return data_shape.front(); } else { return data_shape.back(); @@ -164,6 +176,16 @@ struct Tensor { } }; +struct QuantizationConfig { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0f; + + static constexpr size_t attr_sizes[] = { + sizeof(bool), // force_pow_2_scales + sizeof(float) // amax_epsilon + }; +}; + template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -396,6 +418,15 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ + if (CONDITION) { \ + constexpr bool FLAG = true; \ + { __VA_ARGS__ } \ + } else { \ + constexpr bool FLAG = false; \ + { __VA_ARGS__ } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { @@ -449,20 +480,6 @@ bool is_fp8_dtype(const DType t); std::string to_string(const DType type); std::string to_string(const NVTEScalingMode &type); -inline bool is_tensor_scaling(const NVTEScalingMode &mode) { - return mode == NVTE_DELAYED_TENSOR_SCALING; -} - -inline bool is_block_scaling(const NVTEScalingMode &mode) { - return mode != NVTE_DELAYED_TENSOR_SCALING; -} - -inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { - return is_tensor_scaling(mode); -} - -inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } - /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index b30a6e1338..44614bbe6b 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -73,6 +73,29 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); +/*! \brief Compute an FP8 tensor's amax. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Update an FP8 tensor's scale based on its amax. + * + * This is only supported for FP8 tensors with per-tensor scaling. + * Options are primarily intended for FP8 current-scaling recipes. + * + * \param[in,out] output FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e393dbffc4..e91f3c4836 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -68,11 +68,14 @@ enum NVTETensorParam { }; /*! \enum NVTEScalingMode - * \brief Granularity of scaling: + * \brief Tensor data format. */ enum NVTEScalingMode { - /*! Single scale per tensor, computed in delayed manner. - Used also for high precision data, without scaling */ + /*! Either an unquantized tensor or an FP8 tensor with per-tensor scaling + * + * Not necessary used for delayed tensor scaling. The unintuitive + * name reflects legacy usage. + */ NVTE_DELAYED_TENSOR_SCALING = 0, /*! Single scale per block of 32 elements consecutive in either rowwise or columnwise direction */ @@ -266,6 +269,57 @@ void nvte_tensor_pack_create(NVTETensorPack *pack); */ void nvte_tensor_pack_destroy(NVTETensorPack *pack); +/*! \brief Configuration for tensor quantization. */ +typedef void *NVTEQuantizationConfig; + +/*! \enum NVTEQuantizationConfigAttribute + * \brief Type of option for tensor quantization. + */ +enum NVTEQuantizationConfigAttribute { + /*! Whether to force power of 2 scales */ + kNVTEQuantizationConfigForcePow2Scales = 0, + /*! Small value to add to amax for numerical stability */ + kNVTEQuantizationConfigAmaxEpsilon = 1, + kNVTEQuantizationConfigNumAttributes +}; + +/*! \brief Create a new quantization config. + * \return A new quantization config. + */ +NVTEQuantizationConfig nvte_create_quantization_config(); + +/*! \brief Query an option in quantization config. + * + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in quantization config. + * + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, const void *buf, + size_t size_in_bytes); + +/*! \brief Destroy a quantization config. + * + * \param[in] config Config to be destroyed. + */ +void nvte_destroy_quantization_config(NVTEQuantizationConfig config); + #ifdef __cplusplus } // extern "C" @@ -610,6 +664,58 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \struct QuantizationConfigWrapper + * \brief C++ wrapper for NVTEQuantizationConfigWrapper. + */ +class QuantizationConfigWrapper { + public: + QuantizationConfigWrapper() : config_{nvte_create_quantization_config()} {} + + QuantizationConfigWrapper(const QuantizationConfigWrapper &) = delete; + QuantizationConfigWrapper &operator=(const QuantizationConfigWrapper &) = delete; + + QuantizationConfigWrapper(QuantizationConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + QuantizationConfigWrapper &operator=(QuantizationConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_quantization_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~QuantizationConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_quantization_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEQuantizationConfig. + * + * \return NVTEQuantizationConfig held by this QuantizationConfigWrapper. + */ + operator NVTEQuantizationConfig() const noexcept { return config_; } + + /*! \brief Set whether to force power of 2 scales */ + void set_force_pow_2_scales(bool force_pow_2_scales) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, + &force_pow_2_scales, sizeof(bool)); + } + + /*! \brief Set small value to add to amax */ + void set_amax_epsilon(float amax_epsilon) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEpsilon, + &amax_epsilon, sizeof(float)); + } + + private: + /*! \brief Wrapped NVTEQuantizationConfig. */ + NVTEQuantizationConfig config_ = nullptr; +}; + } // namespace transformer_engine #endif // __cplusplus diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 0bce83d98f..937383d5ec 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -39,6 +39,27 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) +@dataclass(frozen=True) +class MMParams: + """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) + apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, + so only turn it on for certain gemms + """ + + use_split_accumulator: bool = True + + +@dataclass(frozen=True) +class QParams: + """Quantization parameters. + power_2_scale: use power of 2 scale parameter + amax_epsilon: optional minimum value of abs max + """ + + power_2_scale: bool = False + amax_epsilon: float = 0.0 + + class Recipe: """ Base recipe class. @@ -52,6 +73,10 @@ def delayed(self): """Whether the given recipe is delayed scaling.""" return isinstance(self, DelayedScaling) + def float8_current_scaling(self): + """Whether the given recipe is (per-tensor) current scaling.""" + return isinstance(self, Float8CurrentScaling) + @dataclass() class DelayedScaling(Recipe): @@ -161,6 +186,75 @@ def __repr__(self) -> str: ) +@dataclass() +class Float8CurrentScaling(Recipe): + """ + Use the per-tensor current scaling factor strategy. + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of gradient tensor dY + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_dpa: bool, default = `False` + Whether to enable FP8 dot product attention (DPA). When the model is placed in an + `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + inputs from higher precision to FP8, performs attention in FP8, and casts tensors + back to higher precision as outputs. FP8 DPA currently is only supported in the + `FusedAttention` backend. + fp8_mha: bool, default = `False` + Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting + operations mentioned above at the DPA boundaries. Currently only standard MHA modules + i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When + `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as + `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. + When `fp8_mha = True, fp8_dpa = True`, it becomes + `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + + Notes + ----- + * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are + subject to change in future Transformer Engine releases. + """ + + fp8_format: Format = Format.HYBRID + fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) + + @dataclass() class MXFP8BlockScaling(Recipe): """ diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu new file mode 100644 index 0000000000..3a25d71a3b --- /dev/null +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -0,0 +1,237 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/vectorized_pointwise.h" + +namespace transformer_engine { +namespace { + +constexpr int amax_kernel_threads = 512; + +template +__launch_bounds__(amax_kernel_threads) __global__ + void amax_kernel(const InputType *input, float *amax, const size_t N, + const size_t num_aligned_elements) { + VectorizedLoader loader(input, N); + InputType max = 0.f; + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val = static_cast(loader.separate()[i]); + __builtin_assume(max >= InputType{0.f}); + if constexpr (std::is_same_v) { +#if __CUDA_ARCH__ >= 800 + max = __hmax(__habs(val), max); +#else // Turing + max = static_cast<__nv_bfloat16>( + fmaxf(fabsf(static_cast(val)), static_cast(max))); +#endif + } else if constexpr (std::is_same_v) { + max = __hmax(__habs(val), max); + } else { + max = fmaxf(fabsf(val), max); + } + } + } + + // Reduce amax over block + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + atomicMaxFloat(amax, max); + } +} + +template +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { + // Zero out amax so we can update with atomic max + cudaMemsetAsync(amax, 0, sizeof(float), stream); + + // Return immediately if tensor is empty + if (N == 0) { + return; + } + + // Figure out alignment + auto align = CheckAlignment(N, nvec, input); + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); + + // Figure out CUDA blocks + constexpr size_t threads = amax_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + // Launch kernel + switch (align) { + case Alignment::SAME_ALIGNED: + amax_kernel + <<>>(input, amax, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + amax_kernel + <<>>(input, amax, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // This case is a logic error, since there is only one pointer (input) + // in the alignment check. Still safe to process without vectorization. + amax_kernel<1, true, InputType><<>>(input, amax, N, N); + break; + } + } + + // Check results + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace +} // namespace transformer_engine + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + using namespace transformer_engine; + + // Check input tensor + NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); + const auto &input = *reinterpret_cast(input_); + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor for amax computation must unquantized, " + "but got scaling_mode=", + to_string(input.scaling_mode)); + NVTE_CHECK(!is_fp8_dtype(input.data.dtype), + "Input tensor for amax computation must be unquantized, but got dtype=", + to_string(input.data.dtype)); + NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data"); + CheckInputTensor(input, "input_compute_amax"); + + // Check output tensor + NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); + auto &output = *reinterpret_cast(output_); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + "but got scaling_mode=", + to_string(output.scaling_mode)); + NVTE_CHECK(output.amax.numel() == 1, + "Output tensor for amax computation has invalid amax tensor " + "(expected 1 entry, got shape=", + output.amax.shape, ")"); + NVTE_CHECK(output.amax.dptr != nullptr, + "Output tensor for amax computation has amax tensor without data"); + NVTE_CHECK(output.amax.dtype == DType::kFloat32, + "Output tensor for amax computation has invalid amax tensor " + "(expected FP32, got dtype=", + to_string(output.amax.dtype), ")"); + CheckOutputTensor(output, "output_compute_amax"); + + // Compute amax + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); + launch_amax_kernel(reinterpret_cast(input.data.dptr), + reinterpret_cast(output.amax.dptr), input.data.numel(), + stream);); // NOLINT(*) +} + +namespace transformer_engine { +namespace { + +__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, + const float max_fp8, const bool force_pow_2_scales, + const float epsilon) { + float amax = *amax_ptr; + if (amax < epsilon) { + amax = epsilon; + } + + float scale = 1.f; + + if (isinf(amax) || amax == 0.f) { + *scale_ptr = scale; + return; + } + + scale = max_fp8 / amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + // use fp32 max to represent the scale + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + if (force_pow_2_scales) { + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + + *scale_ptr = scale; +} + +} // namespace +} // namespace transformer_engine + +void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_scale_from_amax); + using namespace transformer_engine; + + // Check output tensor + NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); + auto &output = *reinterpret_cast(output_); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Tensor must be FP8 tensor with per-tensor scaling, " + "but got scaling_mode=", + to_string(output.scaling_mode)); + NVTE_CHECK(is_fp8_dtype(output.data.dtype), + "Tensor must be FP8, but got dtype=", to_string(output.data.dtype)); + NVTE_CHECK(output.amax.numel() == 1, + "Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape, + ")"); + NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data"); + NVTE_CHECK(output.amax.dtype == DType::kFloat32, + "Tensor has invalid amax tensor (expected FP32, got dtype=", + to_string(output.amax.dtype), ")"); + NVTE_CHECK(output.scale.numel() == 1, + "Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape, + ")"); + NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data"); + NVTE_CHECK(output.scale.dtype == DType::kFloat32, + "Tensor has invalid scale tensor (expected FP32, got dtype=", + to_string(output.scale.dtype), ")"); + + // Check config + NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)"); + const auto &config = *reinterpret_cast(config_); + + // Maximum FP8 value + float max_fp8 = 0.f; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, + max_fp8 = Quantized_Limits::max_norm;); + + // Update scale + compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast(output.amax.dptr), + reinterpret_cast(output.scale.dptr), max_fp8, + config.force_pow_2_scales, config.amax_epsilon); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 54d5b0b5bf..23f272d5d5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -6,6 +6,7 @@ #include +#include #include #include "common.h" @@ -150,8 +151,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax - if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor"); + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) { NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, @@ -410,3 +410,79 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); } } + +NVTEQuantizationConfig nvte_create_quantization_config() { + return new transformer_engine::QuantizationConfig; +} + +void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, + "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for quantization config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEQuantizationConfigForcePow2Scales: + std::memcpy(buf, &config_.force_pow_2_scales, attr_size); + break; + case kNVTEQuantizationConfigAmaxEpsilon: + std::memcpy(buf, &config_.amax_epsilon, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, const void *buf, + size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, + "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for quantization config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEQuantizationConfigForcePow2Scales: + std::memcpy(&config_.force_pow_2_scales, buf, attr_size); + break; + case kNVTEQuantizationConfigAmaxEpsilon: + std::memcpy(&config_.amax_epsilon, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 4cdb39b70a..7f3b9fb302 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -249,7 +249,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output.dtype(), OutputType, - if (is_delayed_tensor_scaling(output.scaling_mode)) { + if (is_tensor_scaling(output.scaling_mode)) { + // delayed scaling and current scaling are two variants of per-tensor scaling + constexpr const char *itype_name = TypeInfo::name; constexpr const char *otype_name = TypeInfo::name; constexpr size_t itype_size = sizeof(InputType); @@ -323,6 +325,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel <<>>( static_cast(input.data.dptr), diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index b4b86fe708..ba2890ada3 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1054,8 +1054,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), @@ -1079,8 +1078,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp input->data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryGradKernelLauncher( reinterpret_cast(grad.data.dptr), @@ -1164,14 +1162,22 @@ template scaling_mode) || IS_DBIAS) { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + + if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + " on GPU with compute capability < 10.0."); } - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 63ce369892..227b3aaa48 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -844,7 +844,7 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war staging[warpid] = my_warp_max; } __syncthreads(); - compute_t result = 0; + compute_t result = 0.f; if (warpid == 0) { const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; result = warp_reduce_max(my_max); diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index ff475caf21..543b1181cb 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -24,6 +24,16 @@ torch.bfloat16: tex.DType.kBFloat16, } +TE_DType_To_Torch = { + tex.DType.kByte: torch.uint8, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, +} + AttnMaskTypes = ( "no_mask", "padding", diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 5775fe381d..23137a1003 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -46,15 +46,22 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); std::unique_ptr my_quantizer = convert_quantizer(quantizer); + // check for both quantizer & tensor type: + // mxfp8 tensor -> mxfp8 quantizer + // float8 tensor -> delayed scaling quantizer OR current scaling quantizer + // also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer for (auto [check_type, check_quantizer_type, create_tensor, _] : detail::custom_types_converters) { if (check_type(tensor.ptr())) { - NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()), - "Unexpected quantization params type."); + if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) { + continue; + } auto x = create_tensor(tensor, my_quantizer.get()); return x; } } + NVTE_CHECK(dynamic_cast(my_quantizer.get()) != nullptr, + "Unexpected quantization params type."); // Regular pyTorch tensor at::Tensor torch_tensor = tensor.cast(); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 40245cf2d9..980b2dff13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -50,6 +50,9 @@ namespace transformer_engine::pytorch { +// in python we have: dist_group_type = torch.distributed.ProcessGroup +using dist_group_type = c10d::ProcessGroup; + // Each tensor here is shape (N, ) holding all scaling // data for a single FP8 block, e.g. LayerNormLinear class FP8TensorMeta { @@ -136,6 +139,29 @@ class Float8Quantizer : public Quantizer { std::optional rowwise_data = std::nullopt) const override; }; +class Float8CurrentScalingQuantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + int amax_reduction_size; + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + + explicit Float8CurrentScalingQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + class MXFP8Quantizer : public Quantizer { public: DType dtype; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7ce33ee77b..1ef6f5258d 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common.h" #include "extensions.h" #include "pybind.h" @@ -24,7 +25,35 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto [te_output, out] = my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // for current scaling, we need to compute amax first and then quantize + // because cache cannot fit in the entire tensor to compute amax and quantize + // the quantizer should not need amax reduction, no process group needed here + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // activation function might change the input data range, we need to first call the activation function + // and then find the amax and scale of that and then do the quantization + // get a NoneQuantizer to calculate amax of activation output + auto my_quantizer_none = std::make_unique(py::none()); + auto [te_output_act, out_act] = + my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); + act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); + // use te_output_act as input to the compute amax and find the amax of activated tensor + nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + if (my_quantizer_cs->with_amax_reduction) { + NVTE_ERROR( + "per-tensor current scaling amax reduction is not supported in activation functions."); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + } else { + act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + } return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 66dafdaafb..2c3ccff154 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -45,6 +45,29 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob } if (te_output.numel() == 0) return out; + + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0604847235..3e944c0fdd 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -24,6 +24,7 @@ namespace transformer_engine::pytorch { PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *Float8TensorBasePythonClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr; +PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; @@ -33,6 +34,8 @@ void init_float8_extension() { auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); + Float8CurrentScalingQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer")); Float8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); auto fp8_base_module = diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index effeb8cb4d..427bf294d3 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -140,6 +140,123 @@ std::pair Float8Quantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) + : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + // For current scaling, need several other components: + // 1. with_amax_reduction: bool + // 2. amax_reduction_group: torch.distributed.ProcessGroup or None + // 3. amax_reduction_size: int + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group"); + const c10::intrusive_ptr amax_reduction_group = + amax_reduction_group_obj.is_none() + ? nullptr + : amax_reduction_group_obj.cast>(); + const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + this->amax_reduction_size = amax_reduction_size; + + // fp8 current scaling specific quantization params + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); +} + +void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them) + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + // quantize output and its transpose + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8CurrentScalingQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + + //unlike delayed scaling, in current scaling, scale is not known, so scale_inv should be empty buffer + opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts); + + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); } diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 316e6515bf..b127b5d75b 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -12,7 +12,7 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { return; } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d2607e4ed0..27d5869704 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -23,7 +23,8 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer if (transpose_valid) { transpose = tensor.attr("_transpose").cast>(); } - + // In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer + // whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING auto ret = TensorWrapper(quantizer->get_scaling_mode()); ret.set_rowwise_data(data.data_ptr(), dtype, shape); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 0679528b94..b0f55d7598 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -21,6 +21,7 @@ namespace transformer_engine::pytorch { extern PyTypeObject *Float8TensorPythonClass; extern PyTypeObject *Float8TensorBasePythonClass; extern PyTypeObject *Float8QuantizerClass; +extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8QuantizerClass; @@ -33,13 +34,17 @@ void init_mxfp8_extension(); namespace detail { -inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } +inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } + +inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { + return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass; +} inline bool IsFloat8Tensor(PyObject *obj) { return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; } -inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } +inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; @@ -61,9 +66,11 @@ inline bool IsFloatingPointType(at::ScalarType type) { } constexpr std::array custom_types_converters = { - std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, + std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, CreateQuantizer)}; } // namespace detail diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fe023208d1..c1fc15968b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -21,7 +21,7 @@ from .utils import safely_set_viewless_tensor_data from .constants import dist_group_type from .fp8 import FP8GlobalStateManager -from .tensor.float8_tensor import Float8Quantizer, Float8Tensor +from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase @@ -859,7 +859,10 @@ def _all_gather_fp8( # Quantize input tensor if needed if not isinstance(input_, Float8TensorBase): - assert isinstance(quantizer, Float8Quantizer) + assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + # we cannot directly gather the transposed fp8 tensor + # so we need to disable columnwise usage for the quantizer + # and then set it back to the original value after quantizing init_columnwise_usage = quantizer.columnwise_usage quantizer.set_usage(columnwise=False) input_ = quantizer(input_) @@ -867,7 +870,7 @@ def _all_gather_fp8( # Construct output tensor out: Float8TensorBase - if isinstance(quantizer, Float8Quantizer): + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): dtype = torch.float32 device = "cuda" if isinstance(input_, Float8Tensor): @@ -885,6 +888,9 @@ def _all_gather_fp8( out._transpose_invalid = True else: raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + # For delayed scaling, scale_inv is from history, so we can pass it from input_ to out + # For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv, + # so we can just pass it from input_ to out out._scale_inv = input_._scale_inv # Perform communication @@ -999,8 +1005,10 @@ def gather_along_first_dim( out_shape = list(input_.size()) out_shape[0] *= world_size - # FP8 case - if isinstance(input_, Float8TensorBase) or isinstance(quantizer, Float8Quantizer): + # FP8 case: delayed scaling or current scaling + if isinstance(input_, Float8TensorBase) or isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): return _all_gather_fp8( input_, process_group, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f788368112..87298c2ec7 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -13,7 +13,13 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, +) from .constants import dist_group_type from .utils import get_device_compute_capability @@ -198,6 +204,8 @@ def add_fp8_tensors_to_global_buffer( fp8_meta: Dict[str, Any], ) -> None: """ + Delayed scaling only. + The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is to call this function in order to append it's FP8 tensor into a global @@ -211,7 +219,8 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return # Every module must call this function exactly once since @@ -326,7 +335,8 @@ def reduce_and_update_fp8_tensors( cls, forward: bool = True, ) -> None: - """Concatenate, reduce, and split amaxes in the global buffer.""" + """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" + # global_amax_buffer should only be non-empty for fp8 delayed scaling for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) @@ -426,6 +436,8 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + # delayed scaling only function, for other recipes (current scaling with any granularity), + # this is noop for other recipes because cls.global_amax_buffer is empty list cls.reduce_and_update_fp8_tensors(forward=True) @classmethod @@ -434,7 +446,8 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - to ensure both forward steps are numerically same. """ - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -459,8 +472,8 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non """Switch to the copied scaling factors and amaxes from phase 1 forward for indentical numerical outputs. """ - - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return # Store updated amaxes and scales from phase 1 post forward. @@ -478,8 +491,8 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) @@ -743,6 +756,8 @@ def create( cls = DelayedScalingRecipeState elif recipe.mxfp8(): cls = MXFP8BlockScalingRecipeState + elif recipe.float8_current_scaling(): + cls = Float8CurrentScalingRecipeState else: raise ValueError("{recipe.__class__.__name__} is not supported") return cls( @@ -813,6 +828,45 @@ def make_quantizers(self) -> list: ] +class Float8CurrentScalingRecipeState(RecipeState): + """Configuration for Per-tensor current scaling quantization. + + Per-tensor current quantization does not require state. + + """ + + recipe: Float8CurrentScaling + mode: str + dtype: tex.DType + device: torch.device + + def __init__( + self, + recipe: Float8CurrentScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + from .tensor.float8_tensor import Float8CurrentScalingQuantizer + + return [ + Float8CurrentScalingQuantizer(self.dtype, device=self.device) + for i in range(self.num_quantizers) + ] + + class MXFP8BlockScalingRecipeState(RecipeState): """Configuration for MXFP8 quantization. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 84326f58ea..a44e209d36 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,6 +21,7 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8CurrentScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,6 +35,7 @@ from ..tensor import QuantizedTensor, Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer __all__ = ["initialize_ub", "destroy_ub"] @@ -430,7 +432,10 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: - """Increase or decrease size of amax history based on given `length`. + """ + Delayed scaling only. + + Increase or decrease size of amax history based on given `length`. .. warning:: This changes the underlying amax memory location. @@ -489,6 +494,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): return + if recipe.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -851,6 +860,9 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + # FP8 current scaling does not support fused cast + dbias + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 10b21f25c6..8bf420ab0e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -88,6 +88,9 @@ def forward( # TODO Support MXFP8 # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): raise NotImplementedError("GroupedLinear does not yet support MXFP8") + # TODO Support Float8 Current Scaling # pylint: disable=fixme + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2608fedeb1..7571b17c1f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, get_ub, @@ -55,8 +56,8 @@ restore_from_saved, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param - from ..cpp_extensions import ( general_gemm, ) @@ -159,6 +160,11 @@ def forward( # Configure quantizer for normalization output with_quantized_norm = fp8 and not return_layernorm_output + # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer + # so we need to set with_quantized_norm to False + if isinstance(input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False + if with_quantized_norm: if with_input_all_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -210,6 +216,10 @@ def forward( with_quantized_all_gather = False if fp8: input_quantizer.set_usage(rowwise=True, columnwise=False) + # ln_out in this has two possibilities: + # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel + # 2. in high precision, then we need to cast it and then gather in FP8 + # the output ln_out_total will be in FP8, and it's a full tensor ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, @@ -290,6 +300,12 @@ def forward( ln_out_total = ub_obj.get_buffer(input_quantizer) nvtx_range_push(f"{nvtx_label}.gemm") + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + out, *_, rs_out = general_gemm( weightmat, ln_out_total, @@ -297,7 +313,7 @@ def forward( quantization_params=output_quantizer, out_dtype=activation_dtype, bias=bias, - use_split_accumulator=_2X_ACC_FPROP, + use_split_accumulator=fprop_gemm_use_split_accumulator, ub=ub_obj, ub_type=ub_type, extra_output=rs_out, @@ -359,6 +375,7 @@ def forward( ctx.weight = weight ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -431,11 +448,12 @@ def backward( ctx.ub_bulk_wgrad, ] ) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + and (ctx.fp8_recipe is not None) ): - raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" - ) + if not ctx.fp8_recipe.delayed(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) saved_tensors = ctx.saved_tensors ( # pylint: disable=unbalanced-tuple-unpacking @@ -572,6 +590,12 @@ def backward( ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + dgrad, *_ = general_gemm( weight, grad_output, @@ -581,7 +605,7 @@ def backward( quantization_params=ctx.grad_input_quantizer, out=dgrad_bulk, out_dtype=ctx.activation_dtype, - use_split_accumulator=_2X_ACC_DGRAD, + use_split_accumulator=dgrad_gemm_use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, extra_output=rs_out, @@ -643,6 +667,14 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) + wgrad, grad_bias_, *_, rs_out = general_gemm( ln_out_total, grad_output, @@ -654,7 +686,7 @@ def backward( ), bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, + use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, @@ -1139,6 +1171,16 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif other recipes (mxfp8, etc) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1332,3 +1374,44 @@ def _get_quantizers(self, fp8_output): grad_output_quantizer, grad_input_quantizer, ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f4ee0a1155..9bb76cb391 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, _ub_communicators, @@ -59,7 +60,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param - +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -73,17 +74,53 @@ __all__ = ["LayerNormMLP"] -def _act_func(activation: str): - funcs = { - "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), - "relu": (tex.relu, tex.drelu, tex.dbias_drelu), +def _get_act_func_supported_list(recipe: Optional[Recipe] = None): + if recipe is None: + # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + return { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + } + if recipe.delayed() or recipe.mxfp8(): + # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + return { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + } + # no activation fusion written yet + # Per-tensor current scaling: [] + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), "reglu": (tex.reglu, tex.dreglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + "srelu": (tex.srelu, tex.dsrelu, None), } + + +def _act_func(activation: str, recipe: Optional[Recipe] = None): + # based on each quantization mode, we have different kernel fusion supported: + # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # Per-tensor current scaling: [] + funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") return funcs[activation] @@ -161,7 +198,9 @@ def forward( "Comm+GEMM overlap is only supported with FP8 delayed scaling" ) - activation_func = _act_func(activation)[0] + activation_func = _act_func( + activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + )[0] device = inp.device # Cast for native AMP @@ -175,6 +214,8 @@ def forward( # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned with_quantized_norm = fp8 and not return_layernorm_output + if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output @@ -220,6 +261,8 @@ def forward( zero_centered_gamma, ) + ln_out_return = ln_out if return_layernorm_output else None + # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False @@ -229,6 +272,10 @@ def forward( with_quantized_all_gather = False if fp8: fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + # ln_out in this has two possibilities: + # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel + # 2. in high precision, then we need to cast it and then gather in FP8 + # the output ln_out_total will be in FP8, and it's a full tensor ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, @@ -240,26 +287,19 @@ def forward( if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) else: + if fp8: + if not isinstance(ln_out, QuantizedTensor): + fc1_input_quantizer.set_usage( + rowwise=True, columnwise=backwards_needs_fc1_input + ) + ln_out = fc1_input_quantizer(ln_out) + elif backwards_needs_fc1_input: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + # here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer + # or fused into the layernorm kernel + # ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out ln_out_total = ln_out - # If residual connection is after LN, we need `ln_out` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. - ln_out_return = None - if return_layernorm_output: - ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out - if fp8 and not with_quantized_all_gather: - ln_out_total = fc1_input_quantizer(ln_out_total) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[ - slice_start:slice_end, ... - ] # TODO(pgadzinski) - check this # pylint: disable=fixme - else: - ln_out = ln_out_total - # Cast weights to expected dtype fc1_weight_final = fc1_weight fc2_weight_final = fc2_weight @@ -459,6 +499,7 @@ def forward( ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -546,11 +587,12 @@ def backward( ctx.ub_bulk_wgrad, ] ) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + and (ctx.fp8_recipe is not None) ): - raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" - ) + if not ctx.fp8_recipe.delayed(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) saved_tensors = ctx.saved_tensors ( # pylint: disable=unbalanced-tuple-unpacking @@ -733,22 +775,36 @@ def backward( fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.grad_fc1_output_quantizer is not None: dact = ctx.grad_fc1_output_quantizer(dact) - elif _act_func(ctx.activation)[2] is not None and ctx.fp8: + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and ctx.fp8 + ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func(ctx.activation)[2] + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func(ctx.activation)[1] + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO zhongboz: per-tensor current scaling has no bgrad fusion for now + if isinstance(ctx.grad_fc1_output_quantizer, Float8CurrentScalingQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.grad_fc1_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.grad_fc1_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -1286,6 +1342,15 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1494,3 +1559,76 @@ def _get_quantizers(self): grad_fc2_output_quantizer, grad_input_quantizer, ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # fc1_input_quantizer: set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # fc2_input_quantizer + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # fc1_weight_quantizer: also set numerical configs about weight + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # fc2_weight_quantizer + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + if self.sequence_parallel and self.set_parallel_mode: + # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_size = self.tp_size diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f07cfb487b..675a8f929b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -11,6 +11,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, get_ub, @@ -228,6 +229,12 @@ def forward( inputmat_total = ub_obj.get_buffer(input_quantizer) nvtx_range_push(f"{nvtx_label}.gemm") + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + out, *_, rs_out = general_gemm( weightmat, inputmat_total, @@ -235,7 +242,7 @@ def forward( quantization_params=output_quantizer, out_dtype=out_dtype, bias=bias, - use_split_accumulator=_2X_ACC_FPROP, + use_split_accumulator=fprop_gemm_use_split_accumulator, ub=ub_obj, ub_type=ub_type, extra_output=rs_out, @@ -277,6 +284,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer @@ -344,11 +352,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.ub_bulk_wgrad, ] ) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + and (ctx.fp8_recipe is not None) ): - raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" - ) + if not ctx.fp8_recipe.delayed(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) saved_tensors = ctx.saved_tensors inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking @@ -483,6 +492,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # dgrad GEMM nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_dgrad.use_split_accumulator + ) + dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -492,7 +509,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantization_params=ctx.grad_input_quantizer, out=dgrad_bulk, out_dtype=ctx.activation_dtype, - use_split_accumulator=_2X_ACC_DGRAD, + use_split_accumulator=dgrad_gemm_use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, extra_output=rs_out, @@ -551,6 +568,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) + wgrad, grad_bias_, _, rs_out = general_gemm( inputmat_total, grad_output, @@ -562,7 +587,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, + use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, @@ -955,6 +980,16 @@ def __init__( else: self.gemm_bias_unfused_add = False + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -1118,3 +1153,56 @@ def _get_quantizers(self, fp8_output, fp8_grad): grad_output_quantizer, grad_input_quantizer, ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # paralle related + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # set grad_output_quantizer with amax epsilon and power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_size = self.tp_size diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5944039cf0..178401f6a6 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -14,6 +14,7 @@ from ..utils import devices_match, non_tn_fp8_gemm_supported from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..constants import dist_group_type aten = torch.ops.aten @@ -166,6 +167,167 @@ def create_tensor_from_data( ) +class Float8CurrentScalingQuantizer(Quantizer): + """Builder class for FP8 tensors with per-tensor current scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized by + multiplying with a scaling factor and casting to FP8. The max-abs + value ("amax") in the tensor is computed directly by scanning the input + high-precision tensor, without the need of any history window. + + Unlike delayed scaling, scale and amax tensors are not needed to initialize the + quantizer, becuse they are simply GPU buffers that will be filled by current + scaling quantization kernels, instead of using values taken from delayed scaling + history window. Therefore, device parameter is needed for tensor allocation. + + Both Float8CurrentScalingQuantizer and Float8Quantizer produces Float8Tensor, + because they are both per-tensor scaling, ie. one scaling factor per tensor. + + """ + + """Scaling factor to multiply when quantizing to FP8""" + scale: torch.Tensor + """Max-abs value from last FP8 cast""" + amax: torch.Tensor + """FP8 datatype""" + dtype: TE_DType + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + amax_reduction_size: Optional[int] + """Options about how to quantize the tensor""" + force_pow_2_scales: bool + amax_epsilon: float + + def __init__( + self, + fp8_dtype: TE_DType, + device: torch.device, + *, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + amax_reduction_size: Optional[int] = 1, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.scale = torch.empty(1, dtype=torch.float32, device=device) + self.amax = torch.empty(1, dtype=torch.float32, device=device) + self.dtype = fp8_dtype + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.amax_reduction_size = amax_reduction_size + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + if not isinstance(dst, Float8Tensor): + raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor") + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update FP8 dtype + dst._fp8_dtype = self.dtype + + return dst + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + + # Allocate FP8 data transpose if needed + data_transpose = None + if self.columnwise_usage: + inner_dim = data.size(-1) + data_transpose = torch.empty( + inner_dim, + data.numel() // inner_dim, + dtype=torch.uint8, + device=device, + ) + + # Construct FP8 tensor + return Float8Tensor( + shape=shape, + dtype=dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=data_transpose, + quantizer=self, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # current scaling don't need to calibrate + return + + def create_tensor_from_data( + self, + data: torch.Tensor, + fake_dtype=torch.float32, + requires_grad: bool = False, + internal: bool = False, + ): + """ + Create Float8Tensor from raw uint8 data, unlike delayed scaling, + self.scale doesn't mean anything, so we are simply creating empty scale_inv + """ + assert data.dtype in [ + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + if internal: + return Float8TensorBase( + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) + return Float8Tensor( + shape=data.shape, + dtype=fake_dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) + + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data @@ -192,7 +354,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): FP8 format. data_transpose: torch.Tensor, optional FP8 transpose data in a uint8 tensor - quantizer: Float8Quantizer, optional + quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional Builder class for FP8 tensors """ @@ -229,10 +391,9 @@ def _get_quantizer(self) -> Quantizer: """ if self._quantizer is not None: return self._quantizer - return Float8Quantizer( - scale=torch.reciprocal(self._scale_inv), - amax=torch.empty(1, dtype=torch.float32, device=self.device), - fp8_dtype=self._fp8_dtype, + # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling) + raise ValueError( + "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable" ) def quantize_(