From 7c263392ab3dadc373c74031fa80d6c5a19ddde3 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 3 Jan 2025 15:18:06 -0800 Subject: [PATCH] Fuse output zero fill into grouped gemm kernel setup (#3537) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3537 X-link: https://github.com/facebookresearch/FBGEMM/pull/624 During E2E evaluation we found that the kernel launch overhead of having to fill the output with zeros, then separately initialize the grouped gemm args was noticeably impacting performance. This diff fuses the two together. I believe the output initialization is done pretty efficiently here but am open to feedback! In microbenchmarks, there does not appear to be any regression from this change. Reviewed By: jianyuh Differential Revision: D67777269 fbshipit-source-id: f4bf0bd7ab1ff05a7c4f56b32c5766c6be8f4b3d --- .../fp8_rowwise_grouped_gemm.hip | 63 ++++++++++++++----- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index 75df7f77dd..370931f29f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -18,6 +18,7 @@ #include #include +#include #include #include "ck/ck.hpp" @@ -184,17 +185,18 @@ __global__ void set_kernel_args_fixed_nk_kernel( int M, int N, int K, - int group_count) { - int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + int group_count, + const int BLOCK_SIZE) { + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; // Each thread is responsible for setting up the arguments for one group. - if (group_idx < group_count) { + if (thread_idx < group_count) { // Compute offsets for this group. - int group_M = prepad_M[group_idx]; + int group_M = prepad_M[thread_idx]; KernelArguments kernel_group_args = { - XQ + (group_idx * M * K), - WQ + (group_idx * N * K), - {w_scale + (group_idx * N), x_scale + (group_idx * M)}, - output + (group_idx * M * N), + XQ + (thread_idx * M * K), + WQ + (thread_idx * N * K), + {w_scale + (thread_idx * N), x_scale + (thread_idx * M)}, + output + (thread_idx * M * N), group_M, N, K, @@ -203,7 +205,35 @@ __global__ void set_kernel_args_fixed_nk_kernel( {0, 0}, N}; // Write kernel args to memory. - kernel_args[group_idx] = kernel_group_args; + kernel_args[thread_idx] = kernel_group_args; + } + + // We also fuse in initialization of the output tensor. + // We write in chunks of 2 bfloats at a time for efficiency. + for (int i = 0; i < BLOCK_SIZE / 2; i++) { + // Figure out where in memory we are. + int output_offset = (thread_idx * BLOCK_SIZE) + (i * 2); + int current_group = output_offset / (M * N); + // Skip if outside of valid groups. + if (current_group < group_count) { + int nonzeros = prepad_M[current_group]; + int current_M = output_offset / N; + // Only write if this block needs initialization. + // Avoid writing to final element if number of elements is odd. + if (current_M >= nonzeros && output_offset < (M * N * group_count) - 1) { + __hip_bfloat162* output_block = + reinterpret_cast<__hip_bfloat162*>(output + output_offset); + *output_block = __hip_bfloat162(0, 0); + } + } + } + // Handle case where there are an odd number of total elements. + if (((M * N * group_count) % 2) != 0 && + ((M * N * group_count) - (thread_idx * BLOCK_SIZE) < BLOCK_SIZE)) { + // Write out the final element. + __hip_bfloat16* output_block = + reinterpret_cast<__hip_bfloat16*>(output + (M * N * group_count) - 1); + *output_block = __hip_bfloat16(0); } } @@ -262,8 +292,10 @@ void set_dynamic_kernel_args( } // Launch a kernel that sets kernel argument memory. - int const blockSize = std::min(1024, group_count); - int const numBlocks = (group_count + blockSize - 1) / blockSize; + const int BLOCK_SIZE = 8; + int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE); + int blockSize = std::min(1024, block_factor); + int numBlocks = (block_factor + blockSize - 1) / blockSize; set_kernel_args_fixed_nk_kernel<<>>( reinterpret_cast(kernel_args.data_ptr()), reinterpret_cast(XQ[0].data_ptr()), @@ -275,7 +307,8 @@ void set_dynamic_kernel_args( M, N, K, - group_count); + group_count, + BLOCK_SIZE); } at::Tensor get_grouped_kernel_args( @@ -380,10 +413,10 @@ std::vector f8f8bf16_rowwise_grouped( if (zero_start_index_M.has_value()) { int M = XQ[0].size(0); int N = WQ[0].size(0); - // Fill output with zeros to simplify integration. This prevents nans from - // showing up in the tensor. + // Allocate an empty output array. We will set its values to zero as part + // of kernel setup. at::Tensor Y_full = - at::zeros({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); + at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); // Split the output into groups. Y = at::unbind(Y_full, 0); } else {