diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index bc99b0e1c..fad22b163 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -632,23 +632,15 @@ def quantize_fixed_nk(self, x, w): # Stack inputs into groups. x = torch.stack(xp).contiguous() w = torch.stack(w).contiguous() - # Allocate output tensor. - output = torch.empty( - [x.shape[0], x.shape[1], w.shape[1]], - dtype=torch.bfloat16, - device=x.device, - ) # View these unified tensors as lists of tensors. x = [xi.squeeze() for xi in x.split(1, dim=0)] w = [wi.squeeze() for wi in w.split(1, dim=0)] - output = [o.squeeze() for o in output.split(1, dim=0)] # Return processed tensors. return ( x, w, torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), - output, ) def quantize(self, x, w): @@ -664,23 +656,18 @@ def quantize(self, x, w): if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: return self.quantize_fixed_nk(x, w) - output = [ - torch.empty(m, n, device=x[0].device, dtype=torch.bfloat16) - for m, n in zip(m_values, n_values) - ] m_values = None - return x, w, m_values, output + return x, w, m_values - def compute(self, x, w, m_values, _): - return torch.ops.fbgemm.bf16bf16bf16_grouped( - x, - w, - zero_start_index_M=m_values, - ) + def compute(self, x, w, m_values): + if m_values is None: + return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w) + else: + return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values) def quantize_and_compute(self, x, w): - x, w, m_values, output = self.quantize(x, w) - return self.compute(x, w, m_values, output) + x, w, m_values = self.quantize(x, w) + return self.compute(x, w, m_values) @property def name(self) -> str: diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip index 8e5a18920..54a60a009 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip @@ -290,7 +290,6 @@ at::Tensor get_grouped_kernel_args( std::vector bf16bf16bf16_grouped( at::TensorList A, at::TensorList B, - std::optional zero_start_index_M = std::nullopt, std::optional> output = std::nullopt) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. @@ -333,30 +332,16 @@ std::vector bf16bf16bf16_grouped( Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); } } else { - // Two modes for allocating output. When m_values is provided, we need - // the output tensor to be contiguous and can assume M, N, and K are the - // same across groups. Otherwise, we can allocate each output separately. - if (zero_start_index_M.has_value()) { - int M = A[0].size(0); - int N = B[0].size(0); - // Fill output with zeros to simplify integration. This prevents nans from - // showing up in the tensor. - at::Tensor Y_full = - at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - Y = at::unbind(Y_full, 0); - } else { - for (int i = 0; i < group_count; i++) { - int M = A[i].size(0); - int N = B[i].size(0); - Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16))); - } + for (int i = 0; i < group_count; i++) { + int M = A[i].size(0); + int N = B[i].size(0); + Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16))); } } // Prepare kernel arguments by copying them to the proper device location. at::Tensor kernel_args = get_grouped_kernel_args( - A, B, zero_start_index_M, Y); + A, B, std::nullopt, Y); // Perform shape lookup to find best kernel. // We use the largest of each shape for heuristics. @@ -373,4 +358,62 @@ std::vector bf16bf16bf16_grouped( return selected_kernel(A, B, kernel_args, Y); } +at::Tensor bf16bf16bf16_grouped_dynamic( + at::TensorList A, + at::TensorList B, + at::Tensor zero_start_index_M) { + // Check that input datatypes are valid. + // First confirm that there are the same number of groups in all inputs. + TORCH_CHECK( + A.size() == B.size(), + "A and B must have the same number of groups."); + int group_count = A.size(); + // Iterate over inputs and check they are valid. + for (at::Tensor a : A) { + TORCH_CHECK(a.is_cuda() && a.is_contiguous()); + TORCH_CHECK(a.dim() == 2, "Inputs must be 2D."); + TORCH_CHECK( + a.dtype() == at::kBFloat16, + "Inputs must be type bfloat16."); + } + for (at::Tensor b : B) { + TORCH_CHECK(b.is_cuda() && b.is_contiguous()); + TORCH_CHECK(b.dim() == 2, "Inputs must be 2D."); + TORCH_CHECK( + b.dtype() == at::kBFloat16, + "Inputs must be type bfloat16."); + } + + std::vector Y; + int M = A[0].size(0); + int N = B[0].size(0); + // Fill output with zeros to simplify integration. This prevents nans from + // showing up in the tensor. + at::Tensor Y_full = + at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + Y = at::unbind(Y_full, 0); + + // Prepare kernel arguments by copying them to the proper device location. + at::Tensor kernel_args = get_grouped_kernel_args( + A, B, zero_start_index_M, Y); + + // Perform shape lookup to find best kernel. + // We use the largest of each shape for heuristics. + int MaxM = 0; + int MaxN = 0; + int MaxK = 0; + for (int i = 0; i < group_count; i++) { + MaxM = max(MaxM, A[i].size(0)); + MaxN = max(MaxN, B[i].size(0)); + MaxK = max(MaxK, A[i].size(1)); + } + GroupedKernel selected_kernel = + grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + // Run kernel to populate output. + selected_kernel(A, B, kernel_args, Y); + // Return coalesced view of output tensor. + return Y_full; +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index 803681b14..ab38fb548 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -261,6 +261,7 @@ template < std::vector bf16bf16bf16_grouped_impl( at::TensorList X, // BF16 at::TensorList W, // BF16 + std::vector output_tensor, std::optional zero_start_index_M) { int problem_count = X.size(); TORCH_CHECK(W.size() == problem_count); @@ -273,27 +274,6 @@ std::vector bf16bf16bf16_grouped_impl( using GroupedGemmConfigs = GroupedGemmBF16Args:: GroupedGemmConfigs; - constexpr int AlignmentA = - 128 / - cutlass::sizeof_bits< - GroupedGemmBF16Args::ElementInputA>::value; // Alignment of A matrix - // in units of elements - // (up to 16 bytes) - - constexpr int AlignmentB = - 128 / - cutlass::sizeof_bits< - GroupedGemmBF16Args::ElementInputB>::value; // Alignment of B matrix - // in units of elements - // (up to 16 bytes) - - constexpr int AlignmentD = - 128 / - cutlass::sizeof_bits< - GroupedGemmBF16Args::ElementOutput>::value; // Alignment of C matrix - // in units of elements - // (up to 16 bytes) - at::Tensor output_args = at::empty({problem_count}, X[0].options().dtype(at::kLong)); @@ -313,28 +293,6 @@ std::vector bf16bf16bf16_grouped_impl( int stride_buf_offset = problem_count * 2 * sizeof(int64_t) + problem_shape_size; - // Two modes for allocating output. When m_values is provided, we need - // the output tensor to be contiguous and can assume M, N, and K are the - // same across groups. Otherwise, we can allocate each output separately. - std::vector output_tensor; - if (zero_start_index_M.has_value()) { - int M = X[0].size(0); - int N = W[0].size(0); - // Fill output with zeros to simplify integration. This prevents nans from - // showing up in the tensor. - at::Tensor output_full = - at::zeros({problem_count, M, N}, X[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - output_tensor = at::unbind(output_full, 0); - } else { - for (int i = 0; i < problem_count; i++) { - int M = X[i].size(0); - int N = W[i].size(0); - output_tensor.push_back( - at::empty({M, N}, X[i].options().dtype(at::kBFloat16))); - } - } - TORCH_CHECK( !zero_start_index_M.has_value() || zero_start_index_M->dtype() == at::kLong, @@ -505,27 +463,58 @@ std::vector bf16bf16bf16_grouped_impl( std::vector dispatch_bf16_grouped_kernel( at::TensorList x_group, // BF16 at::TensorList w_group, // BF16 + std::vector output_tensor, std::optional zero_start_index_M) { KernelMode kernel = get_grouped_kernel_mode(x_group, w_group); if (kernel == KernelMode::Small) { return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>( - x_group, w_group, zero_start_index_M); + x_group, w_group, output_tensor, zero_start_index_M); } else if (kernel == KernelMode::Large) { return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>( - x_group, w_group, zero_start_index_M); + x_group, w_group, output_tensor, zero_start_index_M); } else { return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>( - x_group, w_group, zero_start_index_M); + x_group, w_group, output_tensor, zero_start_index_M); } } std::vector bf16bf16bf16_grouped( at::TensorList x_group, // BF16 at::TensorList w_group, // BF16 - std::optional zero_start_index_M = std::nullopt, std::optional> output = std::nullopt) { TORCH_CHECK(!output.has_value(), "Preallocated output not yet supported."); - return dispatch_bf16_grouped_kernel(x_group, w_group, zero_start_index_M); + // Initialize output tensor. + int problem_count = x_group.size(); + std::vector output_tensor; + for (int i = 0; i < problem_count; i++) { + int M = x_group[i].size(0); + int N = w_group[i].size(0); + output_tensor.push_back( + at::empty({M, N}, x_group[i].options().dtype(at::kBFloat16))); + } + return dispatch_bf16_grouped_kernel( + x_group, w_group, output_tensor, std::nullopt); +} + +at::Tensor bf16bf16bf16_grouped_dynamic( + at::TensorList x_group, // BF16 + at::TensorList w_group, // BF16 + at::Tensor zero_start_index_M) { + std::vector output_tensor; + int problem_count = x_group.size(); + int M = x_group[0].size(0); + int N = w_group[0].size(0); + // Fill output with zeros to simplify integration. This prevents nans from + // showing up in the tensor. + at::Tensor output_full = at::zeros( + {problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + output_tensor = at::unbind(output_full, 0); + // Run kernel to populate output tensor. + dispatch_bf16_grouped_kernel( + x_group, w_group, output_tensor, zero_start_index_M); + // Return coalesced view of output. + return output_full; } #else @@ -533,12 +522,19 @@ std::vector bf16bf16bf16_grouped( std::vector bf16bf16bf16_grouped( at::TensorList /* x_group */, // BF16 at::TensorList /* w_group */, // BF16 - std::optional /* zero_start_index_M */, std::optional> /* output */) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } +at::Tensor bf16bf16bf16_grouped_dynamic( + at::TensorList /* x_group */, // BF16 + at::TensorList /* w_group */, // BF16 + at::Tensor /* zero_start_index_M */) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 2039fd091..acea22f3c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -64,8 +64,11 @@ std::vector f8f8bf16_grouped( std::vector bf16bf16bf16_grouped( at::TensorList X, at::TensorList W, - std::optional zero_start_index_M = std::nullopt, std::optional> output = std::nullopt); +at::Tensor bf16bf16bf16_grouped_dynamic( + at::TensorList X, + at::TensorList W, + at::Tensor zero_start_index_M); at::Tensor f8f8bf16_rowwise( at::Tensor XQ, at::Tensor WQ, @@ -195,7 +198,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { get_f8f8bf16_rowwise_grouped_kernels); #endif m.def( - "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None, Tensor[](a!)? output=None) -> Tensor[]"); + "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]"); + m.def( + "bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor zero_start_index_M) -> Tensor"); m.def( "f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor"); m.def( @@ -248,6 +253,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("quantize_fp8_per_row", quantize_fp8_per_row); m.impl("quantize_fp8_per_col", quantize_fp8_per_col); m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); + m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); @@ -273,6 +279,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("quantize_fp8_per_row", quantize_fp8_per_row); m.impl("quantize_fp8_per_col", quantize_fp8_per_col); m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); + m.impl("bf16bf16bf16_grouped_dyanmic", bf16bf16bf16_grouped_dynamic); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); @@ -474,7 +481,6 @@ std::vector f8f8bf16_grouped_meta( std::vector bf16bf16bf16_grouped_meta( at::TensorList X, at::TensorList W, - std::optional /* zero_start_index_M = std::nullopt */, std::optional> /* output = std::nullopt */ ) { std::vector Y; @@ -486,6 +492,17 @@ std::vector bf16bf16bf16_grouped_meta( return Y; } +at::Tensor bf16bf16bf16_grouped_dynamic_meta( + at::TensorList X, + at::TensorList W, + at::Tensor /* zero_start_index_M = std::nullopt */) { + int G = X.size(); + int M = X[0].size(0); + int N = W[0].size(0); + at::Tensor Y = at::empty({G, M, N}, X[0].options().dtype(at::kBFloat16)); + return Y; +} + TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("f8f8bf16_blockwise", f8f8bf16_blockwise_meta); m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta); @@ -495,6 +512,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("quantize_fp8_per_row", quantize_fp8_per_row_meta); m.impl("quantize_fp8_per_col", quantize_fp8_per_col_meta); m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped_meta); + m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic_meta); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16", f8f8bf16_meta); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index d54e2be0a..e2126b5a0 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -827,28 +827,30 @@ def test_fp8_grouped_gemm( ) # BF16 grouped gemm kernel + bf16_args = ( + [x_group, w_group, zero_start_index_M] + if use_padding_zeros + else [x_group, w_group] + ) + bf16_op = ( + torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic + if use_padding_zeros + else torch.ops.fbgemm.bf16bf16bf16_grouped + ) if use_cudagraph: # warmup - torch.ops.fbgemm.bf16bf16bf16_grouped( - x_group, - w_group, - zero_start_index_M if use_padding_zeros else None, - ) + bf16_op(*bf16_args) # With cudagraph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - y_bf16_group = torch.ops.fbgemm.bf16bf16bf16_grouped( - x_group, - w_group, - zero_start_index_M if use_padding_zeros else None, - ) + y_bf16_group = bf16_op(*bf16_args) g.replay() else: - y_bf16_group = torch.ops.fbgemm.bf16bf16bf16_grouped( - x_group, - w_group, - zero_start_index_M if use_padding_zeros else None, - ) + y_bf16_group = bf16_op(*bf16_args) + + # View output as list if needed. + if not isinstance(y_bf16_group, (tuple, list)): + y_bf16_group = torch.unbind(y_bf16_group) # BF16 loopover gemm reference y_group_ref = []