Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GEMM] [Tuning] Option to try different initialization strategies #486

Merged
merged 9 commits into from
Jan 27, 2024
59 changes: 40 additions & 19 deletions scripts/amd/gemm/tune_gemm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# fp8
import argparse
import sys
import yaml
Expand Down Expand Up @@ -221,7 +222,7 @@ def generated_kernel_name(M, N, K, gpu_id):
# 4. test_gemm to invoke
# 4.1 run try_config in parallel
# 4.2 matmul in a loop of 10 iterations
def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench):
def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench):
filenames = []
for i in range(jobs):
filenames.append(generated_kernel_name(M, N, K, i))
Expand Down Expand Up @@ -259,8 +260,8 @@ def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, j
# pre string
test_gemm_pre_str = f"""def test_gemm(M, N, K, num_threads):
thread_pool = multiprocessing.Pool(processes=num_threads)
a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, device='cuda')
b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, device='cuda')
a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda')
b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda')
c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]})
task_args = (M, N, K,
a.stride(0), a.stride(1),
Expand Down Expand Up @@ -359,9 +360,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
jobId += ngpus


def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
# Generate kernel out of all configs
generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench)
generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench)

# remove any compiled kernel in the cache
run_bash_command("rm -rf ~/.triton/cache")
Expand Down Expand Up @@ -418,7 +419,7 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs,
print(f"post procesing time: {post_time}", flush=True)
return minTime, bestConfig, compile_time, profile_time, post_time

def gen_input(M, N, ty_name, needTrans, seed, device='cuda'):
def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'):
d_type = name_to_tl_types[ty_name]
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -431,10 +432,24 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
output = input
tl.store(output_ptr + offsets, output, mask=mask)

def init_by_size_and_type(size, dtype, init_type):
if init_type == 'hpl':
return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5)
# This init type has element[i] in row[j] equal to sin(i)
elif init_type == 'trig_float':
M, N = size
return torch.arange(0, N).expand(M,-1).sin().to(dtype=dtype, device='cuda')
elif init_type == 'zeros':
return torch.zeros(size, dtype=dtype, device='cuda')
elif init_type == "randn":
temp = torch.randn(size, dtype=dtype, device='cuda')
return temp
else:
raise ValueError("Bad matrix initialization type.")

raw_data = init_by_size_and_type((N,M) if needTrans else (M,N), torch.float32, init_type)
if needTrans:
raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T
else:
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
raw_data = raw_data.T
if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
(d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8():
input = raw_data.to(tl_to_torch_types[d_type])
Expand Down Expand Up @@ -481,14 +496,14 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
return c


def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, verbose):
def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize = read_config(config)

torch.manual_seed(0)
#a = torch.randn((M, K), device='cuda', dtype=datatype)
#b = torch.randn((K, N), device='cuda', dtype=datatype)
a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, device='cuda')
b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, device='cuda')
a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda')
b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda')
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize)
Expand Down Expand Up @@ -544,6 +559,7 @@ def parse_args():
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
parser.add_argument("--jobs", type=int, default=1, help="number of generated files")
parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])")
args = parser.parse_args()

return args
Expand Down Expand Up @@ -643,24 +659,25 @@ def main():

mnks = []
# TODO: make it more robust to get user input
init_type = args.init_type
if matrix_size_file == "" or not os.path.isfile(matrix_size_file):
M = args.m
N = args.n
K = args.k
col_a = args.col_a
col_b = args.col_b
mnks = [(M, N, K, col_a, col_b, None)]
mnks = [(M, N, K, col_a, col_b, init_type, None)]
else:
with open(matrix_size_file) as file:
matrix_sizes = yaml.safe_load(file)
for item in matrix_sizes:
M, N, K, col_a, col_b, item = process_item(item)
mnks.append((M, N, K, col_a, col_b, item))
mnks.append((M, N, K, col_a, col_b, init_type, item))

# Check correctness from given configs
if args.compare_wo_tuning:
for (M, N, K, col_a, col_b, myConfig) in mnks:
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, item, True)
for (M, N, K, col_a, col_b, init_type, myConfig) in mnks:
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, item, True)
return

configs_full = get_full_tuning_space()
Expand All @@ -673,7 +690,7 @@ def main():
print(f"Tuning starts at: {start_time}", flush=True)
f_results = open(tuning_output_file, 'w')

for (M, N, K, col_a, col_b, myConfig) in mnks:
for (M, N, K, col_a, col_b, init_type, myConfig) in mnks:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove init_type from mnks. All gemm sizes in a single tuning/benchmarking session can shared the same init_type

start_local_time = datetime.now()
# Obtain a pruned tuning space according to gemm size
# If running benchmark, use the provided config
Expand All @@ -691,7 +708,11 @@ def main():
verbose_level = 1
if args.verbose:
verbose_level = 2
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, pruned_configs, run_bench, jobs, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level)
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(
M, N, K, col_a, col_b, dtype_a,
dtype_b, dtype_c, init_type, pruned_configs,
run_bench, jobs, num_threads=args.num_threads, gpus=gpus,
verbose=verbose_level)

# post processing the numbers
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
Expand Down Expand Up @@ -727,7 +748,7 @@ def main():
# Check correctness if asked to
if args.compare:
print("correctness: ", end=" ", flush=True)
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, bestConfig, False)
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False)
elif not run_bench:
print("", flush=True)

Expand Down
Loading