Skip to content

Commit a368b9a

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Split bf16 grouped gemm into dynamic and static versions. (#3544)
Summary: X-link: facebookresearch/FBGEMM#629 This diff updates bf16 grouped gemm in the same way as we did for fp8 in D67810956. This allows a bit more performance depending on the use case. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D67813627
1 parent fd04e6a commit a368b9a

File tree

5 files changed

+154
-108
lines changed

5 files changed

+154
-108
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -632,23 +632,15 @@ def quantize_fixed_nk(self, x, w):
632632
# Stack inputs into groups.
633633
x = torch.stack(xp).contiguous()
634634
w = torch.stack(w).contiguous()
635-
# Allocate output tensor.
636-
output = torch.empty(
637-
[x.shape[0], x.shape[1], w.shape[1]],
638-
dtype=torch.bfloat16,
639-
device=x.device,
640-
)
641635
# View these unified tensors as lists of tensors.
642636
x = [xi.squeeze() for xi in x.split(1, dim=0)]
643637
w = [wi.squeeze() for wi in w.split(1, dim=0)]
644-
output = [o.squeeze() for o in output.split(1, dim=0)]
645638

646639
# Return processed tensors.
647640
return (
648641
x,
649642
w,
650643
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
651-
output,
652644
)
653645

654646
def quantize(self, x, w):
@@ -664,23 +656,18 @@ def quantize(self, x, w):
664656
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
665657
return self.quantize_fixed_nk(x, w)
666658

667-
output = [
668-
torch.empty(m, n, device=x[0].device, dtype=torch.bfloat16)
669-
for m, n in zip(m_values, n_values)
670-
]
671659
m_values = None
672-
return x, w, m_values, output
660+
return x, w, m_values
673661

674-
def compute(self, x, w, m_values, _):
675-
return torch.ops.fbgemm.bf16bf16bf16_grouped(
676-
x,
677-
w,
678-
zero_start_index_M=m_values,
679-
)
662+
def compute(self, x, w, m_values):
663+
if m_values is None:
664+
return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w)
665+
else:
666+
return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values)
680667

681668
def quantize_and_compute(self, x, w):
682-
x, w, m_values, output = self.quantize(x, w)
683-
return self.compute(x, w, m_values, output)
669+
x, w, m_values = self.quantize(x, w)
670+
return self.compute(x, w, m_values)
684671

685672
@property
686673
def name(self) -> str:

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ at::Tensor get_grouped_kernel_args(
290290
std::vector<at::Tensor> bf16bf16bf16_grouped(
291291
at::TensorList A,
292292
at::TensorList B,
293-
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
294293
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
295294
// Check that input datatypes are valid.
296295
// First confirm that there are the same number of groups in all inputs.
@@ -333,30 +332,16 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
333332
Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16.");
334333
}
335334
} else {
336-
// Two modes for allocating output. When m_values is provided, we need
337-
// the output tensor to be contiguous and can assume M, N, and K are the
338-
// same across groups. Otherwise, we can allocate each output separately.
339-
if (zero_start_index_M.has_value()) {
340-
int M = A[0].size(0);
341-
int N = B[0].size(0);
342-
// Fill output with zeros to simplify integration. This prevents nans from
343-
// showing up in the tensor.
344-
at::Tensor Y_full =
345-
at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16));
346-
// Split the output into groups.
347-
Y = at::unbind(Y_full, 0);
348-
} else {
349-
for (int i = 0; i < group_count; i++) {
350-
int M = A[i].size(0);
351-
int N = B[i].size(0);
352-
Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16)));
353-
}
335+
for (int i = 0; i < group_count; i++) {
336+
int M = A[i].size(0);
337+
int N = B[i].size(0);
338+
Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16)));
354339
}
355340
}
356341

357342
// Prepare kernel arguments by copying them to the proper device location.
358343
at::Tensor kernel_args = get_grouped_kernel_args(
359-
A, B, zero_start_index_M, Y);
344+
A, B, std::nullopt, Y);
360345

361346
// Perform shape lookup to find best kernel.
362347
// We use the largest of each shape for heuristics.
@@ -373,4 +358,62 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
373358
return selected_kernel(A, B, kernel_args, Y);
374359
}
375360

361+
at::Tensor bf16bf16bf16_grouped_dynamic(
362+
at::TensorList A,
363+
at::TensorList B,
364+
at::Tensor zero_start_index_M) {
365+
// Check that input datatypes are valid.
366+
// First confirm that there are the same number of groups in all inputs.
367+
TORCH_CHECK(
368+
A.size() == B.size(),
369+
"A and B must have the same number of groups.");
370+
int group_count = A.size();
371+
// Iterate over inputs and check they are valid.
372+
for (at::Tensor a : A) {
373+
TORCH_CHECK(a.is_cuda() && a.is_contiguous());
374+
TORCH_CHECK(a.dim() == 2, "Inputs must be 2D.");
375+
TORCH_CHECK(
376+
a.dtype() == at::kBFloat16,
377+
"Inputs must be type bfloat16.");
378+
}
379+
for (at::Tensor b : B) {
380+
TORCH_CHECK(b.is_cuda() && b.is_contiguous());
381+
TORCH_CHECK(b.dim() == 2, "Inputs must be 2D.");
382+
TORCH_CHECK(
383+
b.dtype() == at::kBFloat16,
384+
"Inputs must be type bfloat16.");
385+
}
386+
387+
std::vector<at::Tensor> Y;
388+
int M = A[0].size(0);
389+
int N = B[0].size(0);
390+
// Fill output with zeros to simplify integration. This prevents nans from
391+
// showing up in the tensor.
392+
at::Tensor Y_full =
393+
at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16));
394+
// Split the output into groups.
395+
Y = at::unbind(Y_full, 0);
396+
397+
// Prepare kernel arguments by copying them to the proper device location.
398+
at::Tensor kernel_args = get_grouped_kernel_args(
399+
A, B, zero_start_index_M, Y);
400+
401+
// Perform shape lookup to find best kernel.
402+
// We use the largest of each shape for heuristics.
403+
int MaxM = 0;
404+
int MaxN = 0;
405+
int MaxK = 0;
406+
for (int i = 0; i < group_count; i++) {
407+
MaxM = max(MaxM, A[i].size(0));
408+
MaxN = max(MaxN, B[i].size(0));
409+
MaxK = max(MaxK, A[i].size(1));
410+
}
411+
GroupedKernel selected_kernel =
412+
grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
413+
// Run kernel to populate output.
414+
selected_kernel(A, B, kernel_args, Y);
415+
// Return coalesced view of output tensor.
416+
return Y_full;
417+
}
418+
376419
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ template <
261261
std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
262262
at::TensorList X, // BF16
263263
at::TensorList W, // BF16
264+
std::vector<at::Tensor> output_tensor,
264265
std::optional<at::Tensor> zero_start_index_M) {
265266
int problem_count = X.size();
266267
TORCH_CHECK(W.size() == problem_count);
@@ -273,27 +274,6 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
273274
using GroupedGemmConfigs = GroupedGemmBF16Args::
274275
GroupedGemmConfigs<TB_M, TB_N, TB_K, TBS_M, TBS_N, TBS_K, PONG>;
275276

276-
constexpr int AlignmentA =
277-
128 /
278-
cutlass::sizeof_bits<
279-
GroupedGemmBF16Args::ElementInputA>::value; // Alignment of A matrix
280-
// in units of elements
281-
// (up to 16 bytes)
282-
283-
constexpr int AlignmentB =
284-
128 /
285-
cutlass::sizeof_bits<
286-
GroupedGemmBF16Args::ElementInputB>::value; // Alignment of B matrix
287-
// in units of elements
288-
// (up to 16 bytes)
289-
290-
constexpr int AlignmentD =
291-
128 /
292-
cutlass::sizeof_bits<
293-
GroupedGemmBF16Args::ElementOutput>::value; // Alignment of C matrix
294-
// in units of elements
295-
// (up to 16 bytes)
296-
297277
at::Tensor output_args =
298278
at::empty({problem_count}, X[0].options().dtype(at::kLong));
299279

@@ -313,28 +293,6 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
313293
int stride_buf_offset =
314294
problem_count * 2 * sizeof(int64_t) + problem_shape_size;
315295

316-
// Two modes for allocating output. When m_values is provided, we need
317-
// the output tensor to be contiguous and can assume M, N, and K are the
318-
// same across groups. Otherwise, we can allocate each output separately.
319-
std::vector<at::Tensor> output_tensor;
320-
if (zero_start_index_M.has_value()) {
321-
int M = X[0].size(0);
322-
int N = W[0].size(0);
323-
// Fill output with zeros to simplify integration. This prevents nans from
324-
// showing up in the tensor.
325-
at::Tensor output_full =
326-
at::zeros({problem_count, M, N}, X[0].options().dtype(at::kBFloat16));
327-
// Split the output into groups.
328-
output_tensor = at::unbind(output_full, 0);
329-
} else {
330-
for (int i = 0; i < problem_count; i++) {
331-
int M = X[i].size(0);
332-
int N = W[i].size(0);
333-
output_tensor.push_back(
334-
at::empty({M, N}, X[i].options().dtype(at::kBFloat16)));
335-
}
336-
}
337-
338296
TORCH_CHECK(
339297
!zero_start_index_M.has_value() ||
340298
zero_start_index_M->dtype() == at::kLong,
@@ -505,40 +463,78 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
505463
std::vector<at::Tensor> dispatch_bf16_grouped_kernel(
506464
at::TensorList x_group, // BF16
507465
at::TensorList w_group, // BF16
466+
std::vector<at::Tensor> output_tensor,
508467
std::optional<at::Tensor> zero_start_index_M) {
509468
KernelMode kernel = get_grouped_kernel_mode(x_group, w_group);
510469
if (kernel == KernelMode::Small) {
511470
return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>(
512-
x_group, w_group, zero_start_index_M);
471+
x_group, w_group, output_tensor, zero_start_index_M);
513472
} else if (kernel == KernelMode::Large) {
514473
return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>(
515-
x_group, w_group, zero_start_index_M);
474+
x_group, w_group, output_tensor, zero_start_index_M);
516475
} else {
517476
return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>(
518-
x_group, w_group, zero_start_index_M);
477+
x_group, w_group, output_tensor, zero_start_index_M);
519478
}
520479
}
521480

522481
std::vector<at::Tensor> bf16bf16bf16_grouped(
523482
at::TensorList x_group, // BF16
524483
at::TensorList w_group, // BF16
525-
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
526484
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
527485
TORCH_CHECK(!output.has_value(), "Preallocated output not yet supported.");
528-
return dispatch_bf16_grouped_kernel(x_group, w_group, zero_start_index_M);
486+
// Initialize output tensor.
487+
int problem_count = x_group.size();
488+
std::vector<at::Tensor> output_tensor;
489+
for (int i = 0; i < problem_count; i++) {
490+
int M = x_group[i].size(0);
491+
int N = w_group[i].size(0);
492+
output_tensor.push_back(
493+
at::empty({M, N}, x_group[i].options().dtype(at::kBFloat16)));
494+
}
495+
return dispatch_bf16_grouped_kernel(
496+
x_group, w_group, output_tensor, std::nullopt);
497+
}
498+
499+
at::Tensor bf16bf16bf16_grouped_dynamic(
500+
at::TensorList x_group, // BF16
501+
at::TensorList w_group, // BF16
502+
at::Tensor zero_start_index_M) {
503+
std::vector<at::Tensor> output_tensor;
504+
int problem_count = x_group.size();
505+
int M = x_group[0].size(0);
506+
int N = w_group[0].size(0);
507+
// Fill output with zeros to simplify integration. This prevents nans from
508+
// showing up in the tensor.
509+
at::Tensor output_full = at::zeros(
510+
{problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16));
511+
// Split the output into groups.
512+
output_tensor = at::unbind(output_full, 0);
513+
// Run kernel to populate output tensor.
514+
dispatch_bf16_grouped_kernel(
515+
x_group, w_group, output_tensor, zero_start_index_M);
516+
// Return coalesced view of output.
517+
return output_full;
529518
}
530519

531520
#else
532521

533522
std::vector<at::Tensor> bf16bf16bf16_grouped(
534523
at::TensorList /* x_group */, // BF16
535524
at::TensorList /* w_group */, // BF16
536-
std::optional<at::Tensor> /* zero_start_index_M */,
537525
std::optional<std::vector<at::Tensor>> /* output */) {
538526
throw std::runtime_error(
539527
"CUDA version is older than 12.0"); // requires CUDA>=12
540528
}
541529

530+
at::Tensor bf16bf16bf16_grouped_dynamic(
531+
at::TensorList /* x_group */, // BF16
532+
at::TensorList /* w_group */, // BF16
533+
at::Tensor /* zero_start_index_M */) {
534+
throw std::runtime_error(
535+
"CUDA version is older than 12.0"); // requires CUDA>=12
536+
}
537+
542538
#endif
543539

544540
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ std::vector<at::Tensor> f8f8bf16_grouped(
6464
std::vector<at::Tensor> bf16bf16bf16_grouped(
6565
at::TensorList X,
6666
at::TensorList W,
67-
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
6867
std::optional<std::vector<at::Tensor>> output = std::nullopt);
68+
at::Tensor bf16bf16bf16_grouped_dynamic(
69+
at::TensorList X,
70+
at::TensorList W,
71+
at::Tensor zero_start_index_M);
6972
at::Tensor f8f8bf16_rowwise(
7073
at::Tensor XQ,
7174
at::Tensor WQ,
@@ -195,7 +198,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
195198
get_f8f8bf16_rowwise_grouped_kernels);
196199
#endif
197200
m.def(
198-
"bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None, Tensor[](a!)? output=None) -> Tensor[]");
201+
"bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]");
202+
m.def(
203+
"bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor zero_start_index_M) -> Tensor");
199204
m.def(
200205
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor");
201206
m.def(
@@ -248,6 +253,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
248253
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
249254
m.impl("quantize_fp8_per_col", quantize_fp8_per_col);
250255
m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped);
256+
m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic);
251257
#ifndef USE_ROCM
252258
m.impl("i8i8bf16", i8i8bf16);
253259
m.impl("f8f8bf16", f8f8bf16);
@@ -273,6 +279,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
273279
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
274280
m.impl("quantize_fp8_per_col", quantize_fp8_per_col);
275281
m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped);
282+
m.impl("bf16bf16bf16_grouped_dyanmic", bf16bf16bf16_grouped_dynamic);
276283
#ifndef USE_ROCM
277284
m.impl("i8i8bf16", i8i8bf16);
278285
m.impl("f8f8bf16", f8f8bf16);
@@ -474,7 +481,6 @@ std::vector<at::Tensor> f8f8bf16_grouped_meta(
474481
std::vector<at::Tensor> bf16bf16bf16_grouped_meta(
475482
at::TensorList X,
476483
at::TensorList W,
477-
std::optional<at::Tensor> /* zero_start_index_M = std::nullopt */,
478484
std::optional<std::vector<at::Tensor>> /* output = std::nullopt */
479485
) {
480486
std::vector<at::Tensor> Y;
@@ -486,6 +492,17 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_meta(
486492
return Y;
487493
}
488494

495+
at::Tensor bf16bf16bf16_grouped_dynamic_meta(
496+
at::TensorList X,
497+
at::TensorList W,
498+
at::Tensor /* zero_start_index_M = std::nullopt */) {
499+
int G = X.size();
500+
int M = X[0].size(0);
501+
int N = W[0].size(0);
502+
at::Tensor Y = at::empty({G, M, N}, X[0].options().dtype(at::kBFloat16));
503+
return Y;
504+
}
505+
489506
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
490507
m.impl("f8f8bf16_blockwise", f8f8bf16_blockwise_meta);
491508
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta);
@@ -495,6 +512,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
495512
m.impl("quantize_fp8_per_row", quantize_fp8_per_row_meta);
496513
m.impl("quantize_fp8_per_col", quantize_fp8_per_col_meta);
497514
m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped_meta);
515+
m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic_meta);
498516
#ifndef USE_ROCM
499517
m.impl("i8i8bf16", i8i8bf16_meta);
500518
m.impl("f8f8bf16", f8f8bf16_meta);

0 commit comments

Comments
 (0)