Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [BC breaking] Always follow C semantics on % #4698

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/python-api/triton-semantics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ The algorithm is as follows:

2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32``

3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32``

3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``.
3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16``

4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32``

Expand Down
34 changes: 18 additions & 16 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def do_test(x, y, kernel_fn):
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas)
err_msg = f"{expr}, {kernel_fn.__name__}"
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=5e-3, rtol=0.01)

def get_scalar(x, dtype, low, high, filter):
# If dtype is int, don't choose a huge number for the scalar
Expand Down Expand Up @@ -381,8 +381,7 @@ def get_scalar(x, dtype, low, high, filter):


def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# FIXME For large x, we are casting x to a floating point where it does not fit
# For small y, we are computing floor(div(float(x), y)) which may not fit
# FIXME For large x, we are casting x to a floating point where it does not fit!!
return (dtype_x, dtype_y) in [
('int32', 'bfloat16'),
('int32', 'float16'),
Expand All @@ -391,8 +390,6 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'bfloat16'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'bfloat16'),
('uint32', 'float16'),
Expand Down Expand Up @@ -425,19 +422,23 @@ def test_dtype_codegen():
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
expr = f'x {op} y'
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16',
'bfloat16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})')

# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
def promote_to_fp32(dtype_x, dtype_y):
return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64')

if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)):
numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)')
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})')
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})')
elif op == '%':
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = np_expr_gen('x', 'y')
else:
numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
Expand All @@ -452,6 +453,7 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
# while Triton performs it in bfloat16
# We also skip mod when it is ill-conditioned
skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y)
or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))
or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes
and _mod_operation_ill_conditioned(dtype_x, "float32")))
# can't divide by zero
Expand Down
10 changes: 4 additions & 6 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .._C.libtriton import ir
from . import core as tl
from . import math

T = TypeVar('T')

Expand Down Expand Up @@ -88,11 +87,12 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i
else:
return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if a_ty.is_bf16() and b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
else:
return tl.bfloat16
if a_ty.is_bf16() or b_ty.is_bf16():
return tl.float32
# 5) return fp16 if operands are different fp8
if a_ty.is_fp8() and b_ty.is_fp8():
Expand Down Expand Up @@ -305,9 +305,7 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder)
return ret
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
# % int
elif scalar_ty.is_int():
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
Expand Down
Loading