Skip to content

Commit

Permalink
[Backend] Fix predicates for device assert inside reduction/scan regi…
Browse files Browse the repository at this point in the history
…on (#5033)

Reductions have special handling for side effectful "combine ops" (e.g.
"add" for a sum reduction). In the presence of side effects, a predicate
is computed to determine whether a thread should participate in the
reduction, to ensure that invalid/uninitialized data is not operated on.
See #4811 for more details.

~Previously, the predicate logic was incorrect for 2D reductions. This
PR fixes the logic and adds a python test.~

Edit: after additional discussion with @peterbell10, we removed the
lanePred logic. Here's our thinking on why this is valid:
* lanePred info is computed based entirely on the blocked layout info
and properties of the reduction
* the blocked layout won't tell you which threads do or don't have
uninitialized data

Instead, it sounds like the motivation for #4811 is based on
uninitialized values that can be indicated by the `pred` variable passed
into `warpReduce()`.
  • Loading branch information
davidberard98 authored Nov 5, 2024
1 parent 038cbc5 commit 732aee7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
7 changes: 0 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,6 @@ struct ReduceOpConversion

auto mod = op->getParentOfType<ModuleOp>();
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
if (iWarpSize > numLaneToReduce) {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(iWarpSize);
Value laneId = urem(threadId, warpSize);
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
pred = pred ? and_(pred, lanePred) : lanePred;
}

for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
Expand Down
24 changes: 24 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5939,6 +5939,30 @@ def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr):
torch.testing.assert_close(Z, X.sum().to(torch.int32))


@pytest.mark.parametrize("reduce_dim", [0, 1])
def test_side_effectful_reduction_2d(device, reduce_dim):
if device != "cuda":
pytest.skip()

@triton.jit(debug=True)
def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr,
NON_REDUCE_DIM: tl.constexpr):
offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :]
vals = tl.load(X + offsets)
z = tl.reduce(vals, reduce_dim, sanitize_add)
tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z)

BLOCK_0 = 16
BLOCK_1 = 32
NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0
torch.manual_seed(42)
X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32)
Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32)
sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim,
NON_REDUCE_DIM=NON_REDUCE_DIM)
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))


def test_side_effectful_scan(device):
if device != "cuda":
pytest.skip()
Expand Down

0 comments on commit 732aee7

Please sign in to comment.