Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit multiply-reduce GEMM kernel #621

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/perf-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ fp32, bf16 and f8 (both e5m2 and e4m3) datatypes.
## `03-matrix-multiplication-stream-k.py`

This script contains the GEMM kernel that implements [stream-k](https://arxiv.org/abs/2301.03598)

## `multreduce_matmul_kernel.py`

Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. Such
small block sizes aren't natively supported by `tl.dot` operator.

Despite being numerically correct, this kernel performed worse than a corresponding GEMM kernel that
used `tl.dot` with minimum block size equal to $16$.
45 changes: 45 additions & 0 deletions python/perf-kernels/multreduce_matmul_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import triton
import triton.language as tl


# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes.
# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch).
@triton.jit
def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if BIAS:
bias_ptrs = bias_ptr + offs_am * stride_bias
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0)
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# Dot product implemented as explicit multiply-reduce:
a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype)
b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype)
accumulator += tl.sum(a * b, axis=1)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(c_ptr.type.element_ty)
if BIAS:
c += bias[:, None]
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
Loading