-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Backend] Fix predicates for device assert inside reduction/scan region #5033
[Backend] Fix predicates for device assert inside reduction/scan region #5033
Conversation
@@ -5917,7 +5917,7 @@ def test_side_effectful_reduction(device): | |||
if device != "cuda": | |||
pytest.skip() | |||
|
|||
@triton.jit(debug=True) | |||
@triton.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the debug=True
needs to be added as a kwarg to the invocation of the triton kernel. Previously I wasn't seeing any asserts in the ttgir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting this, I've opened #5037 to fix it
// Predicate to ensure we don't read from invalid memory. | ||
// definitions: | ||
// "Lane": the strip of values that are being reduced along. | ||
// relevant variables: | ||
// interleave: for two consecutive elements in a lane, the difference | ||
// between their thread ids is the interleave. | ||
// numLanesToReduce: how many lanes we're reducing across. | ||
// totalNumLanes: how many lanes exist in total. If the reduction | ||
// skips some threads, totalNumLanes might not equal numLanesToReduce. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peterbell10 is this accurate? tbh I didn't quite understand what scenario requires a predicate - I verified that this fixes my scenario, but I don't know if it regresses the scenario you were initially targeting.
Value laneId = | ||
urem(udiv(threadId, i32_val(interleave)), i32_val(totalNumLanes)); | ||
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition of lane is the position of a thread within its warp, so this is a bit confusing. Would it work to do this?
Value laneId = | |
urem(udiv(threadId, i32_val(interleave)), i32_val(totalNumLanes)); | |
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); | |
Value laneId = urem(threadId, warpSize); | |
Value lanePred = icmp_slt(laneId, i32_val(totalNumLanes * interleave)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peterbell10 thanks for the suggestion!
Instead I'm using
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce * interleave));
since I presume that the reason for predicating is due to the difference between numLaneToReduce vs. totalNumLanes?
fe5c1fc
to
d430559
Compare
note: the other test_side_effectful_reduction and side_effectful_scan tests are failing after #5035, but somehow not failing on the test_side_effectful_reduction_2d test added by this PR. |
eb32075
to
2ea66e9
Compare
2614adc
to
2ea66e9
Compare
In upstream triton, triton-lang/triton#4589 introduces overflow checks. However, overflow checks likely add some overhead, and have some correctness bugs at the moment (e.g. triton-lang/triton#5033). Let's set `sanitize_overflow=False` but keep `debug=True` so that we can keep using device_assert but without the additional asserts added by `sanitize_overflow`. Pull Request resolved: #139502 Approved by: https://github.com/bertmaher
2ea66e9
to
40986be
Compare
40986be
to
4e5ba83
Compare
Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See [] for more details. Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.
4e5ba83
to
1f3198c
Compare
…9502) In upstream triton, triton-lang/triton#4589 introduces overflow checks. However, overflow checks likely add some overhead, and have some correctness bugs at the moment (e.g. triton-lang/triton#5033). Let's set `sanitize_overflow=False` but keep `debug=True` so that we can keep using device_assert but without the additional asserts added by `sanitize_overflow`. Pull Request resolved: pytorch#139502 Approved by: https://github.com/bertmaher
…on (#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
…can region (triton-lang#5033)" This reverts commit 732aee7.
Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details.
Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid:
Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the
pred
variable passed intowarpReduce()
.