Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-gpu fused_moe tuning support #143

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 62 additions & 205 deletions benchmarks/kernels/benchmark_mixtral_moe_rocm.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -4,182 +4,41 @@
import sys

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import triton
import triton.language as tl
from tqdm import tqdm
from tuning_utils import (get_full_tuning_space, prune_configs,
union_of_list_of_dicts)

import vllm._moe_C as moe_kernels
from vllm._C import ops
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import (get_config_file_name,
invoke_fused_moe_kernel,
moe_align_block_size)


def main(args):
os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID
os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1"
os.environ["OPTIMIZE_EPILOGUE"] = "1"

for bs in [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]:
run_grid(bs, model=args.model, TP=args.TP)


## Utilize method from rocm/Triton tuning script
def get_full_tuning_space():
configs = []

block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
# split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
# for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for (matrix_instr_nonkdim
) in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_m,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim":
matrix_instr_nonkdim,
"kpack": kpack,
})

return configs


## Utilize method from rocm/Triton tuning script
def prune_configs(M, N, K, configs):
pruned_configs = []
elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes)

mfma = 16 if M < 32 or N < 32 else 32

# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True

for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# kpack = config.get("kpack")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = 1 # config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue

pruned_configs.append(config)

return pruned_configs
world_size = args.numGPU
mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False)


def union_of_list_of_dicts(l1, l2):
result = []
temp_list = l1.copy()
temp_list.extend(l2)
for myDict in temp_list:
if myDict not in result:
result.append(myDict)

return result
def wrapper(rank, args):
dist.init_process_group("nccl", world_size=args.numGPU, rank=rank)
device_id = rank

batches = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048,
3072, 4096
]
for i in range(device_id, len(batches), args.numGPU):
tune_batch(batches[i], model=args.model, TP=args.modelTP)

def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024

def tune_batch(bs, model, TP):
device_id = torch.distributed.get_rank()

def run_grid(bs, model, TP):
if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
Expand All @@ -194,9 +53,6 @@ def run_grid(bs, model, TP):
tp_size = TP
num_calls = 100

num_warmup_trials = 1
num_trials = 1

full_configs = get_full_tuning_space()
M1 = bs * 2
N1 = model_intermediate_size * 2 // tp_size
Expand All @@ -209,16 +65,19 @@ def run_grid(bs, model, TP):
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)

configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
{len(prune_configs_2)=} | {len(configs)=}")

best_config = None
best_time_us = 1e20

for config in tqdm(configs):
# warmup
try:
for _ in range(num_warmup_trials):
progress_bar = tqdm(total=len(configs),
desc=f"bs={bs:4d} device={device_id}",
position=device_id)

with torch.cuda.device(device_id):
for config in configs:
progress_bar.update(1)
# warmup
try:
run_timing(
num_calls=num_calls,
bs=bs,
Expand All @@ -229,11 +88,10 @@ def run_grid(bs, model, TP):
model_intermediate_size=model_intermediate_size,
config=config,
)
except triton.runtime.autotuner.OutOfResources:
continue
except triton.runtime.autotuner.OutOfResources:
continue

# benchmark
for _ in range(num_trials):
# benchmark
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
Expand All @@ -246,20 +104,11 @@ def run_grid(bs, model, TP):
)

kernel_dur_us = 1000 * kernel_dur_ms
# model_dur_ms = kernel_dur_ms * num_layers

if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us

# print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
# f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
# f'{d_model=} {model_intermediate_size=} {num_layers=}')

# print("best_time_us", best_time_us)
# print("best_config", best_config)

# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size,
dtype=None)
Expand All @@ -286,10 +135,13 @@ def run_timing(
) -> float:
shard_intermediate_size = model_intermediate_size // tp_size

device_ = "cuda"
dtype_ = torch.float16

hidden_states = torch.rand(
(bs, d_model),
device="cuda",
dtype=torch.float16,
device=device_,
dtype=dtype_,
)

w1 = torch.rand(
Expand All @@ -306,7 +158,6 @@ def run_timing(

gating_output = F.softmax(
torch.rand(
# (num_calls, bs, num_total_experts), # THIS
(bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
Expand Down Expand Up @@ -341,11 +192,11 @@ def run_timing(
topk_,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
gating_output.float(),
)
del token_expert_indicies # Not used. Will be used in the future.

Expand Down Expand Up @@ -376,7 +227,7 @@ def run_timing(
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
for i in range(num_calls):
for _ in range(num_calls):
invoke_fused_moe_kernel(
hidden_states,
w1,
Expand All @@ -393,8 +244,8 @@ def run_timing(
config,
compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16
else tl.float16),
use_fp8=False,
)
use_fp8_w8a8=False,
use_int8_w8a16=False)
divakar-amd marked this conversation as resolved.
Show resolved Hide resolved

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

Expand All @@ -414,8 +265,8 @@ def run_timing(
config,
compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16
else tl.float16),
use_fp8=False,
)
use_fp8_w8a8=False,
use_int8_w8a16=False)

end_event.record()
end_event.synchronize()
Expand All @@ -427,28 +278,34 @@ def run_timing(
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="benchmark_mixtral_moe_rocm",
description="Tune the fused_moe kernel for mixtral.")
description="Distributed tuning script for the fused_moe kernel.")
parser.add_argument('--model',
type=str,
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')
parser.add_argument(
"--TP",
"--modelTP",
type=int,
choices=[8, 4, 2, 1],
help="Specify the TP value that the actual model will run on",
help="Specify the TP value that the model will actually run on",
required=True,
)
parser.add_argument(
"--GPUID",
type=str,
help="This script uses single GPU. Specify the GPU to use for tuning",
default="0",
"--numGPU",
type=int,
choices=[8, 4, 2, 1],
help="Total number of GPUs to use for tuning",
required=True,
)
parser.add_argument('--model',
type=str,
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')

args = parser.parse_args()
if "LOCAL_RANK" not in os.environ:
print("Please use torchrun to launch this multi-gpu script. E.g:")
print("\ttorchrun benchmark_mixtral_moe_rocm.py",
"--model 8x7B --modelTP 4 --numGPU 2")
print("Exiting...")
exit()

print(f"Running tuning for {args.model} model")
print(f"TP is set to: {args.TP}")
print(f"GPU-ID being used for tuning: {args.GPUID}")
print(f"Model TP is set to: {args.modelTP}")
print(f"GPUs being used for tuning: {args.numGPU}")
sys.exit(main(args))
Loading
Loading