From dc5b6604c16e8612ec7c9f8ab735cf32c6cdeac6 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Tue, 30 Jul 2024 02:54:30 +0000 Subject: [PATCH 01/22] Microbenchmark for fused moe --- .../kernels/benchmark_mixtral_moe_rocm.py | 154 ++++-------------- .../layers/fused_moe/fused_moe.py | 22 ++- 2 files changed, 53 insertions(+), 123 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 63080eaf2f11c..dd7f0f030869a 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -17,30 +17,11 @@ def main(args): - os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID os.environ["HIP_FORCE_DEV_KERNARG"] = "1" os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - os.environ["OPTIMIZE_EPILOGUE"] = "1" for bs in [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, ]: run_grid(bs, model=args.model, TP=args.TP) @@ -49,21 +30,22 @@ def main(args): def get_full_tuning_space(): configs = [] - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] + block_m_range = [16] + block_n_range = [64] + block_k_range = [128] # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] + num_warps_range = [4] + group_m_range = [1] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to # other values in the future num_stage_range = [0] waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] + matrix_instr_nonkdim_range = [16] + kpack_range = [2] - for block_m in block_mn_range: - for block_n in block_mn_range: + for block_m in block_m_range: + for block_n in block_n_range: for block_k in block_k_range: for num_warps in num_warps_range: for group_m in group_m_range: @@ -91,77 +73,8 @@ def get_full_tuning_space(): ## Utilize method from rocm/Triton tuning script def prune_configs(M, N, K, configs): - pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = 1 # config.get("SPLIT_K") - GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - return pruned_configs + return configs def union_of_list_of_dicts(l1, l2): @@ -195,7 +108,7 @@ def run_grid(bs, model, TP): num_calls = 100 num_warmup_trials = 1 - num_trials = 1 + num_trials = 10 full_configs = get_full_tuning_space() M1 = bs * 2 @@ -396,29 +309,34 @@ def run_timing( use_fp8=False, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - None, # a2_scale - None, # w2_scale - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 - else tl.float16), - use_fp8=False, - ) + # ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + # invoke_fused_moe_kernel( + # intermediate_cache2, + # w2, + # intermediate_cache3, + # None, # a2_scale + # None, # w2_scale + # topk_weights, + # topk_ids, + # sorted_token_ids, + # expert_ids, + # num_tokens_post_padded, + # True, + # 1, + # config, + # compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 + # else tl.float16), + # use_fp8=False, + # ) end_event.record() end_event.synchronize() + print(f"config = {config}") + print(f"sorted token ids = {sorted_token_ids}") + print(f"sorted token ids shape = {sorted_token_ids.shape}") + print(f"expert ids = {expert_ids}") + print(f"num_tokens_post_padded = {num_tokens_post_padded}") dur_ms = start_event.elapsed_time(end_event) / num_calls return dur_ms diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7a3c6ec773358..3803a28566064 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -26,7 +26,7 @@ def fused_moe_kernel( topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, - num_tokens_post_padded_ptr, + num_tokens_post_padded, # Matrix dimensions N, K, @@ -98,7 +98,6 @@ def fused_moe_kernel( # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) @@ -132,7 +131,7 @@ def fused_moe_kernel( a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, + other=0.0 ) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, @@ -211,7 +210,7 @@ def moe_align_block_size( device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), + expert_ids = torch.zeros((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), @@ -251,6 +250,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + print("==================================================================================") + print(f"grid = {triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N'])}") + print(f"sorted token ids shape = {sorted_token_ids.shape}") + print(f"num_token_ids_post_padded = {num_tokens_post_padded}") + print(f"config in moe = {config}") + print(f"A shape = {A.shape}") + print(f"B shape = {B.shape}") + print(f"num valid tokens = {topk_ids.numel()}") + print(f"sorted_token_ids = {sorted_token_ids[0]}") + print(f"expert_ids = {expert_ids}") + print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") + + fused_moe_kernel[grid]( A, B, @@ -260,7 +272,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, topk_weights, sorted_token_ids, expert_ids, - num_tokens_post_padded, + num_tokens_post_padded[0].item(), B.shape[1], B.shape[2], sorted_token_ids.shape[0], From d5564d3317c7e740b06dc7490eb6ba836ded1dce Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Tue, 30 Jul 2024 18:18:04 +0000 Subject: [PATCH 02/22] Revert broken load --- vllm/model_executor/layers/fused_moe/fused_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3803a28566064..1f3e70fed8c88 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -26,7 +26,7 @@ def fused_moe_kernel( topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, - num_tokens_post_padded, + num_tokens_post_padded_ptr, # Matrix dimensions N, K, @@ -98,6 +98,7 @@ def fused_moe_kernel( # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) @@ -272,7 +273,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, topk_weights, sorted_token_ids, expert_ids, - num_tokens_post_padded[0].item(), + num_tokens_post_padded, B.shape[1], B.shape[2], sorted_token_ids.shape[0], From 1934f71d89a12f0d8ca7fa45a6644464e2f2a19a Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 1 Aug 2024 16:59:33 +0000 Subject: [PATCH 03/22] Add persistent kernel from Jingning --- .../kernels/benchmark_mixtral_moe_rocm.py | 22 +- .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/fused_moe.py | 425 +++++++++++++++++- 3 files changed, 436 insertions(+), 15 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index dd7f0f030869a..8047eceecb149 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -13,6 +13,7 @@ from vllm._C import ops from vllm.model_executor.layers.fused_moe import (get_config_file_name, invoke_fused_moe_kernel, + invoke_fused_moe_persistent_kernel, moe_align_block_size) @@ -30,11 +31,11 @@ def main(args): def get_full_tuning_space(): configs = [] - block_m_range = [16] - block_n_range = [64] + block_m_range = [32] + block_n_range = [128] block_k_range = [128] # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [4] + num_warps_range = [8] group_m_range = [1] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to @@ -290,7 +291,7 @@ def run_timing( start_event.record() for i in range(num_calls): - invoke_fused_moe_kernel( + invoke_fused_moe_persistent_kernel( hidden_states, w1, intermediate_cache1, @@ -332,11 +333,14 @@ def run_timing( end_event.record() end_event.synchronize() - print(f"config = {config}") - print(f"sorted token ids = {sorted_token_ids}") - print(f"sorted token ids shape = {sorted_token_ids.shape}") - print(f"expert ids = {expert_ids}") - print(f"num_tokens_post_padded = {num_tokens_post_padded}") + # print(f"intermediate 0 shape = {intermediate_cache1.shape}") + # print(f"intermediate 1 shape = {intermediate_cache2.shape}") + # print(f"intermediate 2 shape = {intermediate_cache3.shape}") + # print(f"config = {config}") + # print(f"sorted token ids = {sorted_token_ids}") + # print(f"sorted token ids shape = {sorted_token_ids.shape}") + # print(f"expert ids = {expert_ids}") + # print(f"num_tokens_post_padded = {num_tokens_post_padded}") dur_ms = start_event.elapsed_time(end_event) / num_calls return dur_ms diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 851ed919ae3b8..5b78ffd10d5b7 100755 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,8 +1,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, - invoke_fused_moe_kernel, moe_align_block_size) + invoke_fused_moe_kernel, invoke_fused_moe_persistent_kernel, moe_align_block_size) __all__ = [ "fused_moe", "fused_topk", "fused_experts", "get_config_file_name", - "invoke_fused_moe_kernel", "moe_align_block_size" + "invoke_fused_moe_kernel", "invoke_fused_moe_persistent_kernel", "moe_align_block_size" ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1f3e70fed8c88..acec248505c90 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -165,6 +165,349 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +@triton.jit +def fused_moe_persistent_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + This is the persistent version of the fused_moe kernel. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Simply compute how many iterations each persistent block needs to do + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_token = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) + token_mask = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int1) + + # load runtime constant outside the loop + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # compute when it reaches the invalid region + pid_m = 0 + tile_id2 = start_pid - NUM_SMS + tile_counter = -1 + while pid_m * BLOCK_SIZE_M < num_tokens_post_padded: + tile_counter += 1 + tile_id2 += NUM_SMS + group_id = tile_id2 // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) + + # print("tile_counter: ", tile_counter) + for _ in range(0, k_tiles * tile_counter): + # Increment ki or loop back to the start of each tile + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + # Prologue of each tile + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + # compute the base pointer of A, B for each tile + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + # Loop through K dimension + # Compute A, B pointer + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + # Load A, B and dot product + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & + (offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K), + other=0.0) + b = tl.load( + b_ptrs, + mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, + other=0.0) + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + + # Epilogue of each tile + if ki == k_tiles - 1: + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=tl.float32) + + +@triton.jit +def fused_moe_persistent_kernelV2( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + This is the persistent version of the fused_moe kernel. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Simply compute how many iterations each persistent block needs to do + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # num_tiles = num_pid_m * num_pid_n + tile_id = start_pid + + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_token = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) + # token_mask = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int1) + + # Load tile-invariant runtime constant + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Compute how many tiles are outside the padding region + num_pid_in_group = GROUP_SIZE_M * num_pid_n + pid_m = 0 + tile_id2 = start_pid - NUM_SMS + num_valid_tiles = -1 + while pid_m * BLOCK_SIZE_M < num_tokens_post_padded: + num_valid_tiles += 1 + tile_id2 += NUM_SMS + group_id = tile_id2 // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) + + # print("tile_counter: ", tile_counter) + for _ in range(0, num_valid_tiles): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + # Compute the mask + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + # Compute the A pointer + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + # Compute the B pointer + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0 + ) + # We accumulate along the K dimension. + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + # advance tile_id + tile_id += NUM_SMS + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -262,7 +605,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, print(f"sorted_token_ids = {sorted_token_ids[0]}") print(f"expert_ids = {expert_ids}") print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") - + fused_moe_kernel[grid]( A, @@ -293,6 +636,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ) +def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if not use_fp8: + assert A_scale is None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + grid = lambda META: (min( + NUM_SMS, + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * + triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) + ), ) + + # print("==================================================================================") + # print(f"grid = NUM_SMS: {min(NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N']) )}") + # print(f"sorted token ids shape = {sorted_token_ids.shape}") + # print(f"num_token_ids_post_padded = {num_tokens_post_padded}") + # print(f"config in moe = {config}") + # print(f"A shape = {A.shape}") + # print(f"B shape = {B.shape}") + # print(f"num valid tokens = {topk_ids.numel()}") + # print(f"sorted_token_ids = {sorted_token_ids[0]}") + # print(f"expert_ids = {expert_ids}") + # print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") + + + fused_moe_persistent_kernelV2[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + NUM_SMS=NUM_SMS, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8, + **config, + ) + + def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" @@ -408,6 +821,7 @@ def fused_experts(hidden_states: torch.Tensor, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, + "num_stages": 1, } if M <= E: @@ -416,6 +830,7 @@ def fused_experts(hidden_states: torch.Tensor, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "num_stages": 1, } intermediate_cache1 = torch.empty( @@ -439,7 +854,8 @@ def fused_experts(hidden_states: torch.Tensor, compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) - invoke_fused_moe_kernel(hidden_states, + invoke_fused_moe_persistent_kernel(hidden_states, + # invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1_scale, @@ -457,7 +873,8 @@ def fused_experts(hidden_states: torch.Tensor, ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, + invoke_fused_moe_persistent_kernel(intermediate_cache2, + # invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, @@ -540,4 +957,4 @@ def fused_moe( w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale) + a2_scale=a2_scale) \ No newline at end of file From e01cd34e477bdb27308ec0e296f4432f57b1c801 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 1 Aug 2024 22:10:09 +0000 Subject: [PATCH 04/22] Config issues --- .../kernels/benchmark_mixtral_moe_rocm.py | 21 +++++++++------- .../layers/fused_moe/fused_moe.py | 24 ++++++++++--------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 8047eceecb149..9c03d11912e36 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -33,7 +33,7 @@ def get_full_tuning_space(): block_m_range = [32] block_n_range = [128] - block_k_range = [128] + block_k_range = [256] # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] num_warps_range = [8] group_m_range = [1] @@ -289,8 +289,10 @@ def run_timing( start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + print(f"Starting run...") start_event.record() for i in range(num_calls): + print(f"i = {i}, num_calls = {num_calls}") invoke_fused_moe_persistent_kernel( hidden_states, w1, @@ -309,6 +311,7 @@ def run_timing( else tl.float16), use_fp8=False, ) + print(f"i = {i}, num_calls = {num_calls}") # ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -333,14 +336,14 @@ def run_timing( end_event.record() end_event.synchronize() - # print(f"intermediate 0 shape = {intermediate_cache1.shape}") - # print(f"intermediate 1 shape = {intermediate_cache2.shape}") - # print(f"intermediate 2 shape = {intermediate_cache3.shape}") - # print(f"config = {config}") - # print(f"sorted token ids = {sorted_token_ids}") - # print(f"sorted token ids shape = {sorted_token_ids.shape}") - # print(f"expert ids = {expert_ids}") - # print(f"num_tokens_post_padded = {num_tokens_post_padded}") + print(f"intermediate 0 shape = {intermediate_cache1.shape}") + print(f"intermediate 1 shape = {intermediate_cache2.shape}") + print(f"intermediate 2 shape = {intermediate_cache3.shape}") + print(f"config = {config}") + print(f"sorted token ids = {sorted_token_ids}") + print(f"sorted token ids shape = {sorted_token_ids.shape}") + print(f"expert ids = {expert_ids}") + print(f"num_tokens_post_padded = {num_tokens_post_padded}") dur_ms = start_event.elapsed_time(end_event) / num_calls return dur_ms diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index acec248505c90..09913a66d662d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -663,17 +663,18 @@ def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torc triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) ), ) - # print("==================================================================================") - # print(f"grid = NUM_SMS: {min(NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N']) )}") - # print(f"sorted token ids shape = {sorted_token_ids.shape}") - # print(f"num_token_ids_post_padded = {num_tokens_post_padded}") - # print(f"config in moe = {config}") - # print(f"A shape = {A.shape}") - # print(f"B shape = {B.shape}") - # print(f"num valid tokens = {topk_ids.numel()}") - # print(f"sorted_token_ids = {sorted_token_ids[0]}") - # print(f"expert_ids = {expert_ids}") - # print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") + print("==================================================================================") + print(f"grid = NUM_SMS: {min(NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N']) )}") + print(f"sorted token ids shape = {sorted_token_ids.shape}") + print(f"num_token_ids_post_padded = {num_tokens_post_padded}") + print(f"config in moe = {config}") + print(f"A shape = {A.shape}") + print(f"B shape = {B.shape}") + print(f"num valid tokens = {topk_ids.numel()}") + print(f"sorted_token_ids = {sorted_token_ids[0]}") + print(f"expert_ids = {expert_ids}") + print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") + print("Calling persistent kernel") fused_moe_persistent_kernelV2[grid]( @@ -704,6 +705,7 @@ def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torc use_fp8=use_fp8, **config, ) + print("Exiting kernel") def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: From 24317bb9a6303926613226f1375876973c74bfe5 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 2 Aug 2024 00:01:44 +0000 Subject: [PATCH 05/22] Add configs for persistent kernel --- ...=8,N=14336,device_name=AMD_Radeon_Graphics.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json index 9de6d6a479184..54eed795101d8 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -77,11 +77,11 @@ "kpack": 2 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, - "num_warps": 4, + "num_warps": 8, "num_stages": 0, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, @@ -89,10 +89,10 @@ }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 0, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, @@ -100,7 +100,7 @@ }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, @@ -230,4 +230,4 @@ "matrix_instr_nonkdim": 16, "kpack": 2 } -} +} \ No newline at end of file From 669614215392c33527b37416fc0dd2b7b691d85e Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Tue, 6 Aug 2024 13:36:50 +0000 Subject: [PATCH 06/22] Add padding and correct block size --- .../kernels/benchmark_mixtral_moe_rocm.py | 23 ++++++++----------- .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 9c03d11912e36..d1dcf3af62384 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -33,7 +33,7 @@ def get_full_tuning_space(): block_m_range = [32] block_n_range = [128] - block_k_range = [256] + block_k_range = [128] # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] num_warps_range = [8] group_m_range = [1] @@ -198,7 +198,7 @@ def run_timing( model_intermediate_size: int, config, ) -> float: - shard_intermediate_size = model_intermediate_size // tp_size + shard_intermediate_size = 14464-64 #model_intermediate_size // tp_size hidden_states = torch.rand( (bs, d_model), @@ -289,10 +289,8 @@ def run_timing( start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - print(f"Starting run...") start_event.record() for i in range(num_calls): - print(f"i = {i}, num_calls = {num_calls}") invoke_fused_moe_persistent_kernel( hidden_states, w1, @@ -311,7 +309,6 @@ def run_timing( else tl.float16), use_fp8=False, ) - print(f"i = {i}, num_calls = {num_calls}") # ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -336,14 +333,14 @@ def run_timing( end_event.record() end_event.synchronize() - print(f"intermediate 0 shape = {intermediate_cache1.shape}") - print(f"intermediate 1 shape = {intermediate_cache2.shape}") - print(f"intermediate 2 shape = {intermediate_cache3.shape}") - print(f"config = {config}") - print(f"sorted token ids = {sorted_token_ids}") - print(f"sorted token ids shape = {sorted_token_ids.shape}") - print(f"expert ids = {expert_ids}") - print(f"num_tokens_post_padded = {num_tokens_post_padded}") + # print(f"intermediate 0 shape = {intermediate_cache1.shape}") + # print(f"intermediate 1 shape = {intermediate_cache2.shape}") + # print(f"intermediate 2 shape = {intermediate_cache3.shape}") + # print(f"config = {config}") + # print(f"sorted token ids = {sorted_token_ids}") + # print(f"sorted token ids shape = {sorted_token_ids.shape}") + # print(f"expert ids = {expert_ids}") + # print(f"num_tokens_post_padded = {num_tokens_post_padded}") dur_ms = start_event.elapsed_time(end_event) / num_calls return dur_ms diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 09913a66d662d..c75ceb5a31a38 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -688,7 +688,7 @@ def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torc expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - 128, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), From 077cb78a36de6619a495141907acad7b72b24a1d Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 8 Aug 2024 14:49:29 +0000 Subject: [PATCH 07/22] Latest perf --- .../kernels/benchmark_mixtral_moe_rocm.py | 6 +- .../layers/fused_moe/fused_moe.py | 56 +++++++++---------- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index d1dcf3af62384..82773ef329d1c 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -198,7 +198,7 @@ def run_timing( model_intermediate_size: int, config, ) -> float: - shard_intermediate_size = 14464-64 #model_intermediate_size // tp_size + shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( (bs, d_model), @@ -207,7 +207,7 @@ def run_timing( ) w1 = torch.rand( - (num_total_experts, 2 * shard_intermediate_size, d_model), + (num_total_experts, 2 * shard_intermediate_size, d_model+128), device=hidden_states.device, dtype=hidden_states.dtype, ) @@ -232,7 +232,7 @@ def run_timing( assert (hidden_states.shape[0] == gating_output.shape[0] ), "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - 128, "Hidden size mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c75ceb5a31a38..1f20fce4298eb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -468,15 +468,10 @@ def fused_moe_persistent_kernelV2( # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load( - a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0 + a_ptrs ) b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0 + b_ptrs ) # We accumulate along the K dimension. if use_fp8: @@ -594,17 +589,17 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - print("==================================================================================") - print(f"grid = {triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N'])}") - print(f"sorted token ids shape = {sorted_token_ids.shape}") - print(f"num_token_ids_post_padded = {num_tokens_post_padded}") - print(f"config in moe = {config}") - print(f"A shape = {A.shape}") - print(f"B shape = {B.shape}") - print(f"num valid tokens = {topk_ids.numel()}") - print(f"sorted_token_ids = {sorted_token_ids[0]}") - print(f"expert_ids = {expert_ids}") - print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") + # print("==================================================================================") + # print(f"grid = {triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N'])}") + # print(f"sorted token ids shape = {sorted_token_ids.shape}") + # print(f"num_token_ids_post_padded = {num_tokens_post_padded}") + # print(f"config in moe = {config}") + # print(f"A shape = {A.shape}") + # print(f"B shape = {B.shape}") + # print(f"num valid tokens = {topk_ids.numel()}") + # print(f"sorted_token_ids = {sorted_token_ids[0]}") + # print(f"expert_ids = {expert_ids}") + # print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") fused_moe_kernel[grid]( @@ -663,18 +658,18 @@ def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torc triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) ), ) - print("==================================================================================") - print(f"grid = NUM_SMS: {min(NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N']) )}") - print(f"sorted token ids shape = {sorted_token_ids.shape}") - print(f"num_token_ids_post_padded = {num_tokens_post_padded}") - print(f"config in moe = {config}") - print(f"A shape = {A.shape}") - print(f"B shape = {B.shape}") - print(f"num valid tokens = {topk_ids.numel()}") - print(f"sorted_token_ids = {sorted_token_ids[0]}") - print(f"expert_ids = {expert_ids}") - print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") - print("Calling persistent kernel") + # print("==================================================================================") + # print(f"grid = NUM_SMS: {min(NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N']) )}") + # print(f"sorted token ids shape = {sorted_token_ids.shape}") + # print(f"num_token_ids_post_padded = {num_tokens_post_padded}") + # print(f"config in moe = {config}") + # print(f"A shape = {A.shape}") + # print(f"B shape = {B.shape}") + # print(f"num valid tokens = {topk_ids.numel()}") + # print(f"sorted_token_ids = {sorted_token_ids[0]}") + # print(f"expert_ids = {expert_ids}") + # print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") + # print("Calling persistent kernel") fused_moe_persistent_kernelV2[grid]( @@ -705,7 +700,6 @@ def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torc use_fp8=use_fp8, **config, ) - print("Exiting kernel") def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: From 6e70d894327467130bc144d011ffa1e2aac470d8 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 8 Aug 2024 19:58:25 +0000 Subject: [PATCH 08/22] Add masking back --- vllm/model_executor/layers/fused_moe/fused_moe.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3a1c0bd7859c5..07a940f95e4f1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -470,11 +470,14 @@ def fused_moe_persistent_kernelV2( # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load( - a_ptrs - ) + a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) b = tl.load( - b_ptrs - ) + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -685,7 +688,7 @@ def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torc expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2] - 128, + B.shape[2] - padding_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -956,4 +959,4 @@ def fused_moe( w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale) \ No newline at end of file + a2_scale=a2_scale) From 2b94a753256f292b687b2fec8a8cb33ab62bea02 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 8 Aug 2024 20:07:43 +0000 Subject: [PATCH 09/22] Add accuracy test --- benchmarks/test_accuracy.py | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 benchmarks/test_accuracy.py diff --git a/benchmarks/test_accuracy.py b/benchmarks/test_accuracy.py new file mode 100644 index 0000000000000..06c1150f9f20f --- /dev/null +++ b/benchmarks/test_accuracy.py @@ -0,0 +1,44 @@ +from vllm import LLM, SamplingParams +import time + + +def main(): + llm = LLM( + '/data/AI-ModelScope/Mixtral-8x7B-Instruct-v0___1/', + tensor_parallel_size=1, + #quantization="serenity", + dtype='float16', + #swap_space=16, + #enforce_eager=True, + #kv_cache_dtype="fp8", + #quantization="fp8", + #quantized_weights_path="/quantized/quark/llama.safetensors", + #worker_use_ray=True, + #trust_remote_code=True, + #distributed_executor_backend="mp", + ) + batch_size = 5 + max_tokens = 256 + prompt = """The sun is a""" + sampling_params = SamplingParams(temperature=0, + top_p=0.95, + max_tokens=max_tokens) + + start_time = time.perf_counter() + outs = llm.generate([prompt] * batch_size, sampling_params=sampling_params) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + + out_lengths = [len(x.token_ids) for out in outs for x in out.outputs] + num_tokens = sum(out_lengths) + + print( + f"{num_tokens} tokens. {num_tokens / batch_size} on average. {num_tokens / elapsed_time:.2f} tokens/s. {elapsed_time} seconds" + ) + for out in outs: + print("===========") + print(out.outputs[0].text) + + +if __name__ == "__main__": + main() From c406afc8f4e200f4b31bc2fb8afb94d837d34ea5 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 8 Aug 2024 22:13:56 +0000 Subject: [PATCH 10/22] Merge all changes --- vllm/envs.py | 10 + .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/fused_moe.py | 397 ++++-------------- vllm/model_executor/models/mixtral.py | 25 +- 4 files changed, 122 insertions(+), 314 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 739a4792ce078..681885e312145 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -43,6 +43,8 @@ VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = True + VLLM_MOE_SHUFFLE: bool = False + FUSED_MOE_PERSISTENT: bool = False # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -246,6 +248,14 @@ # Pad the weight for moe kernel or not "VLLM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))), + + # shuffle the weight for moe kernel or not + "VLLM_MOE_SHUFFLE": + lambda: bool(int(os.getenv("VLLM_MOE_SHUFFLE", "0"))), + + # User persistent version of fused_moe Triton kernel + "FUSED_MOE_PERSISTENT": + lambda: bool(int(os.getenv("FUSED_MOE_PERSISTENT", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 5b78ffd10d5b7..851ed919ae3b8 100755 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,8 +1,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, - invoke_fused_moe_kernel, invoke_fused_moe_persistent_kernel, moe_align_block_size) + invoke_fused_moe_kernel, moe_align_block_size) __all__ = [ "fused_moe", "fused_topk", "fused_experts", "get_config_file_name", - "invoke_fused_moe_kernel", "invoke_fused_moe_persistent_kernel", "moe_align_block_size" + "invoke_fused_moe_kernel", "moe_align_block_size" ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 07a940f95e4f1..2300edf18d41c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,8 +9,8 @@ import triton.language as tl import vllm._moe_C as moe_kernels -from vllm import _custom_ops as ops from vllm import envs +from vllm import _custom_ops as ops from vllm.logger import init_logger logger = init_logger(__name__) @@ -167,6 +167,9 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) @triton.jit def fused_moe_persistent_kernel( # Pointers to matrices @@ -200,181 +203,7 @@ def fused_moe_persistent_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - NUM_SMS: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - use_fp8: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using - token and expert matrices. - This is the persistent version of the fused_moe kernel. - - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can - be any shape representing batches and K is the feature dimension of - each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is - the number of experts, K is the input feature dimension, and N is - the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the - total number of tokens post padding, topk is the number of times - each token is repeated, and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, - repeated topk times and arranged by the expert index they are - assigned to. - - expert_ids: A tensor containing the indices of the expert for each - block. It determines which expert matrix from B should be used for - each block in A. - This kernel performs the multiplication of a token by its corresponding - expert matrix as determined by `expert_ids`. The sorting of - `sorted_token_ids` by expert index and padding ensures divisibility by - BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix - multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Simply compute how many iterations each persistent block needs to do - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - k_tiles = tl.cdiv(K, BLOCK_SIZE_K) - num_tiles = num_pid_m * num_pid_n - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - pid_m = 0 - pid_n = 0 - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - offs_token = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) - token_mask = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int1) - - # load runtime constant outside the loop - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if use_fp8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) - - # compute when it reaches the invalid region - pid_m = 0 - tile_id2 = start_pid - NUM_SMS - tile_counter = -1 - while pid_m * BLOCK_SIZE_M < num_tokens_post_padded: - tile_counter += 1 - tile_id2 += NUM_SMS - group_id = tile_id2 // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) - - # print("tile_counter: ", tile_counter) - for _ in range(0, k_tiles * tile_counter): - # Increment ki or loop back to the start of each tile - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - # Prologue of each tile - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - # compute the base pointer of A, B for each tile - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - # Loop through K dimension - # Compute A, B pointer - offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = (b_ptr + off_experts * stride_be + - (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) - # Load A, B and dot product - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & - (offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K), - other=0.0) - b = tl.load( - b_ptrs, - mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, - other=0.0) - if use_fp8: - accumulator = tl.dot(a, b, acc=accumulator) - else: - accumulator += tl.dot(a, b) - - # Epilogue of each tile - if ki == k_tiles - 1: - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) - accumulator = accumulator * moe_weight[:, None] - - if use_fp8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) - else: - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + - stride_cn * offs_cn[None, :]) - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), - dtype=tl.float32) - - -@triton.jit -def fused_moe_persistent_kernelV2( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, NUM_SMS: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, @@ -436,7 +265,6 @@ def fused_moe_persistent_kernelV2( group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) - # print("tile_counter: ", tile_counter) for _ in range(0, num_valid_tiles): if GROUP_SIZE_M == 1: pid_m = tile_id // num_pid_n @@ -469,15 +297,21 @@ def fused_moe_persistent_kernelV2( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + if EVEN_K: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0 + ) # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -508,6 +342,7 @@ def fused_moe_persistent_kernelV2( # advance tile_id tile_id += NUM_SMS + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -591,120 +426,73 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - - # print("==================================================================================") - # print(f"grid = {triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N'])}") - # print(f"sorted token ids shape = {sorted_token_ids.shape}") - # print(f"num_token_ids_post_padded = {num_tokens_post_padded}") - # print(f"config in moe = {config}") - # print(f"A shape = {A.shape}") - # print(f"B shape = {B.shape}") - # print(f"num valid tokens = {topk_ids.numel()}") - # print(f"sorted_token_ids = {sorted_token_ids[0]}") - # print(f"expert_ids = {expert_ids}") - # print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") - - - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2] - padding_size, - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8=use_fp8, - **config, - ) - - -def invoke_fused_moe_persistent_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8: bool) -> None: - assert topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 - - if not use_fp8: - assert A_scale is None - assert B_scale is None + if not envs.FUSED_MOE_PERSISTENT: + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8, + **config, + ) else: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) - assert B_scale is not None - - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - grid = lambda META: (min( - NUM_SMS, - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * - triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) - ), ) - - # print("==================================================================================") - # print(f"grid = NUM_SMS: {min(NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], config['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], config['BLOCK_SIZE_N']) )}") - # print(f"sorted token ids shape = {sorted_token_ids.shape}") - # print(f"num_token_ids_post_padded = {num_tokens_post_padded}") - # print(f"config in moe = {config}") - # print(f"A shape = {A.shape}") - # print(f"B shape = {B.shape}") - # print(f"num valid tokens = {topk_ids.numel()}") - # print(f"sorted_token_ids = {sorted_token_ids[0]}") - # print(f"expert_ids = {expert_ids}") - # print(f"num_tokens_post_padded = {num_tokens_post_padded[0]}") - # print("Calling persistent kernel") - - - fused_moe_persistent_kernelV2[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2] - padding_size, - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - NUM_SMS=NUM_SMS, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8=use_fp8, - **config, - ) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + grid = lambda META: (min( + NUM_SMS, + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * + triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) + ), ) + + fused_moe_persistent_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + NUM_SMS=NUM_SMS, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8, + **config, + ) def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @@ -792,8 +580,7 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): # Check constraints. - assert hidden_states.shape[ - 1] == w1.shape[2] - padding_size, "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -809,7 +596,7 @@ def fused_experts(hidden_states: torch.Tensor, config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2] - padding_size, + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) if configs: @@ -856,8 +643,7 @@ def fused_experts(hidden_states: torch.Tensor, compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) - invoke_fused_moe_persistent_kernel(hidden_states, - # invoke_fused_moe_kernel(hidden_states, + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1_scale, @@ -875,8 +661,7 @@ def fused_experts(hidden_states: torch.Tensor, ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_persistent_kernel(intermediate_cache2, - # invoke_fused_moe_kernel(intermediate_cache2, + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ee9db7048f1f6..672f2f28113c4 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -183,15 +183,17 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def process_weights_after_loading(self): # Fp8 is the only case where we need to process after loading. if not self.use_fp8: + w13_ = permute_weight(self.w13_weight.data) + w2_ = permute_weight(self.w2_weight.data) if envs.VLLM_MOE_PADDING: - self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, - (0, 128), "constant", 0), - requires_grad=False) + w13_ = F.pad(w13_, (0, 128), "constant", 0) torch.cuda.empty_cache() - self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, - (0, 128), "constant", 0), - requires_grad=False) + w2_ = F.pad(w2_, (0, 128), "constant", 0) torch.cuda.empty_cache() + self.w13_weight = nn.Parameter(w13_, requires_grad=False) + torch.cuda.empty_cache() + self.w2_weight = nn.Parameter(w2_, requires_grad=False) + torch.cuda.empty_cache() return # If checkpoint is fp16, quantize here. @@ -603,3 +605,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def all_close_1d(x: torch.Tensor) -> bool: assert len(x.shape) == 1 return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + +def permute_weight(x: torch.Tensor) -> torch.Tensor: + x_ = torch.clone(x) + if envs.VLLM_MOE_SHUFFLE: + x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//32, 4, 8) + x_ = x_.permute(0,1,3,4,2,5) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); + return x From 5b92aa6bcbdd56233392ae9d60d73ba7b8a3a98e Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 8 Aug 2024 23:04:54 +0000 Subject: [PATCH 11/22] Fixes --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/models/mixtral.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2300edf18d41c..27962603a31d1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -596,7 +596,7 @@ def fused_experts(hidden_states: torch.Tensor, config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], + configs = get_moe_configs(E, w2.shape[2] - padding_size, "float8" if use_fp8 else None) if configs: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 672f2f28113c4..a5a9bbf021b5c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -607,7 +607,7 @@ def all_close_1d(x: torch.Tensor) -> bool: return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) def permute_weight(x: torch.Tensor) -> torch.Tensor: - x_ = torch.clone(x) + x_ = x if envs.VLLM_MOE_SHUFFLE: x_ = x_.view(x.shape[0], x.shape[1]//16, 16, From 98a31f295b524fe5bb8cf054f872093073ecdfbf Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 9 Aug 2024 15:13:13 +0000 Subject: [PATCH 12/22] Fix shuffling bug --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a5a9bbf021b5c..2a525970028bb 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -615,4 +615,4 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor: x_ = x_.permute(0,1,3,4,2,5) x_ = x_.contiguous() x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); - return x + return x_ From 75031c5daec14b6a759abba093fd9145a835e6b3 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 9 Aug 2024 16:28:20 +0000 Subject: [PATCH 13/22] Fused moe persistent benchmark --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 82773ef329d1c..77fb8d3c966a0 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -13,7 +13,6 @@ from vllm._C import ops from vllm.model_executor.layers.fused_moe import (get_config_file_name, invoke_fused_moe_kernel, - invoke_fused_moe_persistent_kernel, moe_align_block_size) @@ -291,7 +290,7 @@ def run_timing( start_event.record() for i in range(num_calls): - invoke_fused_moe_persistent_kernel( + invoke_fused_moe_kernel( hidden_states, w1, intermediate_cache1, @@ -307,7 +306,7 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8=False, + use_fp8=False ) # ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 7c644595a71753f6d305b6fefc60ccf9d29a64f0 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 9 Aug 2024 18:20:15 +0000 Subject: [PATCH 14/22] Add test configs for mixtral --- tests/kernels/test_moe.py | 50 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2356b9ec18b0d..1dfda2363e8c1 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -52,6 +52,56 @@ def test_fused_moe( torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) +@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +@pytest.mark.parametrize("n", [14336]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_amd_moe_1( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + if n == k: + pytest.skip() + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + + score = torch.randn((m, e), device='cuda', dtype=dtype) + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk) + assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) + + +@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +@pytest.mark.parametrize("n", [4096]) +@pytest.mark.parametrize("k", [14336]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_amd_moe_2( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + if n == k: + pytest.skip() + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + + score = torch.randn((m, e), device='cuda', dtype=dtype) + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk) + assert torch.allclose(triton_output, torch_output, atol=2e-1, rtol=0) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) From b213dfe424a22531307948c893e7f90e72ea4726 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 10 Aug 2024 14:24:34 +0000 Subject: [PATCH 15/22] Use shuffled layout in UT --- tests/kernels/test_moe.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 1dfda2363e8c1..00f14808cb63e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -10,7 +10,17 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.models.mixtral import MixtralMoE - +from vllm import envs + +def permute_weight(x: torch.Tensor) -> torch.Tensor: + x_ = x.clone() + x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//32, 4, 8) + x_ = x_.permute(0,1,3,4,2,5) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); + return x_ def torch_moe(a, w1, w2, score, topk): B, D = a.shape @@ -71,9 +81,12 @@ def test_amd_moe_1( a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + if envs.VLLM_MOE_SHUFFLE: + w1_shuffled = permute_weight(w1.data) + w2_shuffled = permute_weight(w2.data) score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + triton_output = fused_moe(a, w1_shuffled, w2_shuffled, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) From 7df58adb66116e8affa72c07c4c605d01068a0af Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 10 Aug 2024 14:36:48 +0000 Subject: [PATCH 16/22] Fix test bugs --- tests/kernels/test_moe.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 00f14808cb63e..e2e60c2d24abe 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -62,7 +62,8 @@ def test_fused_moe( torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) -@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +#@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +@pytest.mark.parametrize("m", [1, 64, 96, 1000]) @pytest.mark.parametrize("n", [14336]) @pytest.mark.parametrize("k", [4096]) @pytest.mark.parametrize("e", [8]) @@ -91,7 +92,8 @@ def test_amd_moe_1( assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) -@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +#@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +@pytest.mark.parametrize("m", [1, 64, 96, 1000]) @pytest.mark.parametrize("n", [4096]) @pytest.mark.parametrize("k", [14336]) @pytest.mark.parametrize("e", [8]) @@ -110,9 +112,12 @@ def test_amd_moe_2( a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + if envs.VLLM_MOE_SHUFFLE: + w1_shuffled = permute_weight(w1.data) + w2_shuffled = permute_weight(w2.data) score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + triton_output = fused_moe(a, w1_shuffled, w2_shuffled, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=2e-1, rtol=0) From 1ca0ad670267f14d98d2a2e2aad23c0ffeaceed2 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Sat, 10 Aug 2024 16:25:49 +0000 Subject: [PATCH 17/22] use mfma 16 for all testcases, reenable moe test [dtype0-2-8-14336-4096-237] --- tests/kernels/test_moe.py | 6 ++---- .../E=8,N=14336,device_name=AMD_Radeon_Graphics.json | 4 ++-- .../configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index e2e60c2d24abe..8e3d707b8f149 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -62,8 +62,7 @@ def test_fused_moe( torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) -#@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) -@pytest.mark.parametrize("m", [1, 64, 96, 1000]) +@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) @pytest.mark.parametrize("n", [14336]) @pytest.mark.parametrize("k", [4096]) @pytest.mark.parametrize("e", [8]) @@ -92,8 +91,7 @@ def test_amd_moe_1( assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) -#@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) -@pytest.mark.parametrize("m", [1, 64, 96, 1000]) +@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) @pytest.mark.parametrize("n", [4096]) @pytest.mark.parametrize("k", [14336]) @pytest.mark.parametrize("e", [8]) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json index 54eed795101d8..64b94443c81d7 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -139,7 +139,7 @@ "num_warps": 4, "num_stages": 0, "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, + "matrix_instr_nonkdim": 16, "kpack": 2 }, "1024": { @@ -230,4 +230,4 @@ "matrix_instr_nonkdim": 16, "kpack": 2 } -} \ No newline at end of file +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json index f020993a1b615..9e8b4de93747a 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json @@ -128,7 +128,7 @@ "num_warps": 4, "num_stages": 0, "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, + "matrix_instr_nonkdim": 16, "kpack": 1 }, "512": { From 0cbe89251fcef665b87073da4ddfc6586d6aa9c1 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 10 Aug 2024 18:34:14 +0000 Subject: [PATCH 18/22] Enable LDS bypass --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 27962603a31d1..7a348758aae06 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -456,9 +456,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, compute_type=compute_type, use_fp8=use_fp8, **config, + enable_moe_lds_bypass=True ) else: - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count * 2 grid = lambda META: (min( NUM_SMS, triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * @@ -492,6 +493,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, compute_type=compute_type, use_fp8=use_fp8, **config, + enable_moe_lds_bypass=True ) From c43e3e8bb6f267590ef302051647f10e9e9a0a39 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 10 Aug 2024 19:26:18 +0000 Subject: [PATCH 19/22] Add new configs --- ...14336,device_name=AMD_Radeon_Graphics.json | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json index 64b94443c81d7..b294a1c08d6f6 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -69,7 +69,7 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 0, "waves_per_eu": 0, @@ -80,7 +80,7 @@ "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 0, "waves_per_eu": 0, @@ -146,9 +146,9 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -159,7 +159,7 @@ "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -168,9 +168,9 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -181,7 +181,7 @@ "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -189,10 +189,10 @@ "4096": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -200,10 +200,10 @@ "16384": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -211,10 +211,10 @@ "18432": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -222,10 +222,10 @@ "20480": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 From ce8a86e5108c872695965912373f95cce15f1829 Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Sat, 10 Aug 2024 22:27:12 +0800 Subject: [PATCH 20/22] patch wvSpltK_fused_moe from https://github.com/amd-hhashemi/vllm/tree/wvSpltK_fused_moe --- benchmarks/kernels/benchmark_mixtral_moe.py | 12 +- csrc/custom/custom.cu | 56 + csrc/custom/custom_kernels.cu | 1413 +++++++++++++++++ vllm/envs.py | 9 + .../layers/fused_moe/fused_moe.py | 125 +- vllm/model_executor/models/mixtral.py | 7 + 6 files changed, 1615 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 196ec8cfce88e..a8cf2211cdaa0 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -10,7 +10,8 @@ from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) - +from vllm import envs +from torch import nn def main(model, tp_size, gpu, dtype: str): os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) @@ -154,6 +155,15 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, device=hidden_states.device, dtype=hidden_states.dtype, ) + if envs.VLLM_MOE_PADDING: + w1 = nn.Parameter(F.pad(w1.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + w2 = nn.Parameter(F.pad(w2, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() w1_scale = None w2_scale = None diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 9e92187967d47..bf196b235178e 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -51,6 +51,61 @@ void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in, at::cuda::getCurrentCUDAStream(), CuCount); } +void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c, + void* topk_weights, + void* topk_ids, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_post_padded, + const int M, const int N, const int K, const int EM, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + cudaStream_t stream, const int CuCount); + +void wvSpltK_fsdMoe(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + at::Tensor topk_weights, + at::Tensor topk_ids, + at::Tensor sorted_token_ids, + at::Tensor expert_ids, + at::Tensor num_tokens_post_padded, + const int M, const int N, const int K, const int EM, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + const int CuCount) { + //int M = in_a.size(0); + //int K = in_a.size(1); + //int N = N_in; + wvSpltK_fsdMoe_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), + topk_weights.data_ptr(), + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_post_padded.data_ptr(), + M, N, K, EM, + num_valid_tokens, + stride_am, stride_ak,stride_be,stride_bk,stride_bn,stride_cm,stride_cn, + m_blck_sz, mul_routed_weight,top_k, + at::cuda::getCurrentCUDAStream(), CuCount); +} + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); @@ -103,5 +158,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("paged_attention_custom", &paged_attention_custom, "PagedAttention LL4Mi Custom."); m.def("wvSpltK", &wvSpltK); + m.def("wvSpltK_fsdMoe", &wvSpltK_fsdMoe); // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index f03d3da5a8f9c..e55e1510ec27f 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -1925,6 +1925,1419 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, #endif // defined(__HIP__MI300__) TODO: Add NAVI support + + +#undef M +#undef YTILE +#undef UNRL +#define UNRL 1 +//#define M_BLOCK 4 + +template +__global__ void +__launch_bounds__(WvPrGrp * THRDS) +wvSpltK_fsdMoe_hf_( + const DTYPE* __restrict__ A, + const DTYPE* __restrict__ B, + DTYPE* C, + const float* __restrict__ topk_weights, + const int* __restrict__ topk_ids, + const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, + const int M_in, const int N, const int K, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const bool mul_routed_weight, + const int top_k, + const int CuCount + ) { + bool PCML = (K * M_in > 32*1024); + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + if (!PCML) { + for (uint32_t k = 0; k < min(K * M_in, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * M_in, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + } + + int YW = (YTILE * WvPrGrp); + int TWC = (THRDS * WvPrGrp * A_CHUNK); + int TUC = (THRDS * UNRL * A_CHUNK); + uint32_t kBase = 0; + //find biggest k size that fits in LDS + uint32_t kFit = (32*1024)/M_BLOCK; + //kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple of TUC + kFit = (kFit%TUC==0) ? kFit : (kFit-kFit%TUC); //round down to multiple of TUC + //if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + //if (kFit < TUC) PCML = false; + + float sum[M_BLOCK][YTILE]; + + //TRITON + //offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + //offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + //token_mask = offs_token < num_valid_tokens + int offs_token[M_BLOCK]; + bool token_mask[M_BLOCK]; // add to A[] /top_k*k + int off_experts; // add to B[] *K*N loads + + uint32_t Nrndp = (N%YW==0) ? N : (N-N%YW+YW); // Note: All waves in the group need to stay alive to the bitter end, just in case they're needed for cooperative loading of next chunk of A[] into LDS. Such Zomby waves are prevented from doing any real work with continues in the loop below. + if (!PCML) Nrndp = N; //unless its not peicmeal + while (n < Nrndp) { + kBase = 0; + for (uint32_t e = 0; e < num_tokens_post_padded[0]; e+=M_BLOCK) { + kBase = 0; + + for (int m=0; m= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M_BLOCK; m++) { + if (!token_mask[m]) continue; + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + } + + // kept alive just to participate in A[] loads + if (n >= N) continue; + +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // load only 1 column of weights, despite the moe-gate, made possible by expert list. + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M_BLOCK; m++) + { + if (!token_mask[m]) continue; + if (PCML) { + //bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + kFit*m]))); + // skip A[] fetches for Ms that are disabled + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } else { + int aidx = k_ + (offs_token[m]/top_k) * K; + if (aidx + A_CHUNK <= 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[aidx]))); + else + bigA[m][k2] = *((const bigType*)(&(A[aidx]))); + } + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M_BLOCK; m++) { + // skip compute for Ms that are disabled + if (!token_mask[m]) continue; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + for (int y=0; y= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#define mfmaTILEn 16 +#define mfmaTILEk 4 +//#undef WvPrGrp +//#define WvPrGrp 8 +#define USEMFMA +//#define PIPELINED_33334x +//#define PIPELINED_556x +#define PIPELINED4x + +template +__global__ void +__launch_bounds__(WvPrGrp * THRDS) +wvSpltK_fsdMoe_hf_mfma16_( + const DTYPE* __restrict__ A, + const DTYPE* __restrict__ B, + DTYPE* C, + const float* __restrict__ topk_weights, + const int* __restrict__ topk_ids, + const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, + const int M_in, const int N, const int K, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const bool mul_routed_weight, + const int top_k, + const int CuCount + ) { + +using halfCxT = __attribute__((__vector_size__(mfmaTILEn * A_CHUNK / 2 * sizeof(float)))) float; +using halfC = __attribute__((__vector_size__(A_CHUNK / 2 * sizeof(float)))) float; +using halfT = __attribute__((__vector_size__(mfmaTILEk / 2 * sizeof(float)))) float; + +bool PCML = true;//(K * M_in > 32*1024); + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + int i[A_CHUNK / 2]; + long int l[A_CHUNK / 4]; + halfT hT[A_CHUNK / mfmaTILEk]; + halfC hC; + }; + union bigTypeXt{ + bigType B[mfmaTILEn]; + halfCxT hCT; + }; + + + __shared__ half s[1024 * 32]; + + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + int ETILE = (CuCount * WvPrGrp ) / (N/YTILE); // bump up etile to fill machine + if (ETILE < 1) ETILE = 1; //TODO: what is best default min ETILE? + if (M_in >= 128) ETILE = min(M_in/64, 15); // Heuristic: Add an ETILE for every 64 Ms + + const int num_tblk = num_tokens_post_padded[0] / M_BLOCK; + + // its worth spending time trying to load balance for this num_tokens... + if ((CuCount/(ETILE*2) > 0) && (num_tblk>0))// TODO: make sure all overflow/inf conditions are avoided + { + int nPrRnd0 = ((CuCount/(ETILE))*WvPrGrp)*YTILE; + int nRnds0 = (N + nPrRnd0 - 1 ) / nPrRnd0; + int tRnds0 = (num_tblk + (ETILE) - 1) / (ETILE); + int rnds0 = nRnds0 * tRnds0; + + int nPrRnd1n = ((CuCount/(ETILE/2))*WvPrGrp)*YTILE; + int nRnds1n = (N + nPrRnd1n - 1 ) / nPrRnd1n; + int tRnds1n = (num_tblk + (ETILE/2) - 1) / (ETILE/2); + int rnds1n = nRnds1n * tRnds1n; + + int nPrRnd1p = ((CuCount/(ETILE*2))*WvPrGrp)*YTILE; + int nRnds1p = (N + nPrRnd1p - 1 ) / nPrRnd1p; + int tRnds1p = (num_tblk + (ETILE*2) - 1) / (ETILE*2); + int rnds1p = nRnds1p * tRnds1p; + + int etl = ETILE; + if (rnds0 > rnds1n) { etl = ETILE/2; rnds0 = rnds1n; } + if (rnds0 > rnds1p) { etl = ETILE*2; rnds0 = rnds1p; } + ETILE = etl; + } + + uint32_t n = ((blockIdx.x/ETILE) * WvPrGrp + threadIdx.y) * YTILE; + +/* if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + }*/ + + if (!PCML) { + for (uint32_t k = 0; k < min(K * M_in, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * M_in, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + } + + int YW = (YTILE * WvPrGrp); + int TWC = (THRDS * WvPrGrp * A_CHUNK); + int TUC = (THRDS * UNRL * A_CHUNK); + uint32_t kBase = 0; + //find biggest k size that fits in LDS + uint32_t kFit = (32*1024)/M_BLOCK; + //kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple of TUC + kFit = (kFit%TUC==0) ? kFit : (kFit-kFit%TUC); //round down to multiple of TUC + //if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + +#ifdef USEMFMA + using float4_ = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; + float4_ sum4; +#else + float sum[M_BLOCK][YTILE]; +#endif + + //TRITON + //offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + //offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + //token_mask = offs_token < num_valid_tokens + uint32_t offs_token[M_BLOCK]; + bool token_mask[M_BLOCK]; // add to A[] /top_k*k + uint32_t off_experts; // add to B[] *K*N loads + + int kShfl = A_CHUNK * THRDS * ( threadIdx.y + (threadIdx.x/16)); + int kSprd = A_CHUNK * ( threadIdx.x ); + + uint32_t Nrndp = (N%YW==0) ? N : (N-N%YW+YW); // Note: All waves in the group need to stay alive to the bitter end, just in case they're needed for cooperative loading of next chunk of A[] into LDS. Such Zomby waves are prevented from doing any real work with continues in the loop below. + if (!PCML) Nrndp = N; //unless its not peicmeal + while (n < Nrndp) { + kBase = 0; + for (uint32_t e = (blockIdx.x % ETILE) * M_BLOCK; e < num_tokens_post_padded[0]; e+=M_BLOCK*ETILE) { + kBase = 0; + +#pragma unroll M_BLOCK + for (uint32_t m=0; m= K) break; + if (kOff >= kFit) break; +#ifdef USEMFMA + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * K + kOff; // yes, K should be kFit here. but we'lltranspose this below anyway + // Transpose A for MFMAs + uint32_t k_in_x = (k_ot / A_CHUNK) % (K / A_CHUNK); + uint32_t k_in_y = (k_ot / A_CHUNK) / (K / A_CHUNK); + uint32_t k_ot_x = (k_in_x / mfmaTILEn) * mfmaTILEn + (k_in_y % mfmaTILEn); + uint32_t k_ot_y = (k_in_y / mfmaTILEn) * mfmaTILEn + (k_in_x % mfmaTILEn); + + k_ot = (k_ot_y * (kFit / A_CHUNK) + k_ot_x) * A_CHUNK; + + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + //} +#else + //int m = threadIdx.x % M_BLOCK; + //for (uint32_t m = 0; m < M_BLOCK; m++) { + //if (!token_mask[m]) continue; + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + //} +#endif + } + __syncthreads(); + } + } + + // kept alive just to participate in A[] loads + if (n >= N) continue; + + int k1 = k1_; + if (shflk) k1 = kBase + (((k1_-kBase) + kShfl) % kFit ); // shfl loads within this lane, to reduce temporal hotspotting + + #define StgMfma4(_LN) { \ + for (uint32_t _t = 0; _t < A_CHUNK/mfmaTILEk; _t++) { \ + sum4 = __builtin_amdgcn_mfma_f32_16x16x16f16( \ + bigB[0][k2].B[_LN].hT[_t], \ + bigA[_LN][k2].hT[_t], \ + sum4, 0, 0, 0); \ + } \ + } + + +#ifdef PIPELINED1x +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/2; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/2; y= K) break; + for (int m = M_BLOCK/2; m < M_BLOCK; m++) + { + bigA[m-M_BLOCK/2][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/4; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/4; y= K) break; + for (int m = M_BLOCK/4; m < M_BLOCK/2; m++) + { + bigA[m-M_BLOCK/4][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/2; y<3*YTILE/4; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-YTILE/2].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = M_BLOCK/2; m < 3*M_BLOCK/4; m++) + { + bigA[m-M_BLOCK/2][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3*YTILE/4; y= K) break; + for (int m = 3*M_BLOCK/4; m < M_BLOCK; m++) + { + bigA[m-3*M_BLOCK/4][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y<3; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 0; m < 3; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 2////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3; y<6; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/4].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-3].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 3; m < 6; m++) + { + bigA[m-3][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } +///////////////////////////ROUND 3////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=6; y<9; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-6].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 6; m < 9; m++) + { + bigA[m-6][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 4////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=9; y<12; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-9].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 9; m < 12; m++) + { + bigA[m-9][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 5////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=12; y<16; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-12].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 12; m < 16; m++) + { + bigA[m-12][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<4; l++) + StgMfma4(l); + } + + + + +#elif defined(PIPELINED_556x) //556x + +///////////////////////////ROUND 1////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y<5; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 0; m < 5; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<5; l++) + StgMfma4(l); + //} + +///////////////////////////ROUND 2////////////////////////// +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + kSprd; + // if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=5; y<10; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/4].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-5].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 5; m < 10; m++) + { + bigA[m-5][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<5; l++) + StgMfma4(l); + //} +///////////////////////////ROUND 3////////////////////////// + //#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + kSprd; + // if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=10; y<16; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-10].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 10; m < 16; m++) + { + bigA[m-10][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<6; l++) + StgMfma4(l); + } + +#elif defined(PIPELINED8x) //8x + +///////////////////////////ROUND 1////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/8; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/8; y<2*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = M_BLOCK/8; m < 2*M_BLOCK/8; m++) + { + bigA[m-M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=2*YTILE/8; y<3*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-2*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-2*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 2*M_BLOCK/8; m < 3*M_BLOCK/8; m++) + { + bigA[m-2*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3*YTILE/8; y<4*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-3*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-3*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 3*M_BLOCK/8; m < 4*M_BLOCK/8; m++) + { + bigA[m-3*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=4*YTILE/8; y<5*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-4*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-4*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 4*M_BLOCK/8; m < 5*M_BLOCK/8; m++) + { + bigA[m-4*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=5*YTILE/8; y<6*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-5*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-5*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 5*M_BLOCK/8; m < 6*M_BLOCK/8; m++) + { + bigA[m-5*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=6*YTILE/8; y<7*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-6*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-6*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 6*M_BLOCK/8; m < 7*M_BLOCK/8; m++) + { + bigA[m-6*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=7*YTILE/8; y<8*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-7*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-7*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 7*M_BLOCK/8; m < 8*M_BLOCK/8; m++) + { + bigA[m-7*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; +#ifdef USEMFMA + for (int y=0; y= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M_BLOCK; m++) + { +#ifdef USEMFMA +#else + if (!token_mask[m]) continue; +#endif + if (PCML) { + //bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + kFit*m]))); + // skip A[] fetches for Ms that are disabled + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } else { + int aidx = k_ + (offs_token[m]/top_k) * K; + if (aidx + A_CHUNK <= 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[aidx]))); + else + bigA[m][k2] = *((const bigType*)(&(A[aidx]))); + } + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + +#ifdef USEMFMA + bigType stgB; + for (int l=0; l= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + + + +// a = torch.randn((m, k) +// b1 = torch.randn((e, 2 * n, k) +// b2 = torch.randn((e, k, n) +// topk_weights = torch.randn((m, e), device='cuda', dtype=dtype) + +void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c, + void* topk_weights, + void* topk_ids, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_post_padded, + const int M_in, const int N_in, const int K_in, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + cudaStream_t stream, const int CuCount) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + auto* a = reinterpret_cast(in_a); + auto* b = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + auto* topk_weights_ = reinterpret_cast(topk_weights); + auto* topk_ids_ = reinterpret_cast(topk_ids); + auto* sorted_token_ids_ = reinterpret_cast(sorted_token_ids); + auto* expert_ids_ = reinterpret_cast(expert_ids); + auto* num_tokens_post_padded_ = reinterpret_cast(num_tokens_post_padded); + switch (m_blck_sz) { + case 1: + wvSpltK_fsdMoe_hf_<1,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 2: + wvSpltK_fsdMoe_hf_<2,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 3: + wvSpltK_fsdMoe_hf_<3,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 4: + wvSpltK_fsdMoe_hf_<4,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 5: + wvSpltK_fsdMoe_hf_<5,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 6: + wvSpltK_fsdMoe_hf_<6,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 16: + wvSpltK_fsdMoe_hf_mfma16_<16,16><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + + } +} + void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, const int K_in, const int N_in, cudaStream_t stream, const int CuCount = 0) { diff --git a/vllm/envs.py b/vllm/envs.py index 681885e312145..281478283c244 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -46,6 +46,8 @@ VLLM_MOE_SHUFFLE: bool = False FUSED_MOE_PERSISTENT: bool = False + VLLM_MOE_MFMASWIZZLE: bool = True + VLLM_MOE_MFMASWIZZLE_M_THRSHLD: int = 32 # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -256,6 +258,13 @@ # User persistent version of fused_moe Triton kernel "FUSED_MOE_PERSISTENT": lambda: bool(int(os.getenv("FUSED_MOE_PERSISTENT", "0"))), + + # hashem's swizzle + # Swizzle the weights for mfma ops in moe kernel, or not + "VLLM_MOE_MFMASWIZZLE": + lambda: bool(int(os.getenv("VLLM_MOE_MFMASWIZZLE", "1"))), + "VLLM_MOE_MFMASWIZZLE_M_THRSHLD": + lambda: int(os.getenv("VLLM_MOE_MFMASWIZZLE_M_THRSHLD", "32")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7a348758aae06..7c009e8745006 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -13,6 +13,8 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm import _custom_C + logger = init_logger(__name__) padding_size = 128 if envs.VLLM_MOE_PADDING else 0 @@ -405,6 +407,42 @@ def moe_align_block_size( ) return sorted_ids, expert_ids, num_tokens_post_pad +def invoke_mega_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + m_blck_sz: int, mul_routed_weight: bool, top_k: int, + use_fp8: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + #print("\nm=",A.shape[0],"n=",B.shape[1],"k=",B.shape[2],"e=", B.shape[0], "ml_rt:",mul_routed_weight,"tpk",top_k, "\n") + _custom_C.wvSpltK_fsdMoe(#A, B, C, B.shape[1], 80) + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + A.shape[0], + B.shape[1], + B.shape[2] - padding_size, + B.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + m_blck_sz, + mul_routed_weight, + top_k, + 80) def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -623,7 +661,7 @@ def fused_experts(hidden_states: torch.Tensor, "GROUP_SIZE_M": 1, "num_stages": 1, } - + intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), device=hidden_states.device, @@ -640,12 +678,87 @@ def fused_experts(hidden_states: torch.Tensor, dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) + #print(hidden_states.shape) + #print(intermediate_cache2.shape) + #print("M1:", hidden_states.shape[0], "M2:", intermediate_cache2.shape[0]) + #if hidden_states.shape[0] <= 256 and hidden_states.shape[1] % 8 == 0 and intermediate_cache2.shape[0] <= 256 and not use_fp8 : + + #WVSPLTK_M_THRSHLD = 64 + #if hidden_states.shape[0] <= WVSPLTK_M_THRSHLD \ + # and hidden_states.shape[1] % 8 == 0 \ + # and intermediate_cache2.shape[0] <= WVSPLTK_M_THRSHLD \ + # and intermediate_cache2.shape[1] % 8 == 0 \ + # and not use_fp8 : + if envs.VLLM_MOE_MFMASWIZZLE and M<=envs.VLLM_MOE_MFMASWIZZLE_M_THRSHLD: + assert(compute_type == tl.float16, "Only fp16 supported for wvSplitK_mfma16x16 for now") + #m_blck_sz = -(-(M*topk_ids.shape[1]*3)//E) # target 75% of expert distribution for this M size + #if (m_blck_sz >= 12): + # m_blck_sz = 16 + + # all calls go to wvSplitK_mfma16x16 + m_blck_sz = 16 # TODO: this is for decode stage, need another for prefill + #print("M:", M, " M_BLOCK PICKED:", m_blck_sz) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, m_blck_sz, E) # target 75% of expert distribution for this M size + #topk_ids, config2['BLOCK_SIZE_M'],E) + #print("\nsrtd_tkn:", sorted_token_ids) + #print("w1Shape:",w1.shape) + + #env VLLM_MOE_MFMASWIZZLE does this swizzle on init + w1_ = w1 + w2_ = w2 + if not envs.VLLM_MOE_MFMASWIZZLE : # for debug only + if m_blck_sz >= 16 : + w1_ = torch.clone(w1) + w1_ = w1_.view(w1.shape[0], w1.shape[1]//16, 16, w1.shape[2]//128, 16, 8); + w1_ = w1_.permute(0, 1, 4, 3, 2, 5) + w1_ = w1_.contiguous() + w1_ = w1_.view(w1.shape[0],w1.shape[1],w1.shape[2]); + w2_ = torch.clone(w2) + w2_ = w2_.view(w2.shape[0], w2.shape[1]//16, 16, w2.shape[2]//128, 16, 8); + w2_ = w2_.permute(0, 1, 4, 3, 2, 5) + w2_ = w2_.contiguous() + w2_ = w2_.view(w2.shape[0],w2.shape[1],w2.shape[2]); + + #print(w1_) + + invoke_mega_fused_moe_kernel(hidden_states, + w1_, + intermediate_cache1, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + m_blck_sz, + False, + topk_ids.shape[1], + use_fp8=use_fp8) + #print("shdr_invk1:",intermediate_cache1.view(-1, N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + #print("shdr_silu:",intermediate_cache2) + #print("shdr_silu_shape:", intermediate_cache2.shape) + #print("-----------------------------") + + invoke_mega_fused_moe_kernel(intermediate_cache2, + w2_, + intermediate_cache3, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + m_blck_sz, + True, + 1, + use_fp8=use_fp8) - invoke_fused_moe_kernel(hidden_states, + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1_scale, @@ -661,9 +774,9 @@ def fused_experts(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8=use_fp8) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2a525970028bb..f974028157a74 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -615,4 +615,11 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor: x_ = x_.permute(0,1,3,4,2,5) x_ = x_.contiguous() x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); + elif envs.VLLM_MOE_MFMASWIZZLE: # hashem's swizzle + x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//128, 16, 8) + x_ = x_.permute(0,1,4,3,2,5) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); return x_ From 8148b54e92624dcb3571ada4602bb2efdf841254 Mon Sep 17 00:00:00 2001 From: Douglas Lehr Date: Sun, 11 Aug 2024 02:39:02 -0500 Subject: [PATCH 21/22] Add batched prefill via VLLM_SCHED_PREFILL_COUNT To ensure we we don't run prefills repeatedly during decode, provide a mechanism to queue up a certain number of prefills before executing. VLLM_SCHED_PREFILL_COUNT will be the minimum batch count to specify before executing. One caveat, the --scheduler-delay-factor should be used to enforce a longer prefill scheduling value. This will be set to the value in VLLM_SCHED_PREFILL_COUNT, if not explicitly provided. The need for this exists because an uneven number of prefills can lead to the queue never reaching the VLLM_SCHED_PREFILL_COUNT. Causing the server to hang --- vllm/core/scheduler.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7c70b1b244f7d..67daabc9b0fd3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -23,6 +23,9 @@ ARTIFICIAL_PREEMPTION_PROB = 0.5 ARTIFICIAL_PREEMPTION_MAX_CNT = 500 +VLLM_SCHED_PREFILL_COUNT = int( + os.getenv("VLLM_SCHED_PREFILL_COUNT", 0)) # noqa + class PreemptionMode(enum.Enum): """Preemption modes. @@ -263,7 +266,15 @@ def __init__( # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. self.lora_config = lora_config - + self.prefill_timeout = 0 + + # slightly hackey, but if you specify prefill batch count, the delay factor + # needs to exist, otherwise we will always skip. Default will be equal to + # VLLM_SCHED_PREFILL_COUNT, as they should be roughly the same. + # Recommend setting with --scheduler-delay-factor and experimenting + # On command line + if VLLM_SCHED_PREFILL_COUNT > 0 and self.scheduler_config.delay_factor == 0: + self.scheduler_config.delay_factor = VLLM_SCHED_PREFILL_COUNT version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" @@ -644,7 +655,8 @@ def _schedule_prefills( waiting_queue = deque([s for s in waiting_queue]) leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: + + while (VLLM_SCHED_PREFILL_COUNT <= len(waiting_queue) or self._passed_delay(time.time())) and waiting_queue: seq_group = waiting_queue[0] waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -719,7 +731,6 @@ def _schedule_prefills( waiting_queue.extendleft(leftover_waiting_sequences) if len(seq_groups) > 0: self.prev_prompt = True - return waiting_queue, SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, From b9f05ffe28a76f71cf6d346e2a821c1e1fce8aba Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 11 Aug 2024 23:20:16 +0800 Subject: [PATCH 22/22] add script for decode --- .../kernels/benchmark_mixtral_moe_decode.py | 255 ++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 benchmarks/kernels/benchmark_mixtral_moe_decode.py diff --git a/benchmarks/kernels/benchmark_mixtral_moe_decode.py b/benchmarks/kernels/benchmark_mixtral_moe_decode.py new file mode 100644 index 0000000000000..30f2b182738bb --- /dev/null +++ b/benchmarks/kernels/benchmark_mixtral_moe_decode.py @@ -0,0 +1,255 @@ +import argparse +import json +import os +import sys + +import torch +import torch.nn.functional as F +import triton +from tqdm import tqdm +from vllm import envs +from torch import nn +from vllm.model_executor.layers.fused_moe import (fused_moe, + get_config_file_name) + + +def main(model, tp_size, gpu, dtype: str): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) + method = fused_moe + # for bs in [ + # 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + # 2048, 3072, 4096 + # ]: + for bs in [8, 16, 32, 64, 96, 112, 120, 128]: + run_grid(bs, + model=model, + method=method, + gpu=gpu, + tp_size=tp_size, + dtype=dtype) + + +def run_grid(bs, model, method, gpu, tp_size, dtype: str): + if model == '8x7B': + d_model = 4096 + model_intermediate_size = 14336 + num_layers = 32 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + num_layers = 56 + else: + raise ValueError(f'Unsupported Mixtral model {model}') + num_total_experts = 8 + top_k = 2 + # tp_size = 2 + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + configs = [] + + for block_size_n in [32, 64, 128, 256]: + for block_size_m in [16, 32, 64, 128, 256]: + for block_size_k in [64, 128, 256]: + for group_size_m in [1, 16, 32, 64]: + for num_warps in [4, 8]: + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) + + best_config = None + best_time_us = 1e20 + + print(f'{tp_size=} {bs=}') + + # for config in tqdm(configs): + if 1: + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=None, + dtype=dtype, + ) + except triton.runtime.autotuner.OutOfResources: + #continue + pass + + # trial + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=None, + dtype=dtype, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + # best_config = config + best_time_us = kernel_dur_us + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') + + print("best_time_us", best_time_us) + print("best_config", best_config) + + # holds Dict[str, Dict[str, int]] + # filename = get_config_file_name(num_total_experts, + # model_intermediate_size // tp_size, + # "float8" if dtype == "float8" else None) + # print(f"writing config to file {filename}") + # existing_content = {} + # if os.path.exists(filename): + # with open(filename, "r") as f: + # existing_content = json.load(f) + # existing_content[str(bs)] = best_config + # with open(filename, "w") as f: + # json.dump(existing_content, f, indent=4) + # f.write("\n") + + +def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, + top_k: int, tp_size: int, model_intermediate_size: int, method, + config, dtype: str) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + + hidden_states = torch.rand( + (bs, d_model), + device="cuda:0", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if envs.VLLM_MOE_PADDING: + w1 = nn.Parameter(F.pad(w1.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + w2 = nn.Parameter(F.pad(w2.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + + gating_output = F.softmax(torch.rand( + (num_calls, bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for i in range(num_calls): + hidden_states = method( + hidden_states=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + gating_output=gating_output[i], + topk=2, + renormalize=True, + inplace=True, + override_config=config, + use_fp8=dtype == "float8", + ) + end_event.record() + end_event.synchronize() + + + # torch_output = torch_moe(a, w1, w2, score, topk) + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + parser.add_argument('--model', + type=str, + default='8x7B', + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + parser.add_argument('--tp-size', + type=int, + default=2, + help='Tensor paralleli size') + parser.add_argument('--gpu', + type=int, + default=0, + help="GPU ID for benchmarking") + args = parser.parse_args() + sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))