forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add explicit multiply-reduce GEMM kernel
- Loading branch information
1 parent
df4c4d3
commit 855695c
Showing
2 changed files
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. | ||
# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch). | ||
@triton.jit | ||
def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, | ||
stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, | ||
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, | ||
BIAS: tl.constexpr, EVEN_K: tl.constexpr): | ||
pid = tl.program_id(axis=0) | ||
pid_z = tl.program_id(1) | ||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
if GROUP_SIZE_M == 1: | ||
pid_m = pid // num_pid_n | ||
pid_n = pid % num_pid_n | ||
else: | ||
num_pid_in_group = GROUP_SIZE_M * num_pid_n | ||
group_id = pid // num_pid_in_group | ||
first_pid_m = group_id * GROUP_SIZE_M | ||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | ||
pid_m = first_pid_m + (pid % group_size_m) | ||
pid_n = (pid % num_pid_in_group) // group_size_m | ||
if SPLIT_K == 1: | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
else: | ||
offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) | ||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) | ||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) | ||
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak | ||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn | ||
if BIAS: | ||
bias_ptrs = bias_ptr + offs_am * stride_bias | ||
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) | ||
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 | ||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) | ||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): | ||
if EVEN_K: | ||
a = tl.load(a_ptrs) | ||
b = tl.load(b_ptrs) | ||
else: | ||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | ||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||
# TODO: `min_dot_size` isn't always (16, 16, 16), it's in fact architecture | ||
# dependent and data type dependent. Please check this source file: | ||
# `${TRITON_ROOT}/third_party/amd/backend/compiler.py` | ||
# (look for `min_dot_size` function). | ||
# Q: How can we implement this `if` statement more precisely? | ||
# A: Maybe we can use an approach similar to `EVEN_K`, in which the kernel caller | ||
# inspect all relevant information and provides a Boolean flag for the kernel. | ||
if (BLOCK_SIZE_M < 16 or BLOCK_SIZE_N < 16) or BLOCK_SIZE_K < 16: | ||
# Explicit multiply-reduce for small block sizes. | ||
a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype) | ||
b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype) | ||
accumulator += tl.sum(a * b, axis=1) | ||
else: | ||
# Triton dot product for other block sizes. | ||
accumulator += tl.dot(a, b) | ||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak | ||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk | ||
c = accumulator.to(c_ptr.type.element_ty) | ||
if BIAS: | ||
c += bias[:, None] | ||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
if SPLIT_K == 1: | ||
tl.store(c_ptrs, c, mask=c_mask) | ||
else: | ||
tl.atomic_add(c_ptrs, c, mask=c_mask) |