Skip to content

Commit

Permalink
op quantile: wider range of autotune & perf test
Browse files Browse the repository at this point in the history
  • Loading branch information
CelysPr committed Dec 10, 2024
1 parent 3e0d7a5 commit 21bfe87
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
6 changes: 3 additions & 3 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes):

class quantileBenchmark(GenericBenchmark):
def set_more_shapes(self):
more_shapes_1d = [(4,), (1024,)]
more_shapes_2d = [(1024, 2**i) for i in range(0, 10)]
more_shapes_3d = [(64, 64, 2**i) for i in range(0, 7)]
more_shapes_1d = [(4,), (1024,), (65535)]
more_shapes_2d = [(1024, 2**i) for i in range(0, 15, 3)]
more_shapes_3d = [(64, 64, 2**i) for i in range(0, 15, 3)]
return more_shapes_1d + more_shapes_2d + more_shapes_3d


Expand Down
24 changes: 13 additions & 11 deletions src/flag_gems/ops/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,26 @@
import torch
import triton
import triton.language as tl
from torch import Tensor, tensor
from torch import Tensor

from ..utils import dim_compress, libentry

INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]


def cfggen(one_dim=False):
block_q = tensor([1, 2, 4, 8], dtype=torch.int32)
if one_dim:
configs = [triton.Config({"BLOCK_Q": q.item()}, num_warps=4) for q in block_q]
else:
block_n = tensor([2**i for i in range(6, 11)], dtype=torch.int32)
x, y = torch.meshgrid(block_n, block_q, indexing="ij")
configs = [
triton.Config({"BLOCK_Q": q.item(), "BLOCK_N": n.item()}, num_warps=4)
for n, q in zip(x.ravel(), y.ravel())
]
block_q = [2**1 for i in range(0, 6)]
warp = [1, 2, 4, 8, 16, 32]
configs = []
for q in block_q:
for w in warp:
configs.append(triton.Config({"BLOCK_Q": q}, num_warps=w))
if not one_dim:
block_n = [2**i for i in range(1, 16)]
for c in configs:
for n in block_n:
c["BLOCK_N"] = n

return configs


Expand Down

0 comments on commit 21bfe87

Please sign in to comment.