diff --git a/attention.py b/attention.py index 545b51c..a115d7f 100644 --- a/attention.py +++ b/attention.py @@ -8,8 +8,8 @@ def arrangement(q, k, v, o): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True) def arrange_q_or_o(input): arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)) @@ -63,34 +63,11 @@ def application(q, k, v, o): def attention(q, k, v): o = torch.empty_like(q, dtype=v.dtype) - attention_kernel(q, k, v, o) + attention_kernel(q, k, v, o, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64) return o -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8 - ), - ], - key=["EMB_DIM"], -) @triton.jit def triton_attention_kernel( q_ptr, @@ -214,6 +191,8 @@ def grid(meta): *o.stride(), SEQ_LEN=seq_len, EMB_DIM=emb_dim, + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, ) return o diff --git a/conv2d.py b/conv2d.py index e6740c1..f17062d 100644 --- a/conv2d.py +++ b/conv2d.py @@ -38,96 +38,13 @@ def conv2d(input, filter): output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype) - conv2d_kernel(input, filter, output) + conv2d_kernel( + input, filter, output, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64 + ) return output -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["n", "c", "h", "w", "k", "r", "s"], -) @triton.jit def triton_conv2d_kernel( input_ptr, @@ -270,6 +187,10 @@ def grid(meta): *input.stride(), *filter.stride(), *output.stride(), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, ) return output diff --git a/matmul.py b/matmul.py index d5aa445..74cca86 100644 --- a/matmul.py +++ b/matmul.py @@ -7,9 +7,9 @@ def arrangement(lhs, rhs, output): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) - BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True) + BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", constexpr=True) output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) @@ -47,96 +47,11 @@ def matmul(lhs, rhs): (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16 ) - matmul_kernel(lhs, rhs, output) + matmul_kernel(lhs, rhs, output, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64) return output -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["m", "n", "k"], -) @triton.jit def triton_matmul_kernel( lhs_ptr, @@ -220,6 +135,10 @@ def grid(meta): rhs.stride(1), output.stride(0), output.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, ) return output