Skip to content

Commit

Permalink
Fuse output zero fill into grouped gemm kernel setup (#3537)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3537

X-link: facebookresearch/FBGEMM#624

During E2E evaluation we found that the kernel launch overhead of having to fill the output with zeros, then separately initialize the grouped gemm args was noticeably impacting performance. This diff fuses the two together. I believe the output initialization is done pretty efficiently here but am open to feedback!

In microbenchmarks, there does not appear to be any regression from this change.

Reviewed By: jianyuh

Differential Revision: D67777269

fbshipit-source-id: f4bf0bd7ab1ff05a7c4f56b32c5766c6be8f4b3d
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jan 3, 2025
1 parent 84fa740 commit 7c26339
Showing 1 changed file with 48 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <ATen/ATen.h>
#include <c10/hip/HIPStream.h>
#include <hip_bf16.h>
#include <torch/torch.h>

#include "ck/ck.hpp"
Expand Down Expand Up @@ -184,17 +185,18 @@ __global__ void set_kernel_args_fixed_nk_kernel(
int M,
int N,
int K,
int group_count) {
int group_idx = blockIdx.x * blockDim.x + threadIdx.x;
int group_count,
const int BLOCK_SIZE) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each thread is responsible for setting up the arguments for one group.
if (group_idx < group_count) {
if (thread_idx < group_count) {
// Compute offsets for this group.
int group_M = prepad_M[group_idx];
int group_M = prepad_M[thread_idx];
KernelArguments kernel_group_args = {
XQ + (group_idx * M * K),
WQ + (group_idx * N * K),
{w_scale + (group_idx * N), x_scale + (group_idx * M)},
output + (group_idx * M * N),
XQ + (thread_idx * M * K),
WQ + (thread_idx * N * K),
{w_scale + (thread_idx * N), x_scale + (thread_idx * M)},
output + (thread_idx * M * N),
group_M,
N,
K,
Expand All @@ -203,7 +205,35 @@ __global__ void set_kernel_args_fixed_nk_kernel(
{0, 0},
N};
// Write kernel args to memory.
kernel_args[group_idx] = kernel_group_args;
kernel_args[thread_idx] = kernel_group_args;
}

// We also fuse in initialization of the output tensor.
// We write in chunks of 2 bfloats at a time for efficiency.
for (int i = 0; i < BLOCK_SIZE / 2; i++) {
// Figure out where in memory we are.
int output_offset = (thread_idx * BLOCK_SIZE) + (i * 2);
int current_group = output_offset / (M * N);
// Skip if outside of valid groups.
if (current_group < group_count) {
int nonzeros = prepad_M[current_group];
int current_M = output_offset / N;
// Only write if this block needs initialization.
// Avoid writing to final element if number of elements is odd.
if (current_M >= nonzeros && output_offset < (M * N * group_count) - 1) {
__hip_bfloat162* output_block =
reinterpret_cast<__hip_bfloat162*>(output + output_offset);
*output_block = __hip_bfloat162(0, 0);
}
}
}
// Handle case where there are an odd number of total elements.
if (((M * N * group_count) % 2) != 0 &&
((M * N * group_count) - (thread_idx * BLOCK_SIZE) < BLOCK_SIZE)) {
// Write out the final element.
__hip_bfloat16* output_block =
reinterpret_cast<__hip_bfloat16*>(output + (M * N * group_count) - 1);
*output_block = __hip_bfloat16(0);
}
}

Expand Down Expand Up @@ -262,8 +292,10 @@ void set_dynamic_kernel_args(
}

// Launch a kernel that sets kernel argument memory.
int const blockSize = std::min(1024, group_count);
int const numBlocks = (group_count + blockSize - 1) / blockSize;
const int BLOCK_SIZE = 8;
int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE);
int blockSize = std::min(1024, block_factor);
int numBlocks = (block_factor + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ[0].data_ptr()),
Expand All @@ -275,7 +307,8 @@ void set_dynamic_kernel_args(
M,
N,
K,
group_count);
group_count,
BLOCK_SIZE);
}

at::Tensor get_grouped_kernel_args(
Expand Down Expand Up @@ -380,10 +413,10 @@ std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
if (zero_start_index_M.has_value()) {
int M = XQ[0].size(0);
int N = WQ[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
// Allocate an empty output array. We will set its values to zero as part
// of kernel setup.
at::Tensor Y_full =
at::zeros({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16));
at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16));
// Split the output into groups.
Y = at::unbind(Y_full, 0);
} else {
Expand Down

0 comments on commit 7c26339

Please sign in to comment.