Skip to content

Commit

Permalink
group gemm w/o padding (#3399)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3399

X-link: facebookresearch/FBGEMM#487

Just some scaffolding code for group gemm. The idea is that:
* for router score, we'll move the non-zero to the left side, and calculate the indices and the number of non-zeros for each local expert
* group gemm input (needs more discussion):
    * input: 3D tensor [local_expert, tokens, D]
    * input: router_nonzeros tensor - on the M dimension, how many of them needs to be calculated
    * output: We need pad 0 to those 0 entries to make it work with cudagraph.

We only support bf16 grouped gemm for now. FP8 grouped gemm only supports tensor-wise scaling and rowwise scaling has some limitation in cutlass that requires some further work.

Reviewed By: jiawenliu64

Differential Revision: D65260109

fbshipit-source-id: e9b60241c173af34b84d33184262776ca0b38310
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Nov 25, 2024
1 parent efdb2d0 commit 0505ed8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>
#include <torch/script.h>

// clang-format off
// The fixed ordering of the headers is required for CUTLASS 3.2+
Expand Down Expand Up @@ -331,6 +332,10 @@ at::Tensor bf16bf16bf16_grouped_impl(
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t output_offset = 0;

if (zero_start_index_M.has_value() == true) {
TORCH_CHECK(zero_start_index_M.value().dtype() == torch::kInt32);
}

// Set arguments
for (int i = 0; i < problem_count; ++i) {
int N = W[i].size(0);
Expand Down
76 changes: 76 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,82 @@ def test_fp8_grouped_gemm(
y_bf16_group_list[i], y_group_ref[i], atol=8.0e-2, rtol=8.0e-2
)

@unittest.skipIf(
not torch.version.cuda, "Skip on AMD: GMM ops are not yet suported."
)
@settings(deadline=None)
@given(
G=st.sampled_from([4, 5]),
M=st.sampled_from([2048, 3584]),
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
use_cudagraph=st.sampled_from([True, False]),
use_padding_zeros=st.sampled_from([True, False]),
)
def test_bf16_grouped_gemm(
self,
G: int,
M: int,
N: int,
K: int,
use_cudagraph: bool,
use_padding_zeros: bool,
) -> None:
G = 16
M = 64
N = 1024
K = 5120
xs = torch.rand(size=(G, M, K), dtype=torch.bfloat16, device="cuda")
ws = torch.rand(size=(G, N, K), dtype=torch.bfloat16, device="cuda")

x_group = [x.squeeze() for x in xs.split(1, dim=0)]
w_group = [w.squeeze() for w in ws.split(1, dim=0)]

zero_start_index_M = None

use_padding_zeros = True
if use_padding_zeros:
zero_start_index_M = torch.randint(
1,
M,
(G,),
dtype=torch.int,
device="cuda",
)
for i in range(len(x_group)):
x_group[i][zero_start_index_M[i] :, :] = 0

# BF16 grouped gemm kernel
if use_cudagraph:
# warmup
torch.ops.fbgemm.bf16bf16bf16_grouped(
x_group,
w_group,
zero_start_index_M if use_padding_zeros else None,
)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_bf16_group = torch.ops.fbgemm.bf16bf16bf16_grouped(
x_group,
w_group,
zero_start_index_M if use_padding_zeros else None,
)
g.replay()
else:
y_bf16_group = torch.ops.fbgemm.bf16bf16bf16_grouped(
x_group,
w_group,
zero_start_index_M if use_padding_zeros else None,
)

# BF16 loopover gemm reference
y_group_ref = torch.bmm(xs, ws.transpose(1, 2))

torch.testing.assert_close(
y_group_ref, y_bf16_group.view([G, M, N]), atol=8.0e-2, rtol=8.0e-2
)

@unittest.skipIf(torch.version.hip, "Skip on AMD: Marlin not yet suported.")
@settings(deadline=None)
@given(
Expand Down

0 comments on commit 0505ed8

Please sign in to comment.