Skip to content

Commit

Permalink
Replace auto-tuning with hard-coding
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Jan 10, 2025
1 parent 17a7b47 commit 0d220a5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 201 deletions.
31 changes: 5 additions & 26 deletions attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


def arrangement(q, k, v, o):
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True)

def arrange_q_or_o(input):
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1))
Expand Down Expand Up @@ -63,34 +63,11 @@ def application(q, k, v, o):
def attention(q, k, v):
o = torch.empty_like(q, dtype=v.dtype)

attention_kernel(q, k, v, o)
attention_kernel(q, k, v, o, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64)

return o


@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8
),
],
key=["EMB_DIM"],
)
@triton.jit
def triton_attention_kernel(
q_ptr,
Expand Down Expand Up @@ -214,6 +191,8 @@ def grid(meta):
*o.stride(),
SEQ_LEN=seq_len,
EMB_DIM=emb_dim,
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=64,
)

return o
Expand Down
93 changes: 7 additions & 86 deletions conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,96 +38,13 @@ def conv2d(input, filter):

output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype)

conv2d_kernel(input, filter, output)
conv2d_kernel(
input, filter, output, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
)

return output


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
],
key=["n", "c", "h", "w", "k", "r", "s"],
)
@triton.jit
def triton_conv2d_kernel(
input_ptr,
Expand Down Expand Up @@ -270,6 +187,10 @@ def grid(meta):
*input.stride(),
*filter.stride(),
*output.stride(),
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=8,
)

return output
Expand Down
97 changes: 8 additions & 89 deletions matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


def arrangement(lhs, rhs, output):
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", constexpr=True)

output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))

Expand Down Expand Up @@ -47,96 +47,11 @@ def matmul(lhs, rhs):
(lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16
)

matmul_kernel(lhs, rhs, output)
matmul_kernel(lhs, rhs, output, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64)

return output


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
],
key=["m", "n", "k"],
)
@triton.jit
def triton_matmul_kernel(
lhs_ptr,
Expand Down Expand Up @@ -220,6 +135,10 @@ def grid(meta):
rhs.stride(1),
output.stride(0),
output.stride(1),
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=8,
)

return output
Expand Down

0 comments on commit 0d220a5

Please sign in to comment.