Skip to content

Commit

Permalink
modify
Browse files Browse the repository at this point in the history
  • Loading branch information
Galaxy1458 committed Dec 26, 2024
1 parent ffdf37c commit 609ebe9
Show file tree
Hide file tree
Showing 43 changed files with 97 additions and 90 deletions.
2 changes: 1 addition & 1 deletion src/flag_gems/ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("addmm"),
configs=runtime.get_tuned_config("addmm"),
key=["M", "N", "K"],
)
@triton.jit(do_not_specialize=["alpha", "beta"])
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def reduce_all(a, b):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("all"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"])
@triton.jit
def all_kernel_dim(
inp,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("amax"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"])
@triton.jit
def amax_kernel(
inp,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def reduce_any(a, b):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("any"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"])
@triton.jit
def any_kernel_dim(
inp,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr


@libentry()
@triton.heuristics(runtime.get_heuristics_config("argmax"))
@triton.heuristics(runtime.get_heuristic_config("argmax"))
@triton.jit
def argmax_kernel(
inp,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("bmm"),
configs=runtime.get_tuned_config("bmm"),
key=["M", "N", "K"],
)
@triton.heuristics(runtime.get_heuristics_config("bmm"))
@triton.heuristics(runtime.get_heuristic_config("bmm"))
@triton.jit
def bmm_kernel(
A,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/count_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("count_nonzero"), key=["numel"])
@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
@triton.jit
def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
pid_x = tle.program_id(0)
Expand All @@ -41,7 +41,7 @@ def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("count_nonzero"), key=["numel"])
@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
@triton.jit
def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
pid_x = tle.program_id(0)
Expand Down
14 changes: 7 additions & 7 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand Down Expand Up @@ -73,7 +73,7 @@ def celoss_indices_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["label_smoothing"])
Expand Down Expand Up @@ -132,7 +132,7 @@ def celoss_probability_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
Expand Down Expand Up @@ -213,7 +213,7 @@ def celoss_indices_smooth_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
Expand Down Expand Up @@ -284,7 +284,7 @@ def celoss_indices_bwd(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
Expand Down Expand Up @@ -369,7 +369,7 @@ def celoss_probability_bwd(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
Expand Down Expand Up @@ -463,7 +463,7 @@ def celoss_indices_smooth_bwd(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss_sum_and_scale"),
configs=runtime.get_tuned_config("cross_entropy_loss_sum_and_scale"),
key=[
"N",
],
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..runtime import torch_device_fn


@triton.heuristics(runtime.get_heuristics_config("dropout"))
@triton.heuristics(runtime.get_heuristic_config("dropout"))
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_forward_kernel(
X,
Expand Down Expand Up @@ -65,7 +65,7 @@ def dropout_forward_kernel(
tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")


@triton.heuristics(runtime.get_heuristics_config("dropout"))
@triton.heuristics(runtime.get_heuristic_config("dropout"))
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_backward_kernel(
DY,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/exponential_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..runtime import torch_device_fn


@triton.heuristics(runtime.get_heuristics_config("exponential_"))
@triton.heuristics(runtime.get_heuristic_config("exponential_"))
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel(
out_ptr,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate_gather_kernel(
code.writeline("@libentry()")
code.writeline("@triton.heuristics(")
with code.indent():
code.writeline("runtime.get_heuristics_config('gather')")
code.writeline("runtime.get_heuristic_config('gather')")
code.writeline(")")
code.writeline("@triton.jit")

Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@libentry()
@triton.heuristics(runtime.get_heuristics_config("index_select"))
@triton.heuristics(runtime.get_heuristic_config("index_select"))
@triton.jit
def index_select_kernel(
inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
Expand Down
12 changes: 6 additions & 6 deletions src/flag_gems/ops/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def prev_multiple_of(a, b):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instancenorm"),
configs=runtime.get_tuned_config("instancenorm"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -72,7 +72,7 @@ def instance_norm_persistent_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instancenorm"),
configs=runtime.get_tuned_config("instancenorm"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -126,7 +126,7 @@ def instance_norm_persistent_kernel_multiline(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instance_norm_loop"),
configs=runtime.get_tuned_config("instance_norm_loop"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -218,7 +218,7 @@ def instance_norm_loop_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instancenorm"),
configs=runtime.get_tuned_config("instancenorm"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -312,7 +312,7 @@ def update_running_stats_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instance_norm_backward"),
configs=runtime.get_tuned_config("instance_norm_backward"),
key=["M", "N", "C"],
)
@triton.jit
Expand Down Expand Up @@ -379,7 +379,7 @@ def instance_norm_backward_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instance_norm_weight_bias_backward"),
configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"),
key=["N", "B", "C"],
)
@triton.jit
Expand Down
10 changes: 5 additions & 5 deletions src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def prev_multiple_of(a, b):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("layer_norm_persistent"),
configs=runtime.get_tuned_config("layer_norm_persistent"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -69,7 +69,7 @@ def layer_norm_persistent_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("layer_norm_persistent"),
configs=runtime.get_tuned_config("layer_norm_persistent"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -123,7 +123,7 @@ def layer_norm_persistent_kernel_multiline(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("layer_norm_loop"),
configs=runtime.get_tuned_config("layer_norm_loop"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -220,7 +220,7 @@ def layer_norm_loop_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("layer_norm_backward"),
configs=runtime.get_tuned_config("layer_norm_backward"),
key=["M", "N"],
)
@triton.jit
Expand Down Expand Up @@ -288,7 +288,7 @@ def layer_norm_backward_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("weight_bias_backward"),
configs=runtime.get_tuned_config("weight_bias_backward"),
key=["N"],
)
@triton.jit
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def log_softmax_kernel(


@libentry()
@triton.autotune(configs=runtime.get_triton_config("log_softmax"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
@triton.jit
def log_softmax_backward_kernel(
out_ptr,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/masked_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@libentry()
@triton.autotune(configs=runtime.get_triton_config("masked_fill"), key=["N"])
@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
@triton.jit
def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr):
pid = tle.program_id(axis=0)
Expand All @@ -24,7 +24,7 @@ def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr


@libentry()
@triton.autotune(configs=runtime.get_triton_config("masked_fill"), key=["N"])
@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
@triton.jit
def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr):
pid = tle.program_id(axis=0)
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/masked_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@libentry()
@triton.autotune(configs=runtime.get_triton_config("masked_select"), key=["n_elements"])
@triton.autotune(configs=runtime.get_tuned_config("masked_select"), key=["n_elements"])
@triton.jit
def masked_select_kernel(
inp_ptr,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def heur_block_n(args):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("max"),
configs=runtime.get_tuned_config("max"),
key=[
"M",
"N",
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def mean(inp, *, dtype=None):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("mean"),
configs=runtime.get_tuned_config("mean"),
key=["M", "N"],
)
@triton.jit
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/min.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def heur_block_n(args):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("min"),
configs=runtime.get_tuned_config("min"),
key=[
"M",
"N",
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

@libentry()
@libtuner(
configs=runtime.get_triton_config("mm"),
configs=runtime.get_tuned_config("mm"),
key=["M", "N", "K"],
)
@triton.heuristics(runtime.get_heuristics_config("mm"))
@triton.heuristics(runtime.get_heuristic_config("mm"))
@triton.jit
def mm_kernel(
A,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@libentry()
@triton.heuristics(runtime.get_heuristics_config("multinomial"))
@triton.heuristics(runtime.get_heuristic_config("multinomial"))
@triton.jit(do_not_specialize=["K", "N", "philox_seed", "philox_offset"])
def multinomial_with_replacement(
cdf_ptr, out_ptr, K, N, philox_seed, philox_offset, NBLOCK: tl.constexpr
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("mv"),
configs=runtime.get_tuned_config("mv"),
key=["M", "N"],
)
@triton.jit
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("nonzero"),
configs=runtime.get_tuned_config("nonzero"),
key=[
"n_elements",
],
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def heur_block_n(args):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("prod"),
configs=runtime.get_tuned_config("prod"),
key=[
"M",
"N",
Expand Down
Loading

0 comments on commit 609ebe9

Please sign in to comment.