Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton type promotion semantics for floating dtypes are very non-intuitive #4697

Open
lezcano opened this issue Sep 10, 2024 · 1 comment
Open

Comments

@lezcano
Copy link
Contributor

lezcano commented Sep 10, 2024

I recently wrote some documentation describing the triton promotion semantics. It turns out that part of it was not entirely correct. The reality is much murkier.

The type promotion semantics around fp16 and bf16 between them and fp8 dtypes are incredibly odd. We perform with the following precedence:

  1. If the operation is a division or a mod, and an input is of fp16 or bf16, the result is of type fp32 (??!!). e.g. fp16 / fp16 = fp32, bf16 % bf16 = bf16.
  2. If both inputs are of the same dtype, then perform the computation in that dtype (fair)
  3. If you mix fp16 and bf16, it will return fp16 (fair)
  4. If you mix bf16 and any fp8 it will return fp32 ??!!

I propose changing them (the semantics of binary ops of floating types) to:

  1. Preserve dtype If both inputs are of the same dtype, then perform the computation in that dtype. With this fp16 / fp16 = fp16 as expected. We simulate this on SW upcasting to fp32 and downcasting to the correct dtype for mod and div.
  2. Width A narrower dtype will be cast to the wider dtype. This way fp8 x bf16 will return bf16
  3. Prefer float16 If two fp dtypes have the same width, the operation will be performed on fp16. This covers the current semantics of fp8 x fp8 returns fp16 and ``fp16xbf16` returns `fp16`.

For reference, the current implementation lives at

# 3 ) if one operand is half, the other is implicitly converted to half
# unless we're doing / or %, which do not exist natively in PTX for fp16.
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
if a_ty.is_fp16() or b_ty.is_fp16():
if div_or_mod:
return tl.float32
else:
return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
return tl.bfloat16
return tl.float32

lezcano added a commit to lezcano/triton that referenced this issue Sep 10, 2024
The semantics of `%` in triton used to be type dependant (!!).

With this PR, we make `%` always follow C semantics, similar to `//`.

We update the type promotion docs fixing some inaccuracies. It is still
not entirely precise though. For a discussion of the current semantics
see triton-lang#4697
@iclementine
Copy link

iclementine commented Sep 12, 2024

Hello,can you also check that does this PR changes this issue #4676 ?

I found that % by 0 have different results when it co-exsits with // by 0 in a kernel.

brod4910 pushed a commit to brod4910/triton that referenced this issue Oct 19, 2024
The semantics of `%` in triton used to be type dependant (!!).

With this PR, we make `%` always follow C semantics, similar to `//`.

We update the type promotion docs fixing some inaccuracies. It is still
not entirely precise though. For a discussion of the current semantics
see triton-lang#4697
brod4910 pushed a commit to brod4910/triton that referenced this issue Oct 19, 2024
The semantics of `%` in triton used to be type dependant (!!).

With this PR, we make `%` always follow C semantics, similar to `//`.

We update the type promotion docs fixing some inaccuracies. It is still
not entirely precise though. For a discussion of the current semantics
see triton-lang#4697
lezcano added a commit that referenced this issue Nov 1, 2024
Continuation of the work from @lezcano
#4698

> With this PR, we make `%` always follow C semantics, similar to `//`.
We update the type promotion docs fixing some inaccuracies. It is still
not entirely precise though. For a discussion of the current semantics
see #4697

Pretty sure all that was left were changes for the frem function to emit
`np.fmod` instead of `np.remainder` and to ignore ('uint16', 'float64')
mod computations in the tests. I believe this combination is
ill-conditioned but I could be wrong about that.

Co-authored-by: lezcano <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants