Skip to content

Commit

Permalink
[Tutorial] Use per-SM descriptors in matmul tutorial (#4682)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Sep 16, 2024
1 parent 09675e5 commit a26848c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ docs/sg_execution_times.rst

# Vim
*.swp

# macOS
.DS_Store
102 changes: 68 additions & 34 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ def matmul_tma_persistent(a, b):


@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
def matmul_kernel_device_tma_persistent(workspace_ptr, #
tiles_per_update: tl.constexpr, #
a_ptr, b_ptr, c_ptr, #
ready_flag, #
M, N, K, #
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
Expand All @@ -377,31 +377,32 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

if start_pid == 0:
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.atomic_xchg(ready_flag, 1, sem="release")
else:
flag = tl.full([], 0, tl.int32)
while flag != 1:
flag = tl.atomic_add(ready_flag, 0, sem="acquire")
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
TMA_SIZE: tl.constexpr = 128
workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
c_desc_ptr = workspace_base + 2 * TMA_SIZE

tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1

tile_id = start_pid - NUM_SMS
ki = -1
ni = -1

pid_m = 0
pid_n = 0
Expand All @@ -415,6 +416,27 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
ni += 1

# Simulate a grouped gemm
if ni == tiles_per_update:
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M,
BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N,
BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M,
BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
ni = 0

tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
Expand All @@ -435,10 +457,11 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
c = accumulator.to(dtype)

tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)


def matmul_device_tma_persistent(a, b):
def matmul_device_tma_persistent(a, b, tiles_per_update):
# Autotuner does not work with TMA. Use manual config.
configs = {
torch.float8_e4m3fn: {
Expand All @@ -459,15 +482,15 @@ def matmul_device_tma_persistent(a, b):
dtype = a.dtype

c = torch.zeros((M, N), device=a.device, dtype=dtype)
a_desc, b_desc, c_desc = [torch.empty(128, dtype=torch.uint8, device="cuda") for _ in range(3)]
ready_flag = torch.zeros((), dtype=torch.int32, device="cuda")
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
tma_size = 128
workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda")

grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_device_tma_persistent[grid](
a_desc, b_desc, c_desc, #
workspace, #
tiles_per_update, #
a, b, c, #
ready_flag, #
M, N, K, #
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], #
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], #
Expand Down Expand Up @@ -507,7 +530,7 @@ def torch_matmul(a, b):
return c


def bench(K, dtype, reps=10):
def bench(K, dtype, tiles_per_update, reps=10):
M = 8192
N = 8192
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
Expand Down Expand Up @@ -535,14 +558,18 @@ def bench(K, dtype, reps=10):
for _ in range(reps):
matmul_tma_persistent(a, b)
time.sleep(0.01)
for _ in range(reps):
matmul_device_tma_persistent(a, b)
time.sleep(0.01)
flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops"
with proton.scope(
f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}",
{"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}):
for _ in range(reps):
matmul_device_tma_persistent(a, b, tiles_per_update)
time.sleep(0.01)

proton.deactivate(0)


def validate(M, N, K, dtype):
def validate(M, N, K, dtype, tiles_per_update):
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
Expand All @@ -552,7 +579,7 @@ def validate(M, N, K, dtype):
naive_result = matmul(a, b.T)
persistent_result = matmul_persistent(a, b.T)
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
device_tma_persistent_result = matmul_device_tma_persistent(a, b) if supports_tma() else None
device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None

if torch_result is not None:
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
Expand Down Expand Up @@ -586,6 +613,13 @@ def validate(M, N, K, dtype):
parser.add_argument("-K", type=int, required=False, default=512)
parser.add_argument("--K_range", type=int, nargs=2)
parser.add_argument("--K_step", type=int, default=512)
parser.add_argument(
"--tiles_per_update",
type=int,
default=1,
help=
"Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel",
)
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
args = parser.parse_args()

Expand All @@ -601,10 +635,10 @@ def validate(M, N, K, dtype):

torch.manual_seed(0)

validate(32, 32, 32, dtype)
validate(8192, 8192, 512, dtype)
validate(32, 32, 32, dtype, args.tiles_per_update)
validate(8192, 8192, 512, dtype, args.tiles_per_update)

proton.start("matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype)
bench(K, dtype, args.tiles_per_update)
proton.finalize()

0 comments on commit a26848c

Please sign in to comment.