Skip to content

Commit

Permalink
Parallelize the profile step with --ngpus
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Oct 14, 2023
1 parent facecc2 commit b36ec13
Showing 1 changed file with 65 additions and 29 deletions.
94 changes: 65 additions & 29 deletions scripts/amd/gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,20 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
"""
return configStr, matmul_def_str

## Open a file generated_kernelMNK.py and generate
## Open {ngpus} files
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
## and generate
## 1. matmul kernels of all configs
## 2. wrapper function matmul to invoke all the generated kernels
## 3. Another wraper function try_config to invoke matmul function
## 4. test_gemm to invoke
## 4.1 run try_config in parallel
## 4.2 matmul in a loop of 10 iterations
def generate_kernel(M, N, K, configs):
f_kernel = open(f'generated_kernel{M}{N}{K}.py', 'w')
def generate_kernel(M, N, K, configs, ngpus):
filenames = []
for fi in range(ngpus):
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
f_kernel = [open(path, 'w') for path in filenames]

### write imports
import_str = """import torch
Expand All @@ -177,20 +182,24 @@ def generate_kernel(M, N, K, configs):
import sys
import multiprocessing
"""
f_kernel.write(import_str + "\n")
for fi in range(ngpus):
f_kernel[fi].write(import_str + "\n")

### write definitions of matmul_kernel_xxx
### and matmul_xxx and try_config
with open("matmul_kernel.py") as file:
matmul_kernel_code = file.read();
idx = 0
for config in configs:
file_idx = idx % ngpus
configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, config)
## Copy the matmul_kernel with name replaced
matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}")
matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "")
matmul_kernel_config = matmul_kernel_config.replace("import triton", "")
f_kernel.write(matmul_kernel_config + "\n\n")
f_kernel.write(matmul_def_str + "\n")
f_kernel[file_idx].write(matmul_kernel_config + "\n\n")
f_kernel[file_idx].write(matmul_def_str + "\n")
idx += 1

### write test_gemm
# pre string
Expand All @@ -206,28 +215,35 @@ def generate_kernel(M, N, K, configs):
if num_threads > 1:
"""
f_kernel.write(test_gemm_pre_str + "\n")
for fi in range(ngpus):
f_kernel[fi].write(test_gemm_pre_str + "\n")

# warm up call of all matmul functions in parallel
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
f_kernel.write(task_str)
f_kernel[idx % ngpus].write(task_str)
idx += 1

f_kernel.write("""
threadpool_str = """
thread_pool.close()
thread_pool.join()
""")
f_kernel.write("else:")
else:"""
for fi in range(ngpus):
f_kernel[fi].write(threadpool_str)
# call all matmul_xxx functions
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
matmul_call_str = f"""
for i in range(10):
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
f_kernel.write(matmul_call_str + "\n")
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
idx += 1
# post string
f_kernel.write(" return d\n")
for fi in range(ngpus):
f_kernel[fi].write(" return d\n")

### def main and call test_gemm
def_main_str = """
Expand All @@ -240,35 +256,48 @@ def main():
numThreads = args.n
"""
test_gemm_call_str = f'test_gemm({M}, {N}, {K}, torch.float16, numThreads)'
f_kernel.write(def_main_str)
f_kernel.write(test_gemm_call_str + "\n\n")
f_kernel.write("""if __name__ == '__main__':
sys.exit(main())""")
f_kernel.close()

def extract_kernel_time(M, N, K, config):
for fi in range(ngpus):
f_kernel[fi].write(def_main_str)
f_kernel[fi].write(test_gemm_call_str + "\n\n")
f_kernel[fi].write("""if __name__ == '__main__':
sys.exit(main())""")
f_kernel[fi].close()

def extract_kernel_time(M, N, K, config, gpuid):
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results-{gpuid}.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
parsed_outputs = run_bash_command(parse_result_cmd)
return config, parsed_outputs


def profile_batch_kernels(M, N, K, gpuid):
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")


def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
## Generate kernel out of all configs
generate_kernel(M, N, K, configs)
generate_kernel(M, N, K, configs, ngpus)

## remove any compiled kernel in the cache
run_bash_command("rm -rf ~/.triton/cache")

## precompile the kernels in parallel
## TODO: parameterize numThreads at this level
start_time = datetime.now()
run_bash_command(f"python generated_kernel{M}{N}{K}.py -n 32")
for fi in range(ngpus):
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
compile_time = datetime.now()
if verbose:
print(f"compile time: {compile_time - start_time}")

## profile generated kernels
run_bash_command(f"rocprof --stats python generated_kernel{M}{N}{K}.py")
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
for p in running:
p.start()
for p in running:
p.join()

profile_time = datetime.now()
if verbose:
print(f"profile time: {profile_time - compile_time}")
Expand All @@ -278,8 +307,11 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
minTime = 1024 * 1024 * 1024
thread_pool = multiprocessing.Pool(processes=num_threads)
tasks = []
idx = 0
for config in configs:
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config))]
file_idx = idx % ngpus
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
idx += 1
thread_pool.close()
thread_pool.join()

Expand Down Expand Up @@ -372,11 +404,13 @@ def parse_args():
parser.add_argument("-m", type=int, default=0)
parser.add_argument("-n", type=int, default=0)
parser.add_argument("-k", type=int, default=0)
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
args = parser.parse_args()

return args
Expand All @@ -387,6 +421,7 @@ def main():
matrix_size_file = args.gemm_size_file
tuning_output_file = args.tuning_results_file
keepTmp = args.keep
ngpus = args.ngpus

mnks = []
## TODO: make it more robust to get user input
Expand Down Expand Up @@ -429,7 +464,7 @@ def main():
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)

## The main tuning funtion for one gemm size
minTime, bestConfig = tune_gemm_config(M, N, K, pruned_configs, ngpus = 2, verbose=True)
minTime, bestConfig = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)

## post processing the numbers
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
Expand All @@ -451,9 +486,10 @@ def main():

## remove generated files if asked to
if not keepTmp:
os.remove(f"generated_kernel{M}{N}{K}.py")
for f in glob.glob("results.*"):
os.remove(f)
for fi in range(ngpus):
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
for f in glob.glob(f"results-{fi}.*"):
os.remove(f)

## Check correctness if asked to
if args.compare:
Expand Down

0 comments on commit b36ec13

Please sign in to comment.