diff --git a/.gitignore b/.gitignore index 1a572f80a8b0..72ad33529186 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ docs/sg_execution_times.rst # Vim *.swp + +# macOS +.DS_Store diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index d046fccdaae1..1ce34f9a26ab 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -360,9 +360,9 @@ def matmul_tma_persistent(a, b): @triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # +def matmul_kernel_device_tma_persistent(workspace_ptr, # + tiles_per_update: tl.constexpr, # a_ptr, b_ptr, c_ptr, # - ready_flag, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # @@ -377,24 +377,24 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - if start_pid == 0: - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.atomic_xchg(ready_flag, 1, sem="release") - else: - flag = tl.full([], 0, tl.int32) - while flag != 1: - flag = tl.atomic_add(ready_flag, 0, sem="acquire") - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + TMA_SIZE: tl.constexpr = 128 + workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + c_desc_ptr = workspace_base + 2 * TMA_SIZE + + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K], + element_ty=a_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, + load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K], + element_ty=b_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N], + element_ty=c_ptr.dtype.element_ty) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: @@ -402,6 +402,7 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # tile_id = start_pid - NUM_SMS ki = -1 + ni = -1 pid_m = 0 pid_n = 0 @@ -415,6 +416,27 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) if ki == 0: + ni += 1 + + # Simulate a grouped gemm + if ni == tiles_per_update: + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, + load_size=[BLOCK_SIZE_M, + BLOCK_SIZE_K], global_size=[M, K], + element_ty=a_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, + load_size=[BLOCK_SIZE_N, + BLOCK_SIZE_K], global_size=[N, K], + element_ty=b_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, + load_size=[BLOCK_SIZE_M, + BLOCK_SIZE_N], global_size=[M, N], + element_ty=c_ptr.dtype.element_ty) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + ni = 0 + tile_id += NUM_SMS group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M @@ -435,10 +457,11 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # c = accumulator.to(dtype) tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -def matmul_device_tma_persistent(a, b): +def matmul_device_tma_persistent(a, b, tiles_per_update): # Autotuner does not work with TMA. Use manual config. configs = { torch.float8_e4m3fn: { @@ -459,15 +482,15 @@ def matmul_device_tma_persistent(a, b): dtype = a.dtype c = torch.zeros((M, N), device=a.device, dtype=dtype) - a_desc, b_desc, c_desc = [torch.empty(128, dtype=torch.uint8, device="cuda") for _ in range(3)] - ready_flag = torch.zeros((), dtype=torch.int32, device="cuda") NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + tma_size = 128 + workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_device_tma_persistent[grid]( - a_desc, b_desc, c_desc, # + workspace, # + tiles_per_update, # a, b, c, # - ready_flag, # M, N, K, # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # @@ -507,7 +530,7 @@ def torch_matmul(a, b): return c -def bench(K, dtype, reps=10): +def bench(K, dtype, tiles_per_update, reps=10): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) @@ -535,14 +558,18 @@ def bench(K, dtype, reps=10): for _ in range(reps): matmul_tma_persistent(a, b) time.sleep(0.01) - for _ in range(reps): - matmul_device_tma_persistent(a, b) - time.sleep(0.01) + flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" + with proton.scope( + f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}", + {"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}): + for _ in range(reps): + matmul_device_tma_persistent(a, b, tiles_per_update) + time.sleep(0.01) proton.deactivate(0) -def validate(M, N, K, dtype): +def validate(M, N, K, dtype, tiles_per_update): a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() @@ -552,7 +579,7 @@ def validate(M, N, K, dtype): naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None - device_tma_persistent_result = matmul_device_tma_persistent(a, b) if supports_tma() else None + device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None if torch_result is not None: naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), @@ -586,6 +613,13 @@ def validate(M, N, K, dtype): parser.add_argument("-K", type=int, required=False, default=512) parser.add_argument("--K_range", type=int, nargs=2) parser.add_argument("--K_step", type=int, default=512) + parser.add_argument( + "--tiles_per_update", + type=int, + default=1, + help= + "Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel", + ) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() @@ -601,10 +635,10 @@ def validate(M, N, K, dtype): torch.manual_seed(0) - validate(32, 32, 32, dtype) - validate(8192, 8192, 512, dtype) + validate(32, 32, 32, dtype, args.tiles_per_update) + validate(8192, 8192, 512, dtype, args.tiles_per_update) proton.start("matmul", hook="triton") for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K, dtype) + bench(K, dtype, args.tiles_per_update) proton.finalize()