From 609ebe962fd928ca0476ec9c59eca09500e5eb99 Mon Sep 17 00:00:00 2001 From: Your Name <2535184404@qq.com> Date: Thu, 26 Dec 2024 09:32:35 +0800 Subject: [PATCH] modify --- src/flag_gems/ops/addmm.py | 2 +- src/flag_gems/ops/all.py | 2 +- src/flag_gems/ops/amax.py | 2 +- src/flag_gems/ops/any.py | 2 +- src/flag_gems/ops/argmax.py | 2 +- src/flag_gems/ops/bmm.py | 4 ++-- src/flag_gems/ops/count_nonzero.py | 4 ++-- src/flag_gems/ops/cross_entropy_loss.py | 14 +++++++------- src/flag_gems/ops/dropout.py | 4 ++-- src/flag_gems/ops/exponential_.py | 2 +- src/flag_gems/ops/gather.py | 2 +- src/flag_gems/ops/index_select.py | 2 +- src/flag_gems/ops/instancenorm.py | 12 ++++++------ src/flag_gems/ops/layernorm.py | 10 +++++----- src/flag_gems/ops/log_softmax.py | 2 +- src/flag_gems/ops/masked_fill.py | 4 ++-- src/flag_gems/ops/masked_select.py | 2 +- src/flag_gems/ops/max.py | 2 +- src/flag_gems/ops/mean.py | 2 +- src/flag_gems/ops/min.py | 2 +- src/flag_gems/ops/mm.py | 4 ++-- src/flag_gems/ops/multinomial.py | 2 +- src/flag_gems/ops/mv.py | 2 +- src/flag_gems/ops/nonzero.py | 2 +- src/flag_gems/ops/prod.py | 2 +- src/flag_gems/ops/rand.py | 2 +- src/flag_gems/ops/randn.py | 2 +- src/flag_gems/ops/randperm.py | 2 +- src/flag_gems/ops/scatter.py | 2 +- src/flag_gems/ops/softmax.py | 12 ++++++------ src/flag_gems/ops/sum.py | 2 +- src/flag_gems/ops/triu.py | 4 ++-- src/flag_gems/ops/uniform.py | 2 +- src/flag_gems/ops/upsample_bicubic2d_aa.py | 4 ++-- src/flag_gems/ops/upsample_nearest2d.py | 4 ++-- src/flag_gems/ops/var_mean.py | 4 ++-- src/flag_gems/ops/vector_norm.py | 10 +++++----- src/flag_gems/ops/vstack.py | 2 +- src/flag_gems/ops/weightnorm.py | 12 ++++++------ src/flag_gems/runtime/__init__.py | 8 ++++---- src/flag_gems/runtime/backend/__init__.py | 2 +- .../backend/_nvidia/heuristics_config_utils.py | 16 ++++++++++++---- src/flag_gems/runtime/configloader.py | 7 +++---- 43 files changed, 97 insertions(+), 90 deletions(-) diff --git a/src/flag_gems/ops/addmm.py b/src/flag_gems/ops/addmm.py index fc7785df1..35838a58b 100644 --- a/src/flag_gems/ops/addmm.py +++ b/src/flag_gems/ops/addmm.py @@ -12,7 +12,7 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("addmm"), + configs=runtime.get_tuned_config("addmm"), key=["M", "N", "K"], ) @triton.jit(do_not_specialize=["alpha", "beta"]) diff --git a/src/flag_gems/ops/all.py b/src/flag_gems/ops/all.py index 344ba617f..de4aab288 100644 --- a/src/flag_gems/ops/all.py +++ b/src/flag_gems/ops/all.py @@ -21,7 +21,7 @@ def reduce_all(a, b): @libentry() -@triton.autotune(configs=runtime.get_triton_config("all"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"]) @triton.jit def all_kernel_dim( inp, diff --git a/src/flag_gems/ops/amax.py b/src/flag_gems/ops/amax.py index 975c57e65..bfce1c2a3 100644 --- a/src/flag_gems/ops/amax.py +++ b/src/flag_gems/ops/amax.py @@ -42,7 +42,7 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("amax"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"]) @triton.jit def amax_kernel( inp, diff --git a/src/flag_gems/ops/any.py b/src/flag_gems/ops/any.py index 150ba5353..e3bb9e25e 100644 --- a/src/flag_gems/ops/any.py +++ b/src/flag_gems/ops/any.py @@ -21,7 +21,7 @@ def reduce_any(a, b): @libentry() -@triton.autotune(configs=runtime.get_triton_config("any"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"]) @triton.jit def any_kernel_dim( inp, diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index e427a562a..c03b28081 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -47,7 +47,7 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr @libentry() -@triton.heuristics(runtime.get_heuristics_config("argmax")) +@triton.heuristics(runtime.get_heuristic_config("argmax")) @triton.jit def argmax_kernel( inp, diff --git a/src/flag_gems/ops/bmm.py b/src/flag_gems/ops/bmm.py index 66cc1987d..65281f8ec 100644 --- a/src/flag_gems/ops/bmm.py +++ b/src/flag_gems/ops/bmm.py @@ -12,10 +12,10 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("bmm"), + configs=runtime.get_tuned_config("bmm"), key=["M", "N", "K"], ) -@triton.heuristics(runtime.get_heuristics_config("bmm")) +@triton.heuristics(runtime.get_heuristic_config("bmm")) @triton.jit def bmm_kernel( A, diff --git a/src/flag_gems/ops/count_nonzero.py b/src/flag_gems/ops/count_nonzero.py index 08bd86175..76bb05ab1 100644 --- a/src/flag_gems/ops/count_nonzero.py +++ b/src/flag_gems/ops/count_nonzero.py @@ -23,7 +23,7 @@ def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("count_nonzero"), key=["numel"]) +@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) @triton.jit def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): pid_x = tle.program_id(0) @@ -41,7 +41,7 @@ def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("count_nonzero"), key=["numel"]) +@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) @triton.jit def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): pid_x = tle.program_id(0) diff --git a/src/flag_gems/ops/cross_entropy_loss.py b/src/flag_gems/ops/cross_entropy_loss.py index 70d56eaaf..505e9e624 100644 --- a/src/flag_gems/ops/cross_entropy_loss.py +++ b/src/flag_gems/ops/cross_entropy_loss.py @@ -12,7 +12,7 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss"), + configs=runtime.get_tuned_config("cross_entropy_loss"), key=["C", "D"], ) @triton.jit(do_not_specialize=["ignore_index"]) @@ -73,7 +73,7 @@ def celoss_indices_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss"), + configs=runtime.get_tuned_config("cross_entropy_loss"), key=["C", "D"], ) @triton.jit(do_not_specialize=["label_smoothing"]) @@ -132,7 +132,7 @@ def celoss_probability_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss"), + configs=runtime.get_tuned_config("cross_entropy_loss"), key=["C", "D"], ) @triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) @@ -213,7 +213,7 @@ def celoss_indices_smooth_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss"), + configs=runtime.get_tuned_config("cross_entropy_loss"), key=["C", "D"], ) @triton.jit(do_not_specialize=["ignore_index", "mean_num"]) @@ -284,7 +284,7 @@ def celoss_indices_bwd( @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss"), + configs=runtime.get_tuned_config("cross_entropy_loss"), key=["C", "D"], ) @triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) @@ -369,7 +369,7 @@ def celoss_probability_bwd( @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss"), + configs=runtime.get_tuned_config("cross_entropy_loss"), key=["C", "D"], ) @triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) @@ -463,7 +463,7 @@ def celoss_indices_smooth_bwd( @libentry() @triton.autotune( - configs=runtime.get_triton_config("cross_entropy_loss_sum_and_scale"), + configs=runtime.get_tuned_config("cross_entropy_loss_sum_and_scale"), key=[ "N", ], diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 888f96b27..2bcb31a81 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -13,7 +13,7 @@ from ..runtime import torch_device_fn -@triton.heuristics(runtime.get_heuristics_config("dropout")) +@triton.heuristics(runtime.get_heuristic_config("dropout")) @triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) def dropout_forward_kernel( X, @@ -65,7 +65,7 @@ def dropout_forward_kernel( tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") -@triton.heuristics(runtime.get_heuristics_config("dropout")) +@triton.heuristics(runtime.get_heuristic_config("dropout")) @triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) def dropout_backward_kernel( DY, diff --git a/src/flag_gems/ops/exponential_.py b/src/flag_gems/ops/exponential_.py index 977ae097f..8d36ce7eb 100644 --- a/src/flag_gems/ops/exponential_.py +++ b/src/flag_gems/ops/exponential_.py @@ -13,7 +13,7 @@ from ..runtime import torch_device_fn -@triton.heuristics(runtime.get_heuristics_config("exponential_")) +@triton.heuristics(runtime.get_heuristic_config("exponential_")) @triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) def fused_exponential_kernel( out_ptr, diff --git a/src/flag_gems/ops/gather.py b/src/flag_gems/ops/gather.py index 0feb25198..c0c16124e 100644 --- a/src/flag_gems/ops/gather.py +++ b/src/flag_gems/ops/gather.py @@ -36,7 +36,7 @@ def generate_gather_kernel( code.writeline("@libentry()") code.writeline("@triton.heuristics(") with code.indent(): - code.writeline("runtime.get_heuristics_config('gather')") + code.writeline("runtime.get_heuristic_config('gather')") code.writeline(")") code.writeline("@triton.jit") diff --git a/src/flag_gems/ops/index_select.py b/src/flag_gems/ops/index_select.py index f74d77e5d..732a5af16 100644 --- a/src/flag_gems/ops/index_select.py +++ b/src/flag_gems/ops/index_select.py @@ -10,7 +10,7 @@ @libentry() -@triton.heuristics(runtime.get_heuristics_config("index_select")) +@triton.heuristics(runtime.get_heuristic_config("index_select")) @triton.jit def index_select_kernel( inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr diff --git a/src/flag_gems/ops/instancenorm.py b/src/flag_gems/ops/instancenorm.py index cf901bf70..8a6343915 100644 --- a/src/flag_gems/ops/instancenorm.py +++ b/src/flag_gems/ops/instancenorm.py @@ -22,7 +22,7 @@ def prev_multiple_of(a, b): @libentry() @triton.autotune( - configs=runtime.get_triton_config("instancenorm"), + configs=runtime.get_tuned_config("instancenorm"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -72,7 +72,7 @@ def instance_norm_persistent_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("instancenorm"), + configs=runtime.get_tuned_config("instancenorm"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -126,7 +126,7 @@ def instance_norm_persistent_kernel_multiline( @libentry() @triton.autotune( - configs=runtime.get_triton_config("instance_norm_loop"), + configs=runtime.get_tuned_config("instance_norm_loop"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -218,7 +218,7 @@ def instance_norm_loop_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("instancenorm"), + configs=runtime.get_tuned_config("instancenorm"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -312,7 +312,7 @@ def update_running_stats_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("instance_norm_backward"), + configs=runtime.get_tuned_config("instance_norm_backward"), key=["M", "N", "C"], ) @triton.jit @@ -379,7 +379,7 @@ def instance_norm_backward_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("instance_norm_weight_bias_backward"), + configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"), key=["N", "B", "C"], ) @triton.jit diff --git a/src/flag_gems/ops/layernorm.py b/src/flag_gems/ops/layernorm.py index 425854ccb..6e3a30ab0 100644 --- a/src/flag_gems/ops/layernorm.py +++ b/src/flag_gems/ops/layernorm.py @@ -20,7 +20,7 @@ def prev_multiple_of(a, b): @libentry() @triton.autotune( - configs=runtime.get_triton_config("layer_norm_persistent"), + configs=runtime.get_tuned_config("layer_norm_persistent"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -69,7 +69,7 @@ def layer_norm_persistent_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("layer_norm_persistent"), + configs=runtime.get_tuned_config("layer_norm_persistent"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -123,7 +123,7 @@ def layer_norm_persistent_kernel_multiline( @libentry() @triton.autotune( - configs=runtime.get_triton_config("layer_norm_loop"), + configs=runtime.get_tuned_config("layer_norm_loop"), key=["M", "N"], ) @triton.jit(do_not_specialize=["eps"]) @@ -220,7 +220,7 @@ def layer_norm_loop_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("layer_norm_backward"), + configs=runtime.get_tuned_config("layer_norm_backward"), key=["M", "N"], ) @triton.jit @@ -288,7 +288,7 @@ def layer_norm_backward_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_bias_backward"), + configs=runtime.get_tuned_config("weight_bias_backward"), key=["N"], ) @triton.jit diff --git a/src/flag_gems/ops/log_softmax.py b/src/flag_gems/ops/log_softmax.py index 7c49a0c12..c3b331ede 100644 --- a/src/flag_gems/ops/log_softmax.py +++ b/src/flag_gems/ops/log_softmax.py @@ -54,7 +54,7 @@ def log_softmax_kernel( @libentry() -@triton.autotune(configs=runtime.get_triton_config("log_softmax"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) @triton.jit def log_softmax_backward_kernel( out_ptr, diff --git a/src/flag_gems/ops/masked_fill.py b/src/flag_gems/ops/masked_fill.py index f1c010e84..8a2ebded1 100644 --- a/src/flag_gems/ops/masked_fill.py +++ b/src/flag_gems/ops/masked_fill.py @@ -10,7 +10,7 @@ @libentry() -@triton.autotune(configs=runtime.get_triton_config("masked_fill"), key=["N"]) +@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) @triton.jit def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr): pid = tle.program_id(axis=0) @@ -24,7 +24,7 @@ def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr @libentry() -@triton.autotune(configs=runtime.get_triton_config("masked_fill"), key=["N"]) +@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) @triton.jit def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr): pid = tle.program_id(axis=0) diff --git a/src/flag_gems/ops/masked_select.py b/src/flag_gems/ops/masked_select.py index c98932bde..db6aef58f 100644 --- a/src/flag_gems/ops/masked_select.py +++ b/src/flag_gems/ops/masked_select.py @@ -11,7 +11,7 @@ @libentry() -@triton.autotune(configs=runtime.get_triton_config("masked_select"), key=["n_elements"]) +@triton.autotune(configs=runtime.get_tuned_config("masked_select"), key=["n_elements"]) @triton.jit def masked_select_kernel( inp_ptr, diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 763751b3a..ba0b25d5f 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -47,7 +47,7 @@ def heur_block_n(args): @libentry() @triton.autotune( - configs=runtime.get_triton_config("max"), + configs=runtime.get_tuned_config("max"), key=[ "M", "N", diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index e3aba6864..64e822e49 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -60,7 +60,7 @@ def mean(inp, *, dtype=None): @libentry() @triton.autotune( - configs=runtime.get_triton_config("mean"), + configs=runtime.get_tuned_config("mean"), key=["M", "N"], ) @triton.jit diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index 64d28bdaf..9426853b7 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -47,7 +47,7 @@ def heur_block_n(args): @libentry() @triton.autotune( - configs=runtime.get_triton_config("min"), + configs=runtime.get_tuned_config("min"), key=[ "M", "N", diff --git a/src/flag_gems/ops/mm.py b/src/flag_gems/ops/mm.py index 1768f2369..40a3a66a4 100644 --- a/src/flag_gems/ops/mm.py +++ b/src/flag_gems/ops/mm.py @@ -12,10 +12,10 @@ @libentry() @libtuner( - configs=runtime.get_triton_config("mm"), + configs=runtime.get_tuned_config("mm"), key=["M", "N", "K"], ) -@triton.heuristics(runtime.get_heuristics_config("mm")) +@triton.heuristics(runtime.get_heuristic_config("mm")) @triton.jit def mm_kernel( A, diff --git a/src/flag_gems/ops/multinomial.py b/src/flag_gems/ops/multinomial.py index 5351c1275..70d7a3c5f 100644 --- a/src/flag_gems/ops/multinomial.py +++ b/src/flag_gems/ops/multinomial.py @@ -11,7 +11,7 @@ @libentry() -@triton.heuristics(runtime.get_heuristics_config("multinomial")) +@triton.heuristics(runtime.get_heuristic_config("multinomial")) @triton.jit(do_not_specialize=["K", "N", "philox_seed", "philox_offset"]) def multinomial_with_replacement( cdf_ptr, out_ptr, K, N, philox_seed, philox_offset, NBLOCK: tl.constexpr diff --git a/src/flag_gems/ops/mv.py b/src/flag_gems/ops/mv.py index 723d94292..1220fad12 100644 --- a/src/flag_gems/ops/mv.py +++ b/src/flag_gems/ops/mv.py @@ -12,7 +12,7 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("mv"), + configs=runtime.get_tuned_config("mv"), key=["M", "N"], ) @triton.jit diff --git a/src/flag_gems/ops/nonzero.py b/src/flag_gems/ops/nonzero.py index 0578b4fd9..2b4d20181 100644 --- a/src/flag_gems/ops/nonzero.py +++ b/src/flag_gems/ops/nonzero.py @@ -12,7 +12,7 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("nonzero"), + configs=runtime.get_tuned_config("nonzero"), key=[ "n_elements", ], diff --git a/src/flag_gems/ops/prod.py b/src/flag_gems/ops/prod.py index ca88996dc..a2b631344 100644 --- a/src/flag_gems/ops/prod.py +++ b/src/flag_gems/ops/prod.py @@ -70,7 +70,7 @@ def heur_block_n(args): @libentry() @triton.autotune( - configs=runtime.get_triton_config("prod"), + configs=runtime.get_tuned_config("prod"), key=[ "M", "N", diff --git a/src/flag_gems/ops/rand.py b/src/flag_gems/ops/rand.py index 51ee05ae8..9cab927ac 100644 --- a/src/flag_gems/ops/rand.py +++ b/src/flag_gems/ops/rand.py @@ -16,7 +16,7 @@ device_ = device -@triton.heuristics(runtime.get_heuristics_config("rand")) +@triton.heuristics(runtime.get_heuristic_config("rand")) @triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) def rand_kernel( out_ptr, diff --git a/src/flag_gems/ops/randn.py b/src/flag_gems/ops/randn.py index 6857c1a97..3ee2932e5 100644 --- a/src/flag_gems/ops/randn.py +++ b/src/flag_gems/ops/randn.py @@ -29,7 +29,7 @@ def pair_uniform_to_normal(u1, u2): device_ = device -@triton.heuristics(runtime.get_heuristics_config("randn")) +@triton.heuristics(runtime.get_heuristic_config("randn")) @triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) def randn_kernel( out_ptr, diff --git a/src/flag_gems/ops/randperm.py b/src/flag_gems/ops/randperm.py index 46377d757..9757a1adc 100644 --- a/src/flag_gems/ops/randperm.py +++ b/src/flag_gems/ops/randperm.py @@ -155,7 +155,7 @@ def digit_hist_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("randperm"), + configs=runtime.get_tuned_config("randperm"), key=["n_elements"], ) @triton.jit diff --git a/src/flag_gems/ops/scatter.py b/src/flag_gems/ops/scatter.py index 1d34a829e..62f2b1654 100644 --- a/src/flag_gems/ops/scatter.py +++ b/src/flag_gems/ops/scatter.py @@ -38,7 +38,7 @@ def generate_scatter_kernel( # the decorators code.writeline("@libentry()") code.writeline( - '@triton.autotune(configs=runtime.get_triton_config("scatter"), key=["M", "N"])' + '@triton.autotune(configs=runtime.get_tuned_config("scatter"), key=["M", "N"])' ) code.writeline("@triton.jit") diff --git a/src/flag_gems/ops/softmax.py b/src/flag_gems/ops/softmax.py index ab979b70f..e5dbda584 100644 --- a/src/flag_gems/ops/softmax.py +++ b/src/flag_gems/ops/softmax.py @@ -11,7 +11,7 @@ @libentry() -@triton.heuristics(runtime.get_heuristics_config("softmax_non_inner")) +@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) @triton.jit def softmax_kernel_non_inner( output_ptr, @@ -83,7 +83,7 @@ def prev_multiple_of(a, b): @libentry() -@triton.heuristics(runtime.get_heuristics_config("softmax_inner")) +@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) @triton.jit def softmax_kernel_inner( output_ptr, @@ -160,14 +160,14 @@ def softmax_kernel_inner( # ------------------------ backward ------------------------------- @libentry() @triton.autotune( - configs=runtime.get_triton_config("softmax_non_inner"), + configs=runtime.get_tuned_config("softmax_non_inner"), key=[ "M", "N", "K", ], ) -@triton.heuristics(runtime.get_heuristics_config("softmax_backward_non_inner")) +@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) @triton.jit def softmax_backward_kernel_non_inner( out_ptr, @@ -220,11 +220,11 @@ def softmax_backward_kernel_non_inner( @libentry() @triton.autotune( - configs=runtime.get_triton_config("softmax_inner"), + configs=runtime.get_tuned_config("softmax_inner"), key=["M", "N"], ) @triton.heuristics( - values=runtime.get_heuristics_config("softmax_backward_inner"), + values=runtime.get_heuristic_config("softmax_backward_inner"), ) @triton.jit def softmax_backward_kernel_inner( diff --git a/src/flag_gems/ops/sum.py b/src/flag_gems/ops/sum.py index 3194bfb90..ca47c0f21 100644 --- a/src/flag_gems/ops/sum.py +++ b/src/flag_gems/ops/sum.py @@ -56,7 +56,7 @@ def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("sum"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("sum"), key=["M", "N"]) @triton.jit def sum_kernel( inp, diff --git a/src/flag_gems/ops/triu.py b/src/flag_gems/ops/triu.py index 429dc87ef..71716def6 100644 --- a/src/flag_gems/ops/triu.py +++ b/src/flag_gems/ops/triu.py @@ -11,7 +11,7 @@ @libentry() -@triton.autotune(configs=runtime.get_triton_config("triu"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"]) @triton.jit(do_not_specialize=["diagonal"]) def triu_kernel( X, @@ -40,7 +40,7 @@ def triu_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("triu_batch"), + configs=runtime.get_tuned_config("triu_batch"), key=["batch", "MN", "N", "diagonal"], ) @triton.jit(do_not_specialize=["diagonal"]) diff --git a/src/flag_gems/ops/uniform.py b/src/flag_gems/ops/uniform.py index a215cbf6a..114a552bf 100644 --- a/src/flag_gems/ops/uniform.py +++ b/src/flag_gems/ops/uniform.py @@ -13,7 +13,7 @@ from ..runtime import torch_device_fn -@triton.heuristics(runtime.get_heuristics_config("uniform")) +@triton.heuristics(runtime.get_heuristic_config("uniform")) @triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) def uniform_kernel( out_ptr, diff --git a/src/flag_gems/ops/upsample_bicubic2d_aa.py b/src/flag_gems/ops/upsample_bicubic2d_aa.py index 6c44c107c..3d772473a 100644 --- a/src/flag_gems/ops/upsample_bicubic2d_aa.py +++ b/src/flag_gems/ops/upsample_bicubic2d_aa.py @@ -13,7 +13,7 @@ @triton.autotune( - configs=runtime.get_triton_config("upsample_bicubic2d_aa"), + configs=runtime.get_tuned_config("upsample_bicubic2d_aa"), key=["N", "C", "OH", "OW"], ) @triton.jit @@ -367,7 +367,7 @@ def upsample_bicubic2d_aa_kernel( # upsample and downsample @triton.autotune( - configs=runtime.get_triton_config("upsample_bicubic2d_aa"), + configs=runtime.get_tuned_config("upsample_bicubic2d_aa"), key=["N", "C", "OH", "OW"], ) @triton.jit diff --git a/src/flag_gems/ops/upsample_nearest2d.py b/src/flag_gems/ops/upsample_nearest2d.py index 8b4f70a32..ad07494d7 100644 --- a/src/flag_gems/ops/upsample_nearest2d.py +++ b/src/flag_gems/ops/upsample_nearest2d.py @@ -13,9 +13,9 @@ @triton.autotune( - configs=runtime.get_triton_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] + configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] ) -@triton.heuristics(runtime.get_heuristics_config("upsample_nearest2d")) +@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d")) @triton.jit def upsample_nearest2d_kernel( ptr_o, diff --git a/src/flag_gems/ops/var_mean.py b/src/flag_gems/ops/var_mean.py index 93fef7eb0..985e0c58e 100644 --- a/src/flag_gems/ops/var_mean.py +++ b/src/flag_gems/ops/var_mean.py @@ -22,7 +22,7 @@ def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y): @libentry() -@triton.autotune(configs=runtime.get_triton_config("var_mean"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"]) @triton.jit(do_not_specialize=["correction"]) def var_mean_welford_kernel( X, @@ -99,7 +99,7 @@ def var_mean_kernel_1( @libentry() -@triton.heuristics(runtime.get_heuristics_config("var_mean")) +@triton.heuristics(runtime.get_heuristic_config("var_mean")) @triton.jit(do_not_specialize=["correction"]) def var_mean_kernel_2( Acc, diff --git a/src/flag_gems/ops/vector_norm.py b/src/flag_gems/ops/vector_norm.py index 3542a18c8..bd87e640b 100644 --- a/src/flag_gems/ops/vector_norm.py +++ b/src/flag_gems/ops/vector_norm.py @@ -14,7 +14,7 @@ @libentry() -@triton.autotune(configs=runtime.get_triton_config("vector_norm"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) @triton.jit def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] @@ -62,7 +62,7 @@ def l2_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("vector_norm"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) @triton.jit def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] @@ -110,7 +110,7 @@ def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("vector_norm"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) @triton.jit def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] @@ -158,7 +158,7 @@ def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("vector_norm"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) @triton.jit def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] @@ -206,7 +206,7 @@ def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): @libentry() -@triton.autotune(configs=runtime.get_triton_config("vector_norm"), key=["M", "N"]) +@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) @triton.jit(do_not_specialize=["ord"]) def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] diff --git a/src/flag_gems/ops/vstack.py b/src/flag_gems/ops/vstack.py index 55991ec33..8e271d119 100644 --- a/src/flag_gems/ops/vstack.py +++ b/src/flag_gems/ops/vstack.py @@ -12,7 +12,7 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("vstack"), + configs=runtime.get_tuned_config("vstack"), key=[ "max_tile_elems", ], diff --git a/src/flag_gems/ops/weightnorm.py b/src/flag_gems/ops/weightnorm.py index 8810c7b14..9c9728534 100644 --- a/src/flag_gems/ops/weightnorm.py +++ b/src/flag_gems/ops/weightnorm.py @@ -13,7 +13,7 @@ @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_norm_kernel_last"), key=["M", "N"] + configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] ) @triton.jit(do_not_specialize=["eps"]) def weight_norm_kernel_last( @@ -55,7 +55,7 @@ def weight_norm_kernel_last( @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_norm_kernel_first"), key=["M", "N"] + configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] ) @triton.jit(do_not_specialize=["eps"]) def weight_norm_kernel_first( @@ -97,7 +97,7 @@ def weight_norm_kernel_first( @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_norm_kernel_last"), key=["M", "N"] + configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] ) @triton.jit(do_not_specialize=["eps"]) def weight_norm_bwd_kernel_last( @@ -149,7 +149,7 @@ def weight_norm_bwd_kernel_last( @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_norm_kernel_first"), key=["M", "N"] + configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] ) @triton.jit(do_not_specialize=["eps"]) def weight_norm_bwd_kernel_first( @@ -274,7 +274,7 @@ def backward(ctx, w_grad, norm_grad): @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_norm_kernel"), + configs=runtime.get_tuned_config("weight_norm_kernel"), key=["v_shape0", "v_shape1", "v_shape2"], ) @triton.jit(do_not_specialize=["eps"]) @@ -329,7 +329,7 @@ def weight_norm_except_dim_kernel( @libentry() @triton.autotune( - configs=runtime.get_triton_config("weight_norm_kernel"), + configs=runtime.get_tuned_config("weight_norm_kernel"), key=["v_shape0", "v_shape1", "v_shape2"], ) @triton.jit(do_not_specialize=["eps"]) diff --git a/src/flag_gems/runtime/__init__.py b/src/flag_gems/runtime/__init__.py index b6117ec12..3ceae865f 100644 --- a/src/flag_gems/runtime/__init__.py +++ b/src/flag_gems/runtime/__init__.py @@ -17,12 +17,12 @@ torch_backend_device = backend.get_torch_backend_device_fn() -def get_triton_config(op_name): - return config_loader.get_triton_config(op_name) +def get_tuned_config(op_name): + return config_loader.get_tuned_config(op_name) -def get_heuristics_config(op_name): +def get_heuristic_config(op_name): return config_loader.heuristics_config[op_name] -__all__ = ["commom_utils", "backend", "device", "get_triton_config"] +__all__ = ["commom_utils", "backend", "device", "get_tuned_config"] diff --git a/src/flag_gems/runtime/backend/__init__.py b/src/flag_gems/runtime/backend/__init__.py index e83252f8b..2aab08c73 100644 --- a/src/flag_gems/runtime/backend/__init__.py +++ b/src/flag_gems/runtime/backend/__init__.py @@ -124,7 +124,7 @@ def get_curent_device_unused_op(vendor_name=None): return list(vendor_module.get_unused_op()) -def get_heuristics_config(vendor_name=None): +def get_heuristic_config(vendor_name=None): global vendor_module get_vendor_module(vendor_name) return vendor_module.HEURISTICS_CONFIGS diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index 1493bcf1d..204776b1e 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -187,6 +187,14 @@ def var_mean_heur_block_n(args): return triton.next_power_of_2(args["BLOCK_NUM"]) +def upsample_nearest2d_SAME_H(args): + return args["OH"] == args["IH"] + + +def upsample_nearest2d_SAME_W(args): + return args["OW"] == args["IW"] + + HEURISTICS_CONFIGS = { "argmax": { "BLOCK_M": argmax_heur_block_m, @@ -217,8 +225,8 @@ def var_mean_heur_block_n(args): "EVEN_K": mm_heur_even_k, }, "multinomial": { - "NBLOCK": lambda args: 128, - "num_warps": lambda args: 4, + "NBLOCK": 128, + "num_warps": 4, }, "rand": { "BLOCK": rand_heur_block, @@ -252,8 +260,8 @@ def var_mean_heur_block_n(args): "num_warps": uniform_heur_num_warps, }, "upsample_nearest2d": { - "SAME_H": lambda args: args["OH"] == args["IH"], - "SAME_W": lambda args: args["OW"] == args["IW"], + "SAME_H": upsample_nearest2d_SAME_H, + "SAME_W": upsample_nearest2d_SAME_W, }, "var_mean": { "BLOCK_N": var_mean_heur_block_n, diff --git a/src/flag_gems/runtime/configloader.py b/src/flag_gems/runtime/configloader.py index e3693611d..69341353a 100644 --- a/src/flag_gems/runtime/configloader.py +++ b/src/flag_gems/runtime/configloader.py @@ -22,7 +22,6 @@ def __init__(self): # and is reserved from being an attr for vendor customizability self.primitive_yaml_config = self.get_vendor_tune_config() self.heuristics_config = self.get_vendor_heuristics_config() - # gen_key is an identifier that indicates whether the current config needs to be generated automatically self.gen_key = "gen" # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config @@ -36,10 +35,10 @@ def __init__(self): def load_all(self): for key in self.primitive_yaml_config: - self.loaded_triton_config[key] = self.get_triton_config(key) + self.loaded_triton_config[key] = self.get_tuned_config(key) def get_vendor_heuristics_config(self): - return backend.get_heuristics_config(self.device.vendor_name) + return backend.get_heuristic_config(self.device.vendor_name) def get_vendor_tune_config(self): return backend.get_tune_config(self.device.vendor_name) @@ -108,7 +107,7 @@ def to_gen_config(self, gen_config): current_config, ) - def get_triton_config(self, op_name): + def get_tuned_config(self, op_name): if op_name in self.loaded_triton_config: return self.loaded_triton_config[op_name]