Skip to content

Commit

Permalink
add kpack and matrix_instr_nonkdim for stream-k implementation (#562)
Browse files Browse the repository at this point in the history
* add kpack and  matrix_instr_nonkdim

* add to do lists for autotune
  • Loading branch information
xiaohuguo2023 authored Apr 19, 2024
1 parent 4df2ddc commit b9e5290
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 53 deletions.
32 changes: 19 additions & 13 deletions python/perf-kernels/03-matrix-multiplication-stream-k.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int):
def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int):
device = a.device

assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported"
Expand All @@ -167,7 +167,7 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M:
total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
iters_per_tile = triton.cdiv(K, BLK_K)
GROUP_M = 0 # 0 to disable swizzling
GROUP_M = 4 # 0 to disable swizzling
total_tiles = total_blocks_M * total_blocks_N

if total_programs_streamk > 0: # Stream-K
Expand Down Expand Up @@ -232,6 +232,8 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M:
num_stages=num_stages,
num_warps=num_warps,
waves_per_eu = waves_per_eu,
matrix_instr_nonkdim = mfmaInstrSize,
kpack = kpack,
)
if matmul._debug:
print(f"{kk.n_regs} registers used, {kk.n_spills} spills")
Expand All @@ -241,29 +243,33 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M:
return c

@staticmethod
def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2):
return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu)
def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2, mfmaInstrSize = 16, kpack = 1):
return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, mfmaInstrSize = mfmaInstrSize, kpack = kpack)

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3)

m, n, k = 4864, 4096, 8256 # some problem size to test
#m, n, k = 4864, 4096, 8256 # some problem size to test
#m, n, k = 4096, 4096, 8192 # some problem size to test
#m, n, k = 8192, 8192, 8192 # some problem size to test
m, n, k = 6912, 768, 256 # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)
BLK_M = 256
BLK_N = 128
BLK_K = 32
BLK_M = 64
BLK_N = 64
BLK_K = 64
two_tiles = 'True'
num_stages = 0
num_warps = 8
num_warps = 4
waves_per_eu = 0
mfmaInstrSize = 16
kpack = 2

matmul.set_debug(True)
C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)
C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)
#exit(0)
matmul.set_debug(False)
expected = A @ B
Expand All @@ -277,13 +283,13 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=1
triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B))
print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu))
triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu))
triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu))
triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

exit(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,36 +122,40 @@ def get_full_tuning_space():
if not tuning_full_space:
return configs

block_mn_range = [128, 256]
block_mn_range = [64, 128, 256]
block_k_range = [16, 32, 64]
num_warps_range = [1, 2, 4, 8]
# group_m_range = [0, 1, 2, 4, 8]
group_m_range = [4, 8]
group_m_range = [0, 4, 8]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
waves_per_eu_range =[0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for num_waves_per_eu in waves_per_eu_range:
for num_stages in num_stage_range:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu}, num_stages=num_stages, num_warps=num_warps))
for num_stages in num_stage_range:
for num_waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}, num_stages=num_stages, num_warps=num_warps,))

return configs

#To do: we need update the default autotune configuration once we go through the whole performance test sets.
@triton.autotune(
configs= get_full_tuning_space() if tuning_full_space else [
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
# prune_configs_by={
Expand Down Expand Up @@ -264,7 +268,7 @@ def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int):
def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int):
def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk):
total_blocks_M = triton.cdiv(M, BLOCK_M)
total_blocks_N = triton.cdiv(N, BLOCK_N)
Expand Down Expand Up @@ -370,8 +374,8 @@ def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_progra
return c

@staticmethod
def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2):
return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu)
def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2, mfmaInstrSize = 16, kpack = 1):
return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, mfmaInstrSize = mfmaInstrSize, kpack = kpack)

# ---------------------------------------------------------------------------
# Example and Benchmark
Expand All @@ -380,21 +384,24 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK
perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3)

#m, n, k = 1792, 7424, 4864 # some problem size to test
m, n, k = 8192, 8192, 8192 # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float32)
B = torch.randn(k, n, device="cuda", dtype=torch.float32)
#m, n, k = 8192, 8192, 8192 # some problem size to test
m, n, k = 4096, 4096, 8192 # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)
#A = torch.ones((m, k), device="cuda", dtype=torch.float16)
#B = torch.ones((k, n), device="cuda", dtype=torch.float16)
BLOCK_M = 256
BLOCK_N = 128
BLOCK_K = 16
BLOCK_N = 256
BLOCK_K = 64
two_tiles = True
num_stages = 0
num_warps = 4
waves_per_eu = 2
num_warps = 8
waves_per_eu = 0
mfmaInstrSize = 16
kpack = 1

matmul.set_debug(True)
C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu)
C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)
matmul.set_debug(False)
expected = A @ B

Expand All @@ -407,15 +414,15 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK
triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B))
print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu))
triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")
print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})')

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu))
triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")
print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})')

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu))
triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")
print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})')

Expand All @@ -425,6 +432,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK
# ---------------------------------------------------------------------------

# tried to reproduce the tests described in the paper
perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3)
num_samples = 1000 # 32768
step = 256
values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist()
Expand Down
Loading

0 comments on commit b9e5290

Please sign in to comment.