diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py old mode 100755 new mode 100644 index 63080eaf2f11c..ba5f77e4463ec --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -2,184 +2,57 @@ import json import os import sys +import time import torch +import torch.distributed as dist +import torch.multiprocessing as mp import torch.nn.functional as F -import triton import triton.language as tl +from natsort import natsorted from tqdm import tqdm +from tuning_utils import (get_full_tuning_space, prune_configs, + union_of_list_of_dicts) -import vllm._moe_C as moe_kernels -from vllm._C import ops -from vllm.model_executor.layers.fused_moe import (get_config_file_name, +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import (fused_topk, + get_config_file_name, invoke_fused_moe_kernel, moe_align_block_size) def main(args): - os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID - os.environ["HIP_FORCE_DEV_KERNARG"] = "1" - os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - os.environ["OPTIMIZE_EPILOGUE"] = "1" - - for bs in [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ]: - run_grid(bs, model=args.model, TP=args.TP) - - -## Utilize method from rocm/Triton tuning script -def get_full_tuning_space(): - configs = [] - - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] - # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - # for split_k in split_k_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for (matrix_instr_nonkdim - ) in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_m, - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "matrix_instr_nonkdim": - matrix_instr_nonkdim, - "kpack": kpack, - }) - - return configs - - -## Utilize method from rocm/Triton tuning script -def prune_configs(M, N, K, configs): - pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = 1 # config.get("SPLIT_K") - GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def union_of_list_of_dicts(l1, l2): - result = [] - temp_list = l1.copy() - temp_list.extend(l2) - for myDict in temp_list: - if myDict not in result: - result.append(myDict) - - return result - + world_size = args.numGPU + start_time = time.time() + try: + mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=True) + except Exception as e: + print(f"An error occurred during multiprocessing: {e}") + end_time = time.time() + print(f"Total time taken: {end_time - start_time:.2f} seconds") + + +def wrapper(rank, args): + dist.init_process_group("nccl", world_size=args.numGPU, rank=rank) + device_id = rank + + batches = [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, + 3072, 4096 + ] + try: + for i in range(device_id, len(batches), args.numGPU): + tune_batch(batches[i], args) + except Exception as e: + print(f"An error occurred on device {device_id}: {e}") -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 +def tune_batch(bs, args): + model = args.model + TP = args.modelTP + use_fp8 = args.use_fp8 + device_id = torch.distributed.get_rank() -def run_grid(bs, model, TP): if model == '8x7B': d_model = 4096 model_intermediate_size = 14336 @@ -194,33 +67,33 @@ def run_grid(bs, model, TP): tp_size = TP num_calls = 100 - num_warmup_trials = 1 - num_trials = 1 - - full_configs = get_full_tuning_space() + full_configs = get_full_tuning_space(use_fp8) M1 = bs * 2 N1 = model_intermediate_size * 2 // tp_size K1 = d_model - prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + prune_configs_1 = prune_configs(M1, N1, K1, full_configs, use_fp8) M2 = bs * 2 N2 = d_model K2 = model_intermediate_size // tp_size - prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + prune_configs_2 = prune_configs(M2, N2, K2, full_configs, use_fp8) configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) - print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ - {len(prune_configs_2)=} | {len(configs)=}") best_config = None best_time_us = 1e20 - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): + progress_bar = tqdm(total=len(configs), + desc=f"bs={bs:4d} device={device_id}", + position=device_id) + + with torch.cuda.device(device_id): + for config in configs: + progress_bar.update(1) + # warmup + try: run_timing( - num_calls=num_calls, + num_calls=5, bs=bs, d_model=d_model, num_total_experts=num_total_experts, @@ -228,12 +101,13 @@ def run_grid(bs, model, TP): tp_size=tp_size, model_intermediate_size=model_intermediate_size, config=config, + use_fp8_w8a8=use_fp8, ) - except triton.runtime.autotuner.OutOfResources: - continue + except Exception as e: + print(f"Error during warmup: {e}") + continue - # benchmark - for _ in range(num_trials): + # benchmark kernel_dur_ms = run_timing( num_calls=num_calls, bs=bs, @@ -243,37 +117,38 @@ def run_grid(bs, model, TP): tp_size=tp_size, model_intermediate_size=model_intermediate_size, config=config, + use_fp8_w8a8=use_fp8, ) kernel_dur_us = 1000 * kernel_dur_ms - # model_dur_ms = kernel_dur_ms * num_layers if kernel_dur_us < best_time_us: best_config = config best_time_us = kernel_dur_us - # print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - # f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - # f'{d_model=} {model_intermediate_size=} {num_layers=}') - - # print("best_time_us", best_time_us) - # print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] + config_dtype = "fp8_w8a8" if use_fp8 else None filename = get_config_file_name(num_total_experts, model_intermediate_size // tp_size, - dtype=None) + dtype=config_dtype) print(f"writing config to file {filename}") existing_content = {} if os.path.exists(filename): with open(filename, "r") as f: existing_content = json.load(f) existing_content[str(bs)] = best_config + existing_content = sort_json(existing_content) with open(filename, "w") as f: json.dump(existing_content, f, indent=4) f.write("\n") +def sort_json(json_file): + return { + k: v + for k, v in natsorted(json_file.items(), key=lambda item: item[0]) + } + + def run_timing( num_calls: int, bs: int, @@ -283,13 +158,17 @@ def run_timing( tp_size: int, model_intermediate_size: int, config, + use_fp8_w8a8: bool, ) -> float: shard_intermediate_size = model_intermediate_size // tp_size + device_ = "cuda" + dtype_ = torch.float16 + hidden_states = torch.rand( (bs, d_model), - device="cuda", - dtype=torch.float16, + device=device_, + dtype=dtype_, ) w1 = torch.rand( @@ -306,13 +185,29 @@ def run_timing( gating_output = F.softmax( torch.rand( - # (num_calls, bs, num_total_experts), # THIS (bs, num_total_experts), device=hidden_states.device, dtype=torch.float32, ), dim=-1, ) + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if use_fp8_w8a8: + w1_scale = torch.randn(num_total_experts, + dtype=torch.float32, + device=device_) + w2_scale = torch.randn(num_total_experts, + dtype=torch.float32, + device=device_) + a1_scale = torch.randn(1, dtype=torch.float32, device=device_) + a2_scale = torch.randn(1, dtype=torch.float32, device=device_) + + w1 = w1.to(torch.float8_e4m3fnuz) + w2 = w2.to(torch.float8_e4m3fnuz) ###### Stuff from fused moe ###### @@ -328,28 +223,9 @@ def run_timing( ] M, _ = hidden_states.shape E, N, _ = w1.shape - topk_ = 2 - topk_weights = torch.empty(M, - topk_, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk_, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk_, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, top_k, + True) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), @@ -376,13 +252,13 @@ def run_timing( end_event = torch.cuda.Event(enable_timing=True) start_event.record() - for i in range(num_calls): + for _ in range(num_calls): invoke_fused_moe_kernel( hidden_states, w1, intermediate_cache1, - None, # a1_scale - None, # w1_scale + a1_scale, + w1_scale, topk_weights, topk_ids, sorted_token_ids, @@ -393,8 +269,8 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8=False, - ) + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=False) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -402,8 +278,8 @@ def run_timing( intermediate_cache2, w2, intermediate_cache3, - None, # a2_scale - None, # w2_scale + a2_scale, + w2_scale, topk_weights, topk_ids, sorted_token_ids, @@ -414,8 +290,8 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8=False, - ) + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=False) end_event.record() end_event.synchronize() @@ -427,28 +303,40 @@ def run_timing( if __name__ == "__main__": parser = argparse.ArgumentParser( prog="benchmark_mixtral_moe_rocm", - description="Tune the fused_moe kernel for mixtral.") + description="Distributed tuning script for the fused_moe kernel.") + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') parser.add_argument( - "--TP", + "--modelTP", type=int, choices=[8, 4, 2, 1], - help="Specify the TP value that the actual model will run on", + help="Specify the TP value that the model will actually run on", required=True, ) parser.add_argument( - "--GPUID", - type=str, - help="This script uses single GPU. Specify the GPU to use for tuning", - default="0", + "--numGPU", + type=int, + choices=[8, 4, 2, 1], + help="Total number of GPUs to use for tuning", + required=True, + ) + parser.add_argument( + "--use_fp8", + action="store_true", + help="Flag to indicate whether to use FP8 tuning", ) - parser.add_argument('--model', - type=str, - choices=['8x7B', '8x22B'], - help='The Mixtral model to benchmark') - args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + print("Please use torchrun to launch this multi-gpu script. E.g:") + print("\ttorchrun benchmark_mixtral_moe_rocm.py", + "--model 8x7B --modelTP 4 --numGPU 2") + print("Exiting...") + exit() print(f"Running tuning for {args.model} model") - print(f"TP is set to: {args.TP}") - print(f"GPU-ID being used for tuning: {args.GPUID}") + print(f"Model TP is set to: {args.modelTP}") + print(f"GPUs being used for tuning: {args.numGPU}") + print(f"Using FP8: {args.use_fp8}") sys.exit(main(args)) diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py new file mode 100644 index 0000000000000..5a9c84d9ce736 --- /dev/null +++ b/benchmarks/kernels/tuning_utils.py @@ -0,0 +1,136 @@ +from itertools import product + + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(use_fp8): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if use_fp8: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [] if use_fp8 else [16, 32] + kpack_range = [] if use_fp8 else [1, 2] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + + if not use_fp8: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + return configs + + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs, is_fp8=False): + pruned_configs = [] + elemBytes_a = 1 if is_fp8 else 2 # Assuming fp16 or fp8 cases only + elemBytes_b = 1 if is_fp8 else 2 # Assuming fp16 or fp8 cases only + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + + if not is_fp8: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if not is_fp8: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if (matrix_instr_nonkdim >= M + and matrix_instr_nonkdim != BLOCK_SIZE_M): + continue + if (matrix_instr_nonkdim >= N + and matrix_instr_nonkdim != BLOCK_SIZE_N): + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 7f83cbbc307a6..d417f89b249bc 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -11,3 +11,4 @@ peft pytest-asyncio tensorizer>=2.9.0 setuptools-scm>=8 +natsort