From c2cf9730ae695e83e27cbb304d56fe1f19a06659 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 24 Jun 2024 10:09:53 -0700 Subject: [PATCH] fix mx triton kernel after PyTorch triton pin change (#431) Summary: Triton pin updated recently: https://github.com/pytorch/pytorch/pull/126098 In the new triton version, functions can only access global variables of type `tl.constexpr`. Due to the current structure of the code and the fact that these constants are also used by non-triton programs, I think the best thing to do is to just stop using globals in the MX triton kernel. The PR lifts all of these constants to kernel function arguments. Test Plan: ``` pytest test/prototype/mx_formats/test_custom_cast.py ``` Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/regression_test.yml | 2 +- torchao/prototype/mx_formats/custom_cast.py | 143 +++++++++++++++++--- 2 files changed, 128 insertions(+), 17 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 63c66e471..21ad6535a 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -33,7 +33,7 @@ jobs: gpu-arch-version: "12.1" - name: CUDA Nightly runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch==2.5.0.dev20240620+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CPU 2.2.2 diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index f14deba9f..00082acad 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -107,7 +107,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor): import triton.language as tl @triton.jit - def _fp4_packed_to_bf16(x_packed): + def _fp4_packed_to_bf16( + x_packed, + sign_mask_f4, + mantissa_mask_f4, + mbits_f4_e2m1, + ebits_f4_e2m1, + f4_e2m1_exp_bias, + mbits_f32, + ebits_f32, + f32_exp_bias, + zero_bits_f32, + zero_point_five_bits_f32, + ): """ Input: a tensor of packed fp4 values Output: a tensor of bfloat16 values @@ -123,7 +135,7 @@ def _fp4_packed_to_bf16(x_packed): # output = x_unpacked.to(tl.float32) # save the sign - sign_f4 = x & SIGN_MASK_F4 + sign_f4 = x & sign_mask_f4 # set everything to positive, will add sign back at the end x_pos = x ^ sign_f4 @@ -138,25 +150,25 @@ def _fp4_packed_to_bf16(x_packed): denormal_mask = x_pos == 1 # calculate the new exponent and shift it to bits 2:9 of the result - exp_biased_f4 = x_pos >> MBITS_F4_E2M1 - exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS - exp_biased_f32 = exp_biased_f32.to(tl.int32) << MBITS_F32 + exp_biased_f4 = x_pos >> mbits_f4_e2m1 + exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias + exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32 # shift the mantissa to bits 10:32 of the result - mantissa_f4 = x_pos & MANTISSA_MASK_F4 - mantissa_f32 = mantissa_f4.to(tl.int32) << (MBITS_F32 - MBITS_F4_E2M1) + mantissa_f4 = x_pos & mantissa_mask_f4 + mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1) output = mantissa_f32 # combine the pieces result = exp_biased_f32 | mantissa_f32 # result[zero_mask] = ZERO_BITS_F32 - result = tl.where(zero_mask, ZERO_BITS_F32, result) + result = tl.where(zero_mask, zero_bits_f32, result) # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32 - result = tl.where(denormal_mask, ZERO_POINT_FIVE_BITS_F32, result) + result = tl.where(denormal_mask, zero_point_five_bits_f32, result) # add sign back sign_f32 = sign_f4.to(tl.int32) << ( - MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1 + mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1 ) result = result | sign_f32 @@ -174,6 +186,16 @@ def triton_f4_to_bf16_kernel( x_ptr, output_ptr, n_elements_in, + sign_mask_f4: tl.constexpr, + mantissa_mask_f4: tl.constexpr, + mbits_f4_e2m1: tl.constexpr, + ebits_f4_e2m1: tl.constexpr, + f4_e2m1_exp_bias: tl.constexpr, + mbits_f32: tl.constexpr, + ebits_f32: tl.constexpr, + f32_exp_bias: tl.constexpr, + zero_bits_f32: tl.constexpr, + zero_point_five_bits_f32: tl.constexpr, BLOCK_SIZE_IN: tl.constexpr, ): pid = tl.program_id(axis=0) @@ -187,7 +209,19 @@ def triton_f4_to_bf16_kernel( # packed uint8 x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16(x_packed) + output = _fp4_packed_to_bf16( + x_packed, + sign_mask_f4, + mantissa_mask_f4, + mbits_f4_e2m1, + ebits_f4_e2m1, + f4_e2m1_exp_bias, + mbits_f32, + ebits_f32, + f32_exp_bias, + zero_bits_f32, + zero_point_five_bits_f32, + ) # set up output offsets block_start_out = pid * BLOCK_SIZE_OUT @@ -213,6 +247,18 @@ def triton_f4_to_scaled_bf16_kernel( output_ptr, n_elements_in, mx_block_size: tl.constexpr, + sign_mask_f4: tl.constexpr, + mantissa_mask_f4: tl.constexpr, + mbits_f4_e2m1: tl.constexpr, + ebits_f4_e2m1: tl.constexpr, + f4_e2m1_exp_bias: tl.constexpr, + mbits_f32: tl.constexpr, + ebits_f32: tl.constexpr, + f32_exp_bias: tl.constexpr, + zero_bits_f32: tl.constexpr, + zero_point_five_bits_f32: tl.constexpr, + e8m0_exponent_bias: tl.constexpr, + e8m0_exponent_nan_val: tl.constexpr, BLOCK_SIZE_IN: tl.constexpr, ): pid = tl.program_id(axis=0) @@ -227,7 +273,19 @@ def triton_f4_to_scaled_bf16_kernel( mask_in = offsets_in < n_elements_in # packed uint8 x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16(x_packed) + output = _fp4_packed_to_bf16( + x_packed, + sign_mask_f4, + mantissa_mask_f4, + mbits_f4_e2m1, + ebits_f4_e2m1, + f4_e2m1_exp_bias, + mbits_f32, + ebits_f32, + f32_exp_bias, + zero_bits_f32, + zero_point_five_bits_f32, + ) # load scale block_start_s = pid * BLOCK_SIZE_S @@ -236,9 +294,9 @@ def triton_f4_to_scaled_bf16_kernel( s = tl.load(s_ptr + offsets_s, mask=mask_s) # create the scale in bf16 - s_offset = s.to(tl.int16) - E8M0_EXPONENT_BIAS + s_offset = s.to(tl.int16) - e8m0_exponent_bias s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) - s_fp = tl.where(s != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan")) + s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) # multiply output by scale # TODO(later): see if manipulating the exponent instead of fp @@ -263,6 +321,16 @@ def triton_f4_to_bf16_kernel( x_ptr, output_ptr, n_elements_in, + sign_mask_f4, + mantissa_mask_f4, + mbits_f4_e2m1, + ebits_f4_e2m1, + f4_e2m1_exp_bias, + mbits_f32, + ebits_f32, + f32_exp_bias, + zero_bits_f32, + zero_point_five_bits_f32, BLOCK_SIZE_IN, ): raise AssertionError("unsupported without triton") @@ -273,6 +341,18 @@ def triton_f4_to_scaled_bf16_kernel( output_ptr, n_elements_in, mx_block_size, + sign_mask_f4, + mantissa_mask_f4, + mbits_f4_e2m1, + ebits_f4_e2m1, + f4_e2m1_exp_bias, + mbits_f32, + ebits_f32, + f32_exp_bias, + zero_bits_f32, + zero_point_five_bits_f32, + e8m0_exponent_bias, + e8m0_exponent_nan_val, BLOCK_SIZE_IN, ): raise AssertionError("unsupported without triton") @@ -294,7 +374,22 @@ def triton_f4_to_bf16(x: torch.Tensor): grid = lambda meta: ( # noqa: E731 triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), ) # noqa: E731,E501 - triton_f4_to_bf16_kernel[grid](x, output, n_elements_in, BLOCK_SIZE_IN=512) + triton_f4_to_bf16_kernel[grid]( + x, + output, + n_elements_in, + sign_mask_f4=SIGN_MASK_F4, + mantissa_mask_f4=MANTISSA_MASK_F4, + mbits_f4_e2m1=MBITS_F4_E2M1, + ebits_f4_e2m1=EBITS_F4_E2M1, + f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, + mbits_f32=MBITS_F32, + ebits_f32=EBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + zero_bits_f32=ZERO_BITS_F32, + zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, + BLOCK_SIZE_IN=512, + ) return output @@ -318,7 +413,23 @@ def triton_f4_to_scaled_bf16( triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), ) triton_f4_to_scaled_bf16_kernel[grid]( - x, s_e8m0, output, n_elements_in, mx_block_size + x, + s_e8m0, + output, + n_elements_in, + mx_block_size, + sign_mask_f4=SIGN_MASK_F4, + mantissa_mask_f4=MANTISSA_MASK_F4, + mbits_f4_e2m1=MBITS_F4_E2M1, + ebits_f4_e2m1=EBITS_F4_E2M1, + f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, + mbits_f32=MBITS_F32, + ebits_f32=EBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + zero_bits_f32=ZERO_BITS_F32, + zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, ) return output