Skip to content

Commit

Permalink
[Frontend] [BC breaking] Always follow C semantics on %
Browse files Browse the repository at this point in the history
The semantics of `%` in triton used to be type dependant (!!).

With this PR, we make `%` always follow C semantics, similar to `//`.

We update the type promotion docs fixing some inaccuracies. It is still
not entirely precise though. For a discussion of the current semantics
see triton-lang#4697
  • Loading branch information
lezcano committed Sep 10, 2024
1 parent 8a3fb7e commit 673ec9b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
4 changes: 1 addition & 3 deletions docs/python-api/triton-semantics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ The algorithm is as follows:

2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32``

3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32``

3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``.
3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16``

4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32``

Expand Down
34 changes: 18 additions & 16 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def do_test(x, y, kernel_fn):
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas)
err_msg = f"{expr}, {kernel_fn.__name__}"
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=5e-3, rtol=0.01)

def get_scalar(x, dtype, low, high, filter):
# If dtype is int, don't choose a huge number for the scalar
Expand Down Expand Up @@ -381,8 +381,7 @@ def get_scalar(x, dtype, low, high, filter):


def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# FIXME For large x, we are casting x to a floating point where it does not fit
# For small y, we are computing floor(div(float(x), y)) which may not fit
# FIXME For large x, we are casting x to a floating point where it does not fit!!
return (dtype_x, dtype_y) in [
('int32', 'bfloat16'),
('int32', 'float16'),
Expand All @@ -391,8 +390,6 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'bfloat16'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'bfloat16'),
('uint32', 'float16'),
Expand Down Expand Up @@ -425,19 +422,23 @@ def test_dtype_codegen():
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
expr = f'x {op} y'
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16',
'bfloat16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})')

# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
def promote_to_fp32(dtype_x, dtype_y):
return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64')

if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)):
numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)')
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})')
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})')
elif op == '%':
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = np_expr_gen('x', 'y')
else:
numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
Expand All @@ -452,6 +453,7 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
# while Triton performs it in bfloat16
# We also skip mod when it is ill-conditioned
skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y)
or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))
or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes
and _mod_operation_ill_conditioned(dtype_x, "float32")))
# can't divide by zero
Expand Down
10 changes: 4 additions & 6 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .._C.libtriton import ir
from . import core as tl
from . import math

T = TypeVar('T')

Expand Down Expand Up @@ -88,11 +87,12 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i
else:
return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if a_ty.is_bf16() and b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
else:
return tl.bfloat16
if a_ty.is_bf16() or b_ty.is_bf16():
return tl.float32
# 5) return fp16 if operands are different fp8
if a_ty.is_fp8() and b_ty.is_fp8():
Expand Down Expand Up @@ -305,9 +305,7 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder)
return ret
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
# % int
elif scalar_ty.is_int():
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
Expand Down

0 comments on commit 673ec9b

Please sign in to comment.