diff --git a/docs/python-api/triton-semantics.rst b/docs/python-api/triton-semantics.rst index 178fb94bbae7..b298274ad5eb 100644 --- a/docs/python-api/triton-semantics.rst +++ b/docs/python-api/triton-semantics.rst @@ -11,9 +11,7 @@ The algorithm is as follows: 2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32`` -3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32`` - - 3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``. +3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16`` 4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32`` diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index af8f414ad31c..1adb76be9c90 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -346,7 +346,7 @@ def do_test(x, y, kernel_fn): z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) err_msg = f"{expr}, {kernel_fn.__name__}" - np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=5e-3, rtol=0.01) def get_scalar(x, dtype, low, high, filter): # If dtype is int, don't choose a huge number for the scalar @@ -381,8 +381,7 @@ def get_scalar(x, dtype, low, high, filter): def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: - # FIXME For large x, we are casting x to a floating point where it does not fit - # For small y, we are computing floor(div(float(x), y)) which may not fit + # FIXME For large x, we are casting x to a floating point where it does not fit!! return (dtype_x, dtype_y) in [ ('int32', 'bfloat16'), ('int32', 'float16'), @@ -391,8 +390,6 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ('int64', 'float16'), ('int64', 'float32'), ('int64', 'float64'), - ('uint16', 'bfloat16'), - ('uint16', 'float16'), ('uint16', 'float32'), ('uint32', 'bfloat16'), ('uint32', 'float16'), @@ -425,19 +422,23 @@ def test_dtype_codegen(): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): expr = f'x {op} y' - if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: - # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. - numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', - 'bfloat16'): - # Triton promotes 16-bit floating-point / and % to 32-bit because there - # are no native div or FRem operations on float16. Since we have to - # convert anyway, we may as well take the accuracy bump. - numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') + + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + def promote_to_fp32(dtype_x, dtype_y): + return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') + + if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): + numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') + elif op == '%': + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = np_expr_gen('x', 'y') else: numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): @@ -452,6 +453,7 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): # while Triton performs it in bfloat16 # We also skip mod when it is ill-conditioned skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) + or (op in ('/', '%') and dtype_x in ("float16", "bfloat16")) or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes and _mod_operation_ill_conditioned(dtype_x, "float32"))) # can't divide by zero diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 188e3279a80f..583e5795db3f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -6,7 +6,6 @@ from .._C.libtriton import ir from . import core as tl -from . import math T = TypeVar('T') @@ -88,11 +87,12 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i else: return tl.float16 # 4) return bf16 only if both operands are of bf16 - if a_ty.is_bf16() or b_ty.is_bf16(): + if a_ty.is_bf16() and b_ty.is_bf16(): if div_or_mod: return tl.float32 - if a_ty.is_bf16() and b_ty.is_bf16(): + else: return tl.bfloat16 + if a_ty.is_bf16() or b_ty.is_bf16(): return tl.float32 # 5) return fp16 if operands are different fp8 if a_ty.is_fp8() and b_ty.is_fp8(): @@ -305,9 +305,7 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu other_scalar_ty = other.type.scalar # float % float if scalar_ty.is_floating(): - # input - input.div(other, rounding_mode="floor") * other - ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) - return ret + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) # % int elif scalar_ty.is_int(): if scalar_ty.int_signedness != other_scalar_ty.int_signedness: