Skip to content

Cast bf16 to fp32 for sum/reduce #6974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

dshi7
Copy link
Contributor

@dshi7 dshi7 commented May 29, 2025

There used to be this behavior but got stopped by f9ad25e#diff-7935be9afb58d190ea8fd7a849bdc375fa48f75aff33c7bb152428feaa6860f4 by @Mogball

By reading the PR summary, it seems unintentional to stop promoting bf16 to fp32.

To reduce this footgun, this PR adds a dtype argument to tl.sum (and
tl.reduce) which optionally casts the input to that dtype before
computing the operation. For tl.sum, the default dtype is set to
tl.int32 for integers smaller than that and tl.float32 for floats
smaller than that.

@dshi7 dshi7 requested a review from ptillet as a code owner May 29, 2025 01:16
@dshi7 dshi7 requested review from Mogball and htyu and removed request for ptillet May 29, 2025 01:16
@ThomasRaoux
Copy link
Collaborator

I believe the change was on purpose, the description was unfortunately outdated, see this discussion:
#5763 (comment)

In any case I think the current behavior is what we want, we should only upcast if user explicitly pass float32 as accumulator type

@ThomasRaoux
Copy link
Collaborator

closing this for now based on comments, feel free to start a new discussion if you think this is still needed

@ThomasRaoux ThomasRaoux closed this Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants