Skip to content

Commit

Permalink
Split bf16 grouped gemm into dynamic and static versions. (#3544)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#629

This diff updates bf16 grouped gemm in the same way as we did for fp8 in D67810956. This allows a bit more performance depending on the use case.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D67813627
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jan 5, 2025
1 parent fd04e6a commit a368b9a
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 108 deletions.
29 changes: 8 additions & 21 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,23 +632,15 @@ def quantize_fixed_nk(self, x, w):
# Stack inputs into groups.
x = torch.stack(xp).contiguous()
w = torch.stack(w).contiguous()
# Allocate output tensor.
output = torch.empty(
[x.shape[0], x.shape[1], w.shape[1]],
dtype=torch.bfloat16,
device=x.device,
)
# View these unified tensors as lists of tensors.
x = [xi.squeeze() for xi in x.split(1, dim=0)]
w = [wi.squeeze() for wi in w.split(1, dim=0)]
output = [o.squeeze() for o in output.split(1, dim=0)]

# Return processed tensors.
return (
x,
w,
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
output,
)

def quantize(self, x, w):
Expand All @@ -664,23 +656,18 @@ def quantize(self, x, w):
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
return self.quantize_fixed_nk(x, w)

output = [
torch.empty(m, n, device=x[0].device, dtype=torch.bfloat16)
for m, n in zip(m_values, n_values)
]
m_values = None
return x, w, m_values, output
return x, w, m_values

def compute(self, x, w, m_values, _):
return torch.ops.fbgemm.bf16bf16bf16_grouped(
x,
w,
zero_start_index_M=m_values,
)
def compute(self, x, w, m_values):
if m_values is None:
return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w)
else:
return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values)

def quantize_and_compute(self, x, w):
x, w, m_values, output = self.quantize(x, w)
return self.compute(x, w, m_values, output)
x, w, m_values = self.quantize(x, w)
return self.compute(x, w, m_values)

@property
def name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ at::Tensor get_grouped_kernel_args(
std::vector<at::Tensor> bf16bf16bf16_grouped(
at::TensorList A,
at::TensorList B,
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
Expand Down Expand Up @@ -333,30 +332,16 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16.");
}
} else {
// Two modes for allocating output. When m_values is provided, we need
// the output tensor to be contiguous and can assume M, N, and K are the
// same across groups. Otherwise, we can allocate each output separately.
if (zero_start_index_M.has_value()) {
int M = A[0].size(0);
int N = B[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
at::Tensor Y_full =
at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16));
// Split the output into groups.
Y = at::unbind(Y_full, 0);
} else {
for (int i = 0; i < group_count; i++) {
int M = A[i].size(0);
int N = B[i].size(0);
Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16)));
}
for (int i = 0; i < group_count; i++) {
int M = A[i].size(0);
int N = B[i].size(0);
Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16)));
}
}

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = get_grouped_kernel_args(
A, B, zero_start_index_M, Y);
A, B, std::nullopt, Y);

// Perform shape lookup to find best kernel.
// We use the largest of each shape for heuristics.
Expand All @@ -373,4 +358,62 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
return selected_kernel(A, B, kernel_args, Y);
}

at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList A,
at::TensorList B,
at::Tensor zero_start_index_M) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
TORCH_CHECK(
A.size() == B.size(),
"A and B must have the same number of groups.");
int group_count = A.size();
// Iterate over inputs and check they are valid.
for (at::Tensor a : A) {
TORCH_CHECK(a.is_cuda() && a.is_contiguous());
TORCH_CHECK(a.dim() == 2, "Inputs must be 2D.");
TORCH_CHECK(
a.dtype() == at::kBFloat16,
"Inputs must be type bfloat16.");
}
for (at::Tensor b : B) {
TORCH_CHECK(b.is_cuda() && b.is_contiguous());
TORCH_CHECK(b.dim() == 2, "Inputs must be 2D.");
TORCH_CHECK(
b.dtype() == at::kBFloat16,
"Inputs must be type bfloat16.");
}

std::vector<at::Tensor> Y;
int M = A[0].size(0);
int N = B[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
at::Tensor Y_full =
at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16));
// Split the output into groups.
Y = at::unbind(Y_full, 0);

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = get_grouped_kernel_args(
A, B, zero_start_index_M, Y);

// Perform shape lookup to find best kernel.
// We use the largest of each shape for heuristics.
int MaxM = 0;
int MaxN = 0;
int MaxK = 0;
for (int i = 0; i < group_count; i++) {
MaxM = max(MaxM, A[i].size(0));
MaxN = max(MaxN, B[i].size(0));
MaxK = max(MaxK, A[i].size(1));
}
GroupedKernel selected_kernel =
grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
// Run kernel to populate output.
selected_kernel(A, B, kernel_args, Y);
// Return coalesced view of output tensor.
return Y_full;
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ template <
std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
at::TensorList X, // BF16
at::TensorList W, // BF16
std::vector<at::Tensor> output_tensor,
std::optional<at::Tensor> zero_start_index_M) {
int problem_count = X.size();
TORCH_CHECK(W.size() == problem_count);
Expand All @@ -273,27 +274,6 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
using GroupedGemmConfigs = GroupedGemmBF16Args::
GroupedGemmConfigs<TB_M, TB_N, TB_K, TBS_M, TBS_N, TBS_K, PONG>;

constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<
GroupedGemmBF16Args::ElementInputA>::value; // Alignment of A matrix
// in units of elements
// (up to 16 bytes)

constexpr int AlignmentB =
128 /
cutlass::sizeof_bits<
GroupedGemmBF16Args::ElementInputB>::value; // Alignment of B matrix
// in units of elements
// (up to 16 bytes)

constexpr int AlignmentD =
128 /
cutlass::sizeof_bits<
GroupedGemmBF16Args::ElementOutput>::value; // Alignment of C matrix
// in units of elements
// (up to 16 bytes)

at::Tensor output_args =
at::empty({problem_count}, X[0].options().dtype(at::kLong));

Expand All @@ -313,28 +293,6 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
int stride_buf_offset =
problem_count * 2 * sizeof(int64_t) + problem_shape_size;

// Two modes for allocating output. When m_values is provided, we need
// the output tensor to be contiguous and can assume M, N, and K are the
// same across groups. Otherwise, we can allocate each output separately.
std::vector<at::Tensor> output_tensor;
if (zero_start_index_M.has_value()) {
int M = X[0].size(0);
int N = W[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
at::Tensor output_full =
at::zeros({problem_count, M, N}, X[0].options().dtype(at::kBFloat16));
// Split the output into groups.
output_tensor = at::unbind(output_full, 0);
} else {
for (int i = 0; i < problem_count; i++) {
int M = X[i].size(0);
int N = W[i].size(0);
output_tensor.push_back(
at::empty({M, N}, X[i].options().dtype(at::kBFloat16)));
}
}

TORCH_CHECK(
!zero_start_index_M.has_value() ||
zero_start_index_M->dtype() == at::kLong,
Expand Down Expand Up @@ -505,40 +463,78 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_impl(
std::vector<at::Tensor> dispatch_bf16_grouped_kernel(
at::TensorList x_group, // BF16
at::TensorList w_group, // BF16
std::vector<at::Tensor> output_tensor,
std::optional<at::Tensor> zero_start_index_M) {
KernelMode kernel = get_grouped_kernel_mode(x_group, w_group);
if (kernel == KernelMode::Small) {
return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>(
x_group, w_group, zero_start_index_M);
x_group, w_group, output_tensor, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>(
x_group, w_group, zero_start_index_M);
x_group, w_group, output_tensor, zero_start_index_M);
} else {
return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>(
x_group, w_group, zero_start_index_M);
x_group, w_group, output_tensor, zero_start_index_M);
}
}

std::vector<at::Tensor> bf16bf16bf16_grouped(
at::TensorList x_group, // BF16
at::TensorList w_group, // BF16
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
TORCH_CHECK(!output.has_value(), "Preallocated output not yet supported.");
return dispatch_bf16_grouped_kernel(x_group, w_group, zero_start_index_M);
// Initialize output tensor.
int problem_count = x_group.size();
std::vector<at::Tensor> output_tensor;
for (int i = 0; i < problem_count; i++) {
int M = x_group[i].size(0);
int N = w_group[i].size(0);
output_tensor.push_back(
at::empty({M, N}, x_group[i].options().dtype(at::kBFloat16)));
}
return dispatch_bf16_grouped_kernel(
x_group, w_group, output_tensor, std::nullopt);
}

at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList x_group, // BF16
at::TensorList w_group, // BF16
at::Tensor zero_start_index_M) {
std::vector<at::Tensor> output_tensor;
int problem_count = x_group.size();
int M = x_group[0].size(0);
int N = w_group[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
at::Tensor output_full = at::zeros(
{problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16));
// Split the output into groups.
output_tensor = at::unbind(output_full, 0);
// Run kernel to populate output tensor.
dispatch_bf16_grouped_kernel(
x_group, w_group, output_tensor, zero_start_index_M);
// Return coalesced view of output.
return output_full;
}

#else

std::vector<at::Tensor> bf16bf16bf16_grouped(
at::TensorList /* x_group */, // BF16
at::TensorList /* w_group */, // BF16
std::optional<at::Tensor> /* zero_start_index_M */,
std::optional<std::vector<at::Tensor>> /* output */) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList /* x_group */, // BF16
at::TensorList /* w_group */, // BF16
at::Tensor /* zero_start_index_M */) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

#endif

} // namespace fbgemm_gpu
24 changes: 21 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ std::vector<at::Tensor> f8f8bf16_grouped(
std::vector<at::Tensor> bf16bf16bf16_grouped(
at::TensorList X,
at::TensorList W,
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
std::optional<std::vector<at::Tensor>> output = std::nullopt);
at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList X,
at::TensorList W,
at::Tensor zero_start_index_M);
at::Tensor f8f8bf16_rowwise(
at::Tensor XQ,
at::Tensor WQ,
Expand Down Expand Up @@ -195,7 +198,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
get_f8f8bf16_rowwise_grouped_kernels);
#endif
m.def(
"bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None, Tensor[](a!)? output=None) -> Tensor[]");
"bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]");
m.def(
"bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor zero_start_index_M) -> Tensor");
m.def(
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor");
m.def(
Expand Down Expand Up @@ -248,6 +253,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
m.impl("quantize_fp8_per_col", quantize_fp8_per_col);
m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped);
m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic);
#ifndef USE_ROCM
m.impl("i8i8bf16", i8i8bf16);
m.impl("f8f8bf16", f8f8bf16);
Expand All @@ -273,6 +279,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
m.impl("quantize_fp8_per_col", quantize_fp8_per_col);
m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped);
m.impl("bf16bf16bf16_grouped_dyanmic", bf16bf16bf16_grouped_dynamic);
#ifndef USE_ROCM
m.impl("i8i8bf16", i8i8bf16);
m.impl("f8f8bf16", f8f8bf16);
Expand Down Expand Up @@ -474,7 +481,6 @@ std::vector<at::Tensor> f8f8bf16_grouped_meta(
std::vector<at::Tensor> bf16bf16bf16_grouped_meta(
at::TensorList X,
at::TensorList W,
std::optional<at::Tensor> /* zero_start_index_M = std::nullopt */,
std::optional<std::vector<at::Tensor>> /* output = std::nullopt */
) {
std::vector<at::Tensor> Y;
Expand All @@ -486,6 +492,17 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_meta(
return Y;
}

at::Tensor bf16bf16bf16_grouped_dynamic_meta(
at::TensorList X,
at::TensorList W,
at::Tensor /* zero_start_index_M = std::nullopt */) {
int G = X.size();
int M = X[0].size(0);
int N = W[0].size(0);
at::Tensor Y = at::empty({G, M, N}, X[0].options().dtype(at::kBFloat16));
return Y;
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("f8f8bf16_blockwise", f8f8bf16_blockwise_meta);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta);
Expand All @@ -495,6 +512,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("quantize_fp8_per_row", quantize_fp8_per_row_meta);
m.impl("quantize_fp8_per_col", quantize_fp8_per_col_meta);
m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped_meta);
m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic_meta);
#ifndef USE_ROCM
m.impl("i8i8bf16", i8i8bf16_meta);
m.impl("f8f8bf16", f8f8bf16_meta);
Expand Down
Loading

0 comments on commit a368b9a

Please sign in to comment.