Skip to content

Commit

Permalink
FP8 Grouped Gemm Optimization (pytorch#3655)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D69072529

fbshipit-source-id: b90b4d1c76bf813f94f36cd21a55118442f62b38
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 7, 2025
1 parent dced756 commit d564c8c
Show file tree
Hide file tree
Showing 79 changed files with 6,720 additions and 4,714 deletions.
21 changes: 9 additions & 12 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row(
torch.Tensor: fp8 scaled tensor.
torch.Tensor: reciprocal scale tensor per row.
"""
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
# Get constant values.
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
num_rows = a.numel() // a.shape[-1]
Expand Down Expand Up @@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row(
USE_INT64=use_int64,
)

return a_fp8, a_scale
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])


@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
Expand Down Expand Up @@ -2514,17 +2521,7 @@ def quantize_fp8_row(
logger.info("Triton does not support cpu, falling back to torch ops.")
use_triton = False
if use_triton:
assert (
a.dim() <= 4
), "Only up to 4 dimension input tensor is supported if use_triton is True"
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
# else use pytorch implementation.
if not output_device:
output_device = a.device
Expand Down
103 changes: 0 additions & 103 deletions fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py

This file was deleted.

22 changes: 8 additions & 14 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
quantize_fp8_block,
quantize_fp8_row,
scale_fp8_row,
triton_quantize_fp8_row,
)
from tinygemm.utils import group_quantize_tensor

Expand Down Expand Up @@ -553,38 +554,31 @@ def preprocess(self, x, w):
def quantize(self, x, wq, w_scale, m_values=None):
# Handle case where inputs are explicitly grouped and non-sparse.
if isinstance(x, (tuple, list)):
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
return xq, wq, x_scale, w_scale, m_values
# Otherwise inputs are unified tensors and sparse.
else:
B = x.shape[0]
xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values)
xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
x_scale = x_scale.view(B, -1)
return xq, wq, x_scale, w_scale, m_values

def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
def compute(self, xq, wq, x_scale, w_scale, m_values):
if m_values is None:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq,
wq,
x_scale,
w_scale,
kernel_name=kernel_name,
)
else:
# Break tensor into groups, simulates what is done e2e.
B = xq.shape[0]
xq_group = [xq[i, :, :] for i in range(B)]
x_scale_group = [x_scale[i, :] for i in range(B)]
wq_group = [wq[i, :, :] for i in range(B)]
w_scale_group = [w_scale[i, :] for i in range(B)]
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)

def quantize_and_compute(self, x, wq, w_scale, m_values=None):
Expand Down
Loading

0 comments on commit d564c8c

Please sign in to comment.