Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Fix predicates for device assert inside reduction/scan region #5033

Merged
merged 1 commit into from
Nov 5, 2024

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Nov 1, 2024

Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details.

Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.

Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid:

  • lanePred info is computed based entirely on the blocked layout info and properties of the reduction
  • the blocked layout won't tell you which threads do or don't have uninitialized data

Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the pred variable passed into warpReduce().

@@ -5917,7 +5917,7 @@ def test_side_effectful_reduction(device):
if device != "cuda":
pytest.skip()

@triton.jit(debug=True)
@triton.jit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the debug=True needs to be added as a kwarg to the invocation of the triton kernel. Previously I wasn't seeing any asserts in the ttgir

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting this, I've opened #5037 to fix it

Comment on lines 166 to 174
// Predicate to ensure we don't read from invalid memory.
// definitions:
// "Lane": the strip of values that are being reduced along.
// relevant variables:
// interleave: for two consecutive elements in a lane, the difference
// between their thread ids is the interleave.
// numLanesToReduce: how many lanes we're reducing across.
// totalNumLanes: how many lanes exist in total. If the reduction
// skips some threads, totalNumLanes might not equal numLanesToReduce.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 is this accurate? tbh I didn't quite understand what scenario requires a predicate - I verified that this fixes my scenario, but I don't know if it regresses the scenario you were initially targeting.

Comment on lines 178 to 180
Value laneId =
urem(udiv(threadId, i32_val(interleave)), i32_val(totalNumLanes));
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
Copy link
Contributor

@peterbell10 peterbell10 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition of lane is the position of a thread within its warp, so this is a bit confusing. Would it work to do this?

Suggested change
Value laneId =
urem(udiv(threadId, i32_val(interleave)), i32_val(totalNumLanes));
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
Value laneId = urem(threadId, warpSize);
Value lanePred = icmp_slt(laneId, i32_val(totalNumLanes * interleave));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 thanks for the suggestion!

Instead I'm using

Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce * interleave));

since I presume that the reason for predicating is due to the difference between numLaneToReduce vs. totalNumLanes?

@davidberard98
Copy link
Contributor Author

note: the other test_side_effectful_reduction and side_effectful_scan tests are failing after #5035, but somehow not failing on the test_side_effectful_reduction_2d test added by this PR.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 1, 2024
In upstream triton, triton-lang/triton#4589 introduces overflow checks. However, overflow checks likely add some overhead, and have some correctness bugs at the moment (e.g. triton-lang/triton#5033). Let's set `sanitize_overflow=False` but keep `debug=True` so that we can keep using device_assert but without the additional asserts added by `sanitize_overflow`.

Pull Request resolved: #139502
Approved by: https://github.com/bertmaher
Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See [] for more details.

Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
…9502)

In upstream triton, triton-lang/triton#4589 introduces overflow checks. However, overflow checks likely add some overhead, and have some correctness bugs at the moment (e.g. triton-lang/triton#5033). Let's set `sanitize_overflow=False` but keep `debug=True` so that we can keep using device_assert but without the additional asserts added by `sanitize_overflow`.

Pull Request resolved: pytorch#139502
Approved by: https://github.com/bertmaher
@peterbell10 peterbell10 merged commit 732aee7 into triton-lang:main Nov 5, 2024
7 checks passed
peterbell10 added a commit that referenced this pull request Nov 5, 2024
…5075)

This is a follow up to #5033 but for scan ops, and also improving the
testing as it was clearly insufficient before.
bertmaher pushed a commit that referenced this pull request Nov 5, 2024
…on (#5033)

Reductions have special handling for side effectful "combine ops" (e.g.
"add" for a sum reduction). In the presence of side effects, a predicate
is computed to determine whether a thread should participate in the
reduction, to ensure that invalid/uninitialized data is not operated on.
See #4811 for more details.

~Previously, the predicate logic was incorrect for 2D reductions. This
PR fixes the logic and adds a python test.~

Edit: after additional discussion with @peterbell10, we removed the
lanePred logic. Here's our thinking on why this is valid:
* lanePred info is computed based entirely on the blocked layout info
and properties of the reduction
* the blocked layout won't tell you which threads do or don't have
uninitialized data

Instead, it sounds like the motivation for #4811 is based on
uninitialized values that can be indicated by the `pred` variable passed
into `warpReduce()`.
bertmaher pushed a commit that referenced this pull request Nov 5, 2024
…5075)

This is a follow up to #5033 but for scan ops, and also improving the
testing as it was clearly insufficient before.
antiagainst added a commit to antiagainst/triton that referenced this pull request Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants