Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantile op #287

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0263c92
Create quantile.py
CelysPr Nov 11, 2024
39d1aa1
Fixed code format in quantile.py
CelysPr Nov 12, 2024
45b7c71
Add operation test and benchmark for quantile
CelysPr Nov 19, 2024
ff8d9ce
Update quantile.py
CelysPr Nov 19, 2024
c7d8d2f
Add test cases for quantile in test_general_reduction_ops.py
CelysPr Nov 19, 2024
ec35fa7
Add benchmark for quantile op in test_reduction_perf.py
CelysPr Nov 19, 2024
a3ad411
Merge branch 'patch-1' of https://github.com/CelysPr/FlagGems into pa…
CelysPr Dec 2, 2024
b3fabae
Merge remote-tracking branch 'origin/master' into patch-1
CelysPr Dec 2, 2024
5e88940
Update test_general_reduction_ops.py
CelysPr Dec 2, 2024
2f8ca20
fix
CelysPr Dec 7, 2024
785c27a
p
CelysPr Dec 8, 2024
802529e
add parameter "out" for quantile operator
CelysPr Dec 8, 2024
c73be35
add resolution for torch.float64 type
CelysPr Dec 8, 2024
496f092
bug fix & more test cases for quantile unit test
CelysPr Dec 8, 2024
1691b2d
generate_tensor_input supports torch.float64
CelysPr Dec 9, 2024
816a37a
add default shape for quantile op
CelysPr Dec 9, 2024
1c89735
Update benchmark for quantile op
CelysPr Dec 9, 2024
42da97d
Update resolution for torch.float64
CelysPr Dec 9, 2024
857a6f0
well
CelysPr Dec 9, 2024
da02de7
Merge branch 'FlagOpen:master' into patch-1
CelysPr Dec 10, 2024
7149913
Merge branch 'patch-1' of https://github.com/CelysPr/FlagGems into pa…
CelysPr Dec 10, 2024
3e0d7a5
fix typos
CelysPr Dec 10, 2024
21bfe87
op quantile: wider range of autotune & perf test
CelysPr Dec 10, 2024
1e1a0f5
fix quantile op
CelysPr Dec 12, 2024
8e40979
Merge branch 'patch-1' of https://github.com/CelysPr/FlagGems into pa…
CelysPr Dec 12, 2024
3e9e410
update ops/__init__.py
CelysPr Dec 13, 2024
3a1a082
Merge branch 'master' into patch-1
CelysPr Dec 26, 2024
66c073e
Op quantile: update devices
CelysPr Dec 26, 2024
f52c1e1
quantile op: remove torch.float64 type from op unit tests
CelysPr Jan 6, 2025
d1dceb2
quantile op: remove torch.float64 type from benchmark
CelysPr Jan 6, 2025
40630db
Remove torch.float64 type from performance_utils
CelysPr Jan 6, 2025
634c6c5
Update __init__.py
CelysPr Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions benchmark/core_shapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ diag:
- [256, 1024]
- [1024, 1024]

quantile:
shapes:
- [1048576,] # 1024 * 1024
- [64, 64]
- [4096, 256]
- [64, 512, 128]
- [20, 8, 10000]

BlasBenchmark:
shapes:
- [2, 4096, 4096, 4096]
Expand Down
2 changes: 1 addition & 1 deletion benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def set_more_shapes(self):


def generate_tensor_input(shape, dtype, device):
if dtype in FLOAT_DTYPES:
if dtype in FLOAT_DTYPES + [torch.float64]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don’t need to focus on torch.float64 accuracy for this case. In real-world model training and inference scenarios, float64 precision is rarely used.

return torch.randn(shape, dtype=dtype, device=device)
elif dtype in INT_DTYPES:
return torch.randint(
Expand Down
34 changes: 34 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .performance_utils import (
Benchmark,
Config,
GenericBenchmark,
GenericBenchmark2DOnly,
generate_tensor_input,
unary_input_fn,
Expand Down Expand Up @@ -179,3 +180,36 @@ def count_nonzero_input_fn(shape, dtype, device):
dtypes=FLOAT_DTYPES,
)
bench.run()


class quantileBenchmark(GenericBenchmark):
def set_more_shapes(self):
more_shapes_1d = [(4,), (1024,), (65535)]
more_shapes_2d = [(1024, 2**i) for i in range(0, 15, 3)]
more_shapes_3d = [(64, 64, 2**i) for i in range(0, 15, 3)]
return more_shapes_1d + more_shapes_2d + more_shapes_3d


def quantile_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
q = torch.tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=cur_dtype, device=device)
yield inp, q, 0


@pytest.mark.parametrize(
"op_name, torch_op, input_fn, dtypes",
[
pytest.param(
"quantile",
torch.quantile,
quantile_input_fn,
[torch.float32, torch.float64],
marks=pytest.mark.quantile,
)
],
)
def test_quantile_benchmark(op_name, torch_op, input_fn, dtypes):
bench = quantileBenchmark(
input_fn=input_fn, op_name=op_name, torch_op=torch_op, dtypes=dtypes
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("any", any, Autograd.disable),
("any.dim", any_dim, Autograd.disable),
("any.dims", any_dims, Autograd.disable),
("quantile", quantile, Autograd.disable),
("log_softmax.int", log_softmax, Autograd.enable),
("outer", outer, Autograd.enable),
("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from .pad import constant_pad_nd, pad
from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor
from .prod import prod, prod_dim
from .quantile import quantile
from .rand import rand
from .rand_like import rand_like
from .randn import randn
Expand Down Expand Up @@ -249,6 +250,7 @@
"argmax",
"prod",
"prod_dim",
"quantile",
"var_mean",
"vector_norm",
"log_softmax",
Expand Down
142 changes: 142 additions & 0 deletions src/flag_gems/ops/quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging

import torch
import triton
import triton.language as tl
from torch import Tensor

from ..runtime import torch_device_fn
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle

INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]


def heur_block_q(args):
return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))


def heur_block_n(args):
if args["N"] >= 65536:
return triton.next_power_of_2(triton.cdiv(args["N"], 512))
elif args["N"] >= 4096:
return triton.next_power_of_2(triton.cdiv(args["N"], 128))
elif args["N"] >= 64:
return 32
elif args["N"] >= 32:
return 4
else:
return 1


@libentry()
@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
@triton.jit
def quantile_kernel(
inp,
q,
out,
N,
M,
Q,
BLOCK_Q: tl.constexpr,
BLOCK_N: tl.constexpr,
interpolation: tl.constexpr,
):
pid_Q = tle.program_id(0)
pid_N = tle.program_id(1)
ctype = inp.dtype.element_ty

offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
mask_Q = offsets_Q < Q
q_ptrs = q + offsets_Q

offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
mask_N = offsets_N < N

out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
mask_out = mask_N[:, None] & mask_Q[None, :]

q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
q_lower = tl.floor(q_block).to(tl.int32)
q_upper = tl.ceil(q_block).to(tl.int32)

inp_lower = tl.load(
inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
)
inp_upper = tl.load(
inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
)

if interpolation == "linear":
q_frac = q_block - q_lower
tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)

elif interpolation == "lower":
tl.store(out_ptrs, inp_lower, mask_out)

elif interpolation == "higher":
tl.store(out_ptrs, inp_upper, mask_out)

elif interpolation == "nearest":
q_round = tl.extra.cuda.libdevice.rint(q_block)
out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
tl.store(out_ptrs, out_block, mask_out)

elif interpolation == "midpoint":
tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)


def quantile(
inp, q, dim=None, keepdim=False, interpolation="linear", out=None
) -> Tensor:
logging.debug("GEMS QUANTILE DIM")
assert torch.is_floating_point(inp)
assert dim is None or isinstance(dim, int)
assert isinstance(q, (float, torch.Tensor))
assert interpolation in INTERPOLATION_METHOD

M = inp.numel()
if isinstance(q, float):
q = torch.tensor(q, device=inp.device)
Q = 1
else:
Q = 1 if q.numel() == 1 else len(q)

assert M > 0
assert Q > 0
assert torch.all(q >= 0.0) and torch.all(q <= 1.0)

if dim is None:
inp = inp.ravel()
dim = 0

shape = list(inp.shape)

dim %= inp.ndim
inp = dim_compress(inp, dim)
M = shape[dim]
N = inp.numel() // M

inp, _ = inp.sort() # Sort the input with torch.sort()
output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)

grid = lambda meta: (
triton.cdiv(Q, meta["BLOCK_Q"]),
triton.cdiv(N, meta["BLOCK_N"]),
)

with torch_device_fn.device(inp.device):
quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation)

output = output.permute(
(-1,) + tuple(range(0, inp.ndim - 1))
) # Same as torch.quantile()
if keepdim:
output = output.unsqueeze(dim + 1)
if Q == 1:
output = output.squeeze(0)

if out is not None:
out.copy_(output)
return output
1 change: 1 addition & 0 deletions src/flag_gems/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
torch.float16: 1e-3,
torch.float32: 1.3e-6,
torch.bfloat16: 0.016,
torch.float64: 1e-7,
}


Expand Down
59 changes: 59 additions & 0 deletions tests/test_general_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,62 @@ def test_accuracy_sum_dim(shape, dim, keepdim, dtype):
if dim == []:
_dim = inp.numel()
gems_assert_close(res_out, ref_out, dtype, reduce_dim=_dim)


QUANTILE_SHAPES = REDUCTION_SMALL_SHAPES + [(10, 64, 196), (65535, 1)]
QUANTILE_FLOAT_DTYPES = [torch.float32, torch.float64]
QUANTILE_Q = (
[(0.2, 0.5, 0.8)]
if QUICK_MODE
else [(0.4), (0.0, 0.2, 0.5, 0.8, 1.0), (0.662, 0.8, 0.104, 0.99, 0.347, 0.255)]
)
QUANTILE_INTERPOLATION = (
["linear"] if QUICK_MODE else ["linear", "lower", "higher", "nearest", "midpoint"]
)


@pytest.mark.quantile
@pytest.mark.parametrize("shape", QUANTILE_SHAPES)
@pytest.mark.parametrize("dtype", QUANTILE_FLOAT_DTYPES)
@pytest.mark.parametrize("q", QUANTILE_Q)
@pytest.mark.parametrize("interpolation", QUANTILE_INTERPOLATION)
def test_accuracy_quantile_without_dim(shape, dtype, q, interpolation):
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp = to_reference(inp)
q = torch.tensor(q, dtype=dtype, device=inp.device)
ref_q = to_reference(q)

ref_out = torch.quantile(ref_inp, ref_q, interpolation=interpolation)
with flag_gems.use_gems():
res_out = torch.quantile(inp, q, interpolation=interpolation)

gems_assert_close(res_out, ref_out, dtype, reduce_dim=inp.numel())


@pytest.mark.quantile
@pytest.mark.parametrize("shape", QUANTILE_SHAPES)
@pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM)
@pytest.mark.parametrize("dtype", QUANTILE_FLOAT_DTYPES)
@pytest.mark.parametrize("q", QUANTILE_Q)
@pytest.mark.parametrize("interpolation", QUANTILE_INTERPOLATION)
def test_accuracy_quantile_dim(shape, dim, keepdim, dtype, q, interpolation):
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp = to_reference(inp)
q = torch.tensor(q, dtype=dtype, device=inp.device)
ref_q = to_reference(q)

ref_out = torch.quantile(
ref_inp, ref_q, dim=dim, keepdim=keepdim, interpolation=interpolation
)
with flag_gems.use_gems():
res_out = torch.quantile(
inp, q, dim=dim, keepdim=keepdim, interpolation=interpolation
)

if isinstance(dim, int):
dim = [dim]
dim = [d % inp.ndim for d in dim]
_dim = 1
for d in dim:
_dim *= shape[d]
gems_assert_close(res_out, ref_out, dtype, reduce_dim=_dim)