-
Notifications
You must be signed in to change notification settings - Fork 166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Float8 Inference #574
Comments
I think triton backend for FP8 matmul can be useful. Benchmarking on 4070Ti SUPER, triton is still quite a bit slower than CuBLAS - 25%. Using
But it can still be useful, such as to support AxisWise scaling for sm89. Benchmark script (triton matmul copied from the official tutorial. Not sure why I cannot use torch.compile() to generate FP8 triton kernel.)import torch
import triton
import triton.language as tl
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def get_cuda_autotune_config():
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
# Good config for fp8 inputs.
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4)
]
@triton.autotune(
configs=get_cuda_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
c = accumulator.to(tl.bfloat16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul_triton(a, b):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
)
return c
M, K, N = 16384, 8192, 1280
A = torch.randn(M, K).cuda().to(torch.float8_e4m3fn)
B = torch.randn(N, K).cuda().to(torch.float8_e4m3fn)
scale_a = torch.tensor([1.]).cuda()
scale_b = torch.tensor([1.]).cuda()
A_bf16 = A.bfloat16()
B_bf16 = B.bfloat16()
time_bf16 = triton.testing.do_bench(lambda: torch.mm(A_bf16, B_bf16.T))
print(f"CuBLAS BF16: {time_bf16:.4f}")
time_triton = triton.testing.do_bench(lambda: matmul_triton(A, B.T))
print(f"Triton FP8: {time_triton:.4f}")
time_cublas = triton.testing.do_bench(lambda: torch._scaled_mm(A, B.T, scale_a, scale_b))
print(f"CuBLAS FP8: {time_cublas:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
RFC: Float8 Inference
Objective
We want to provide an easy mechanism to utilize FP8 in inference, and see both decreased memory usage and performance gains on hardware that supports native FP8 computation. We would like the API to require minimal model rewrites. We also want it to be configurable in such a way as to provide multiple levels of scaling granularity with their own accuracy/performance trade-offs. The solution should be composable with other inference components in the PyTorch ecosystem:
This solution is targeting server-side GPU inference. It is not currently focused on supporting edge or CPU inference.
Background
Float8 inference can be used to reduce memory usage and improve computational efficiency. By using FP8 instead of higher precision formats, we can achieve significant speedups and memory savings with minimal loss in accuracy. The memory saving is unique to float8 inference as opposed to float8 training. For inference, the weights are static and thus do not need the higher precision during weight updates.
Proposal
Float8InferenceLinear Module
We propose a new
Float8InferenceLinear
module that extendsnn.Linear
with Float8 quantization capabilities:This module handles the quantization of weights and activations based on the provided configuration. This module was landed in this PR: #287. It is designed to replace a pre-trained nn.Linear module in an existing model and statically convert the weight to FP8. By default, we do this in E4M3 format.
It provides configuration options via the
QuantConfig
class to encapsulate various quantization settings:The main configuration options are captured in the
ActivationCasting
enum:_scaled_mm
.Top-level API
We propose a top-level API for quantizing models:
This function allows users to easily convert their models to use Float8 inference.
An example of how this can be used on a Hugging Face model can be found in this PR in TorchAO
Proposed Extensions
Scaling Granularity
Currently, we only support TensorWise scaling. Concretely, this is done by calculating the
max(abs(Tensor))
and utilizing this value to compute the Float8Tensor scale. However, due to outlier values in activations, this can have large quantization error. As well, calculating a global reduction across the entire activation tensor can be relatively slow.Therefore, we want to add the option to specify different types of scaling granularities.
The
scaling_granularity
parameter determines how scales are computed:TensorWise
: A single scale is computed for the entire tensor.AxisWise
: Scales are computed along a specified axis of the tensor.We recently added Axiswise scaling support to
_scaled_mm
in this PyTorch PR: #128989. As well, I have a worked PR stack showing how Axiswise scaling can be implemented in Float8Experimental: pytorch-labs/float8_experimental#305We would like to continue generalizing the scaling granularity to:
GroupWise
: Similar to AxisWise but instead of one scale per axis, we have multipleBlockWise
: All other forms can be seen as special cases of this. Scale per 2D tile of activation and weight.Design Details
Tensor Subclass Usage
The implementation utilizes Float8Tensors to encapsulate the scaling as well as dispatch to
_scaled_mm
instead oftorch.mm
. This is not the only way this could be implemented. Since we do not have the autograd constraint that backpropagating grads must match the dtype of the tensor in the forward, we are free to desugar the Float8Tensor into its constituents, store them on the module, and use them in the forward. However, using the tensor subclass, allows us to re-use similar components between training and inference, but it does have downsides:Performance
Compile
As with the rest of this project, we heavily rely on the compile stack to generate efficient and fused casting code. We do actually see some performance gains on heavily compute-bound models, but in general, we require torch.compile for competitive performance.
Export
Currently, it is not possible to run torch.export + AOTI with the publicly available export APIs. However, this PR: pytorch-labs/float8_experimental#295 demonstrates that it is possible. There are plans this half for the export team to make export of nn.modules with subclasses as weights available in the public API.
Limitations and Future Work
Extend ScalingGranularity
bfloat16
._scaled_mm
only supports TensorWise and AxisWise scaling; work is needed to extend to other granularities.Composition with other dtypes/techniques
AffineQuantized
,NF4Tensor
,int4_weight_only
, etc. It is possible that users will want to compose different types within the same model. Work is needed here to ensure that the top-level UX is expressive enough to handle these cases.Standardize on TorchAO APIs
Non-H100 GPU Support
_scaled_mm
's TensorWise support is enabled on sm89, and MI300x + GPUs. However, the AxisWise kernel is based on Cutlass and is not currently supported on any GPU besides H100.Dynamic Shapes
Other Module Support
-While Linear weights take up the majority of model size and compute, other operations can still be amenable to the compute gains from FP8
Examples
Open Questions
_scaled_mm
kernel can support this use case. Is that a problem?Conclusion
This RFC proposes significant enhancements to Float8 inference in PyTorch, aiming to provide a more flexible, efficient, and user-friendly framework for quantization. By supporting various scaling granularities and quantization strategies, we can cater to a wide range of use cases and potentially unlock substantial performance improvements for many models.
Additional Details
Utilizing this script: https://gist.github.com/drisspg/d7ae2134fbb6ca369c4817853c3352fa
The text was updated successfully, but these errors were encountered: