-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add quantile op #287
Open
CelysPr
wants to merge
32
commits into
FlagOpen:master
Choose a base branch
from
CelysPr:patch-1
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add quantile op #287
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
0263c92
Create quantile.py
CelysPr 39d1aa1
Fixed code format in quantile.py
CelysPr 45b7c71
Add operation test and benchmark for quantile
CelysPr ff8d9ce
Update quantile.py
CelysPr c7d8d2f
Add test cases for quantile in test_general_reduction_ops.py
CelysPr ec35fa7
Add benchmark for quantile op in test_reduction_perf.py
CelysPr a3ad411
Merge branch 'patch-1' of https://github.com/CelysPr/FlagGems into pa…
CelysPr b3fabae
Merge remote-tracking branch 'origin/master' into patch-1
CelysPr 5e88940
Update test_general_reduction_ops.py
CelysPr 2f8ca20
fix
CelysPr 785c27a
p
CelysPr 802529e
add parameter "out" for quantile operator
CelysPr c73be35
add resolution for torch.float64 type
CelysPr 496f092
bug fix & more test cases for quantile unit test
CelysPr 1691b2d
generate_tensor_input supports torch.float64
CelysPr 816a37a
add default shape for quantile op
CelysPr 1c89735
Update benchmark for quantile op
CelysPr 42da97d
Update resolution for torch.float64
CelysPr 857a6f0
well
CelysPr da02de7
Merge branch 'FlagOpen:master' into patch-1
CelysPr 7149913
Merge branch 'patch-1' of https://github.com/CelysPr/FlagGems into pa…
CelysPr 3e0d7a5
fix typos
CelysPr 21bfe87
op quantile: wider range of autotune & perf test
CelysPr 1e1a0f5
fix quantile op
CelysPr 8e40979
Merge branch 'patch-1' of https://github.com/CelysPr/FlagGems into pa…
CelysPr 3e9e410
update ops/__init__.py
CelysPr 3a1a082
Merge branch 'master' into patch-1
CelysPr 66c073e
Op quantile: update devices
CelysPr f52c1e1
quantile op: remove torch.float64 type from op unit tests
CelysPr d1dceb2
quantile op: remove torch.float64 type from benchmark
CelysPr 40630db
Remove torch.float64 type from performance_utils
CelysPr 634c6c5
Update __init__.py
CelysPr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
from torch import Tensor | ||
|
||
from ..runtime import torch_device_fn | ||
from ..utils import dim_compress, libentry | ||
from ..utils import triton_lang_extension as tle | ||
|
||
INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] | ||
|
||
|
||
def heur_block_q(args): | ||
return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) | ||
|
||
|
||
def heur_block_n(args): | ||
if args["N"] >= 65536: | ||
return triton.next_power_of_2(triton.cdiv(args["N"], 512)) | ||
elif args["N"] >= 4096: | ||
return triton.next_power_of_2(triton.cdiv(args["N"], 128)) | ||
elif args["N"] >= 64: | ||
return 32 | ||
elif args["N"] >= 32: | ||
return 4 | ||
else: | ||
return 1 | ||
|
||
|
||
@libentry() | ||
@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) | ||
@triton.jit | ||
def quantile_kernel( | ||
inp, | ||
q, | ||
out, | ||
N, | ||
M, | ||
Q, | ||
BLOCK_Q: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
interpolation: tl.constexpr, | ||
): | ||
pid_Q = tle.program_id(0) | ||
pid_N = tle.program_id(1) | ||
ctype = inp.dtype.element_ty | ||
|
||
offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) | ||
mask_Q = offsets_Q < Q | ||
q_ptrs = q + offsets_Q | ||
|
||
offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) | ||
mask_N = offsets_N < N | ||
|
||
out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] | ||
mask_out = mask_N[:, None] & mask_Q[None, :] | ||
|
||
q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) | ||
q_lower = tl.floor(q_block).to(tl.int32) | ||
q_upper = tl.ceil(q_block).to(tl.int32) | ||
|
||
inp_lower = tl.load( | ||
inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 | ||
) | ||
inp_upper = tl.load( | ||
inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 | ||
) | ||
|
||
if interpolation == "linear": | ||
q_frac = q_block - q_lower | ||
tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) | ||
|
||
elif interpolation == "lower": | ||
tl.store(out_ptrs, inp_lower, mask_out) | ||
|
||
elif interpolation == "higher": | ||
tl.store(out_ptrs, inp_upper, mask_out) | ||
|
||
elif interpolation == "nearest": | ||
q_round = tl.extra.cuda.libdevice.rint(q_block) | ||
out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) | ||
tl.store(out_ptrs, out_block, mask_out) | ||
|
||
elif interpolation == "midpoint": | ||
tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) | ||
|
||
|
||
def quantile( | ||
inp, q, dim=None, keepdim=False, interpolation="linear", out=None | ||
) -> Tensor: | ||
logging.debug("GEMS QUANTILE DIM") | ||
assert torch.is_floating_point(inp) | ||
assert dim is None or isinstance(dim, int) | ||
assert isinstance(q, (float, torch.Tensor)) | ||
assert interpolation in INTERPOLATION_METHOD | ||
|
||
M = inp.numel() | ||
if isinstance(q, float): | ||
q = torch.tensor(q, device=inp.device) | ||
Q = 1 | ||
else: | ||
Q = 1 if q.numel() == 1 else len(q) | ||
|
||
assert M > 0 | ||
assert Q > 0 | ||
assert torch.all(q >= 0.0) and torch.all(q <= 1.0) | ||
|
||
if dim is None: | ||
inp = inp.ravel() | ||
dim = 0 | ||
|
||
shape = list(inp.shape) | ||
|
||
dim %= inp.ndim | ||
inp = dim_compress(inp, dim) | ||
M = shape[dim] | ||
N = inp.numel() // M | ||
|
||
inp, _ = inp.sort() # Sort the input with torch.sort() | ||
output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) | ||
|
||
grid = lambda meta: ( | ||
triton.cdiv(Q, meta["BLOCK_Q"]), | ||
triton.cdiv(N, meta["BLOCK_N"]), | ||
) | ||
|
||
with torch_device_fn.device(inp.device): | ||
quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation) | ||
|
||
output = output.permute( | ||
(-1,) + tuple(range(0, inp.ndim - 1)) | ||
) # Same as torch.quantile() | ||
if keepdim: | ||
output = output.unsqueeze(dim + 1) | ||
if Q == 1: | ||
output = output.squeeze(0) | ||
|
||
if out is not None: | ||
out.copy_(output) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
torch.float16: 1e-3, | ||
torch.float32: 1.3e-6, | ||
torch.bfloat16: 0.016, | ||
torch.float64: 1e-7, | ||
} | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don’t need to focus on
torch.float64
accuracy for this case. In real-world model training and inference scenarios,float64
precision is rarely used.