Skip to content

Commit

Permalink
fix quantile op
Browse files Browse the repository at this point in the history
  • Loading branch information
CelysPr committed Dec 12, 2024
1 parent 21bfe87 commit 1e1a0f5
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 299 deletions.
1 change: 1 addition & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def set_more_shapes(self):
def quantile_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
q = torch.tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=cur_dtype, device=device)
print(shape)
yield inp, q, 0


Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def enable(lib=aten_lib, unused=None):
("any", any, Autograd.disable),
("any.dim", any_dim, Autograd.disable),
("any.dims", any_dims, Autograd.disable),
("quantile", quantile, Autograd.disable),
("log_softmax.int", log_softmax, Autograd.enable),
("outer", outer, Autograd.enable),
("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
Expand Down
123 changes: 34 additions & 89 deletions src/flag_gems/ops/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,91 +10,27 @@
INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]


def cfggen(one_dim=False):
block_q = [2**1 for i in range(0, 6)]
warp = [1, 2, 4, 8, 16, 32]
configs = []
for q in block_q:
for w in warp:
configs.append(triton.Config({"BLOCK_Q": q}, num_warps=w))
if not one_dim:
block_n = [2**i for i in range(1, 16)]
for c in configs:
for n in block_n:
c["BLOCK_N"] = n

return configs
def heur_block_q(args):
return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))


@libentry()
@triton.autotune(configs=cfggen(True), key=["M", "Q"])
@triton.jit
def quantile_kernel_1d(
inp, q, out, M, Q, BLOCK_Q: tl.constexpr, interpolation: tl.constexpr
):
pid = tl.program_id(0)
ctype = inp.dtype.element_ty

offsets = pid * BLOCK_Q + tl.arange(0, BLOCK_Q)
mask = offsets < Q
q_ptrs = q + offsets
out_ptrs = out + offsets

q_block = tl.load(q_ptrs, mask, 0.0).to(ctype) * (M - 1)
q_lower = tl.floor(q_block).to(tl.int32)
q_upper = tl.ceil(q_block).to(tl.int32)
inp_lower = tl.load(inp + q_lower)
inp_upper = tl.load(inp + q_upper)

if interpolation == "linear":
q_frac = q_block - q_lower
tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask)

elif interpolation == "lower":
tl.store(out_ptrs, inp_lower, mask)

elif interpolation == "higher":
tl.store(out_ptrs, inp_upper, mask)

elif interpolation == "nearest":
q_near = tl.where(q_block - q_lower > q_upper - q_block, inp_upper, inp_lower)
tl.store(out_ptrs, q_near, mask)

elif interpolation == "midpoint":
tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask)


def quantile(inp, q, *, interpolation="linear", out=None) -> Tensor:
logging.debug("GEMS QUANTILE")
assert torch.is_floating_point(inp)
assert isinstance(q, (float, torch.Tensor))
assert interpolation in INTERPOLATION_METHOD

M = inp.numel()
if isinstance(q, float):
q = torch.tensor(q, device=inp.device)
Q = len(q)

assert M > 0
assert Q > 0
assert torch.all(q >= 0.0) and torch.all(q <= 1.0)

inp, _ = inp.sort() # Sort the input with torch.sort()
output = torch.empty(q.shape, dtype=inp.dtype, device=inp.device)
grid = lambda meta: [triton.cdiv(Q, meta["BLOCK_Q"])]

with torch.cuda.device(inp.device):
quantile_kernel_1d[grid](inp, q, output, M, Q, interpolation=interpolation)

if out is not None:
out.copy_(output)
return output
def heur_block_n(args):
if args["N"] >= 65536:
return triton.next_power_of_2(triton.cdiv(args["N"], 512))
elif args["N"] >= 4096:
return triton.next_power_of_2(triton.cdiv(args["N"], 128))
elif args["N"] >= 64:
return 32
elif args["N"] >= 32:
return 4
else:
return 1


@libentry()
@triton.autotune(configs=cfggen(), key=["N", "M", "Q"])
@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
@triton.jit
def quantile_kernel_2d(
def quantile_kernel(
inp,
q,
out,
Expand Down Expand Up @@ -123,8 +59,12 @@ def quantile_kernel_2d(
q_lower = tl.floor(q_block).to(tl.int32)
q_upper = tl.ceil(q_block).to(tl.int32)

inp_lower = tl.load(inp + offsets_N[:, None] * M + q_lower[None, :])
inp_upper = tl.load(inp + offsets_N[:, None] * M + q_upper[None, :])
inp_lower = tl.load(
inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
)
inp_upper = tl.load(
inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
)

if interpolation == "linear":
q_frac = q_block - q_lower
Expand All @@ -137,15 +77,16 @@ def quantile_kernel_2d(
tl.store(out_ptrs, inp_upper, mask_out)

elif interpolation == "nearest":
q_near = tl.where(q_block - q_lower > q_upper - q_block, inp_upper, inp_lower)
tl.store(out_ptrs, q_near, mask_out)
q_round = tl.extra.cuda.libdevice.rint(q_block)
out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
tl.store(out_ptrs, out_block, mask_out)

elif interpolation == "midpoint":
tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)


def quantile_dim(
inp, q, dim=None, keepdim=False, *, interpolation="linear", out=None
def quantile(
inp, q, dim=None, keepdim=False, interpolation="linear", out=None
) -> Tensor:
logging.debug("GEMS QUANTILE DIM")
assert torch.is_floating_point(inp)
Expand All @@ -156,7 +97,9 @@ def quantile_dim(
M = inp.numel()
if isinstance(q, float):
q = torch.tensor(q, device=inp.device)
Q = len(q)
Q = 1
else:
Q = 1 if q.numel() == 1 else len(q)

assert M > 0
assert Q > 0
Expand All @@ -176,19 +119,21 @@ def quantile_dim(
inp, _ = inp.sort() # Sort the input with torch.sort()
output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)

grid = lambda meta: [
grid = lambda meta: (
triton.cdiv(Q, meta["BLOCK_Q"]),
triton.cdiv(N, meta["BLOCK_N"]),
]
)

with torch.cuda.device(inp.device):
quantile_kernel_2d[grid](inp, q, output, N, M, Q, interpolation=interpolation)
quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation)

output = output.permute(
(-1,) + tuple(range(0, inp.ndim - 1))
) # Same as torch.quantile()
if keepdim:
output = output.unsqueeze(dim + 1)
if Q == 1:
output = output.squeeze(0)

if out is not None:
out.copy_(output)
Expand Down
Loading

0 comments on commit 1e1a0f5

Please sign in to comment.