Skip to content

Commit

Permalink
fix lit test
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Dec 19, 2024
1 parent 6f48edd commit 492969a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
1 change: 0 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_LLVM_DEBUG_ONLY",
"USE_IR_LOC",
"NVPTX_ENABLE_DUMP",
"ENABLE_BUFFER_REUSE",
// clang-format on
};

Expand Down
21 changes: 12 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,6 @@ bool reuseBuffers(SmallVector<Operation *> &taskTopOps,
const SmallVector<Channel *> &channels,
DenseMap<Channel *, Channel *> &mapToRepresenting,
SmallVector<scf::ForOp> &loopWithBufferReuse) {
if (!triton::tools::getBoolEnv("ENABLE_BUFFER_REUSE"))
return false;
// For the case of multiple parallel ForOps with same number of channels,
// we can try reusing the buffers across the parallel ForOps.
// One case is
Expand Down Expand Up @@ -1508,14 +1506,19 @@ DenseMap<Channel *, Value> createBuffer(
MLIRContext *context = funcOp.getContext();
OpBuilder builder(funcOp);
builder.setInsertionPointToStart(&(funcOp.getBody().front()));
DenseSet<Channel *> visited;
for (auto &item : channelsGroupedByProducers) {
auto c = item.first;
if (mapToRepresenting.count(c)) {
channelReuse[mapToRepresenting[c]].push_back(c);
LDBG("update channelReuse key " << mapToRepresenting[c] << " " << c);
} else {
channelReuse[c].push_back(c);
LDBG("update channelReuse key " << c << " " << c);
auto &channels = item.second;
for (auto c : channels) {
assert(!visited.count(c));
visited.insert(c);
if (mapToRepresenting.count(c)) {
channelReuse[mapToRepresenting[c]].push_back(c);
LDBG("update channelReuse key " << mapToRepresenting[c] << " " << c);
} else {
channelReuse[c].push_back(c);
LDBG("update channelReuse key " << c << " " << c);
}
}
}
for (auto &item : channelsGroupedByProducers) {
Expand Down
26 changes: 14 additions & 12 deletions test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
}
}

// ----- verify that we can reuse buffers between two for loops
// -----

// Verify that we can reuse buffers between two for loops
// CHECK-LABEL: @_attn_bwd_ws
// CHECK: triton_gpu.local_alloc {allocation.shareGroup = 0 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK: triton_gpu.local_alloc {allocation.shareGroup = 1 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK: triton_gpu.local_alloc {allocation.shareGroup = 0 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK: triton_gpu.local_alloc {allocation.shareGroup = 1 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 0 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 1 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 0 : i32} : () -> !tt.memdesc<2x64x128xbf16
// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 1 : i32} : () -> !tt.memdesc<2x64x128xbf16

// CHECK: %[[TASKID:.*]] = triton_nvidia_gpu.get_async_task_id : i32
// CHECK: %[[TID:.*]] = triton_nvidia_gpu.get_async_task_id : i32
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
// CHECK: %[[WG0:.*]] = arith.cmpi eq, %[[TASKID]], %[[ZERO]] : i32
// CHECK: scf.if %[[WG0]]
// CHECK: %[[TWG0:.*]] = arith.cmpi eq, %[[TID]], %[[ZERO]] : i32
// CHECK: scf.if %[[TWG0]]
// CHECK: triton_nvidia_gpu.reg_dealloc 40
// CHECK: scf.if
// CHECK: scf.yield
Expand All @@ -470,8 +472,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: scf.yield {{.*}} %[[IF_IDX]]

// CHECK: %[[ONE:.*]] = arith.constant 1 : i32
// CHECK: %[[WG1:.*]] = arith.cmpi eq, %[[TASKID]], %[[ONE]] : i32
// CHECK: scf.if %[[WG1]]
// CHECK: %[[TWG1:.*]] = arith.cmpi eq, %[[TID]], %[[ONE]] : i32
// CHECK: scf.if %[[TWG1]]
// CHECK: triton_nvidia_gpu.reg_alloc 232
// CHECK: scf.if
// CHECK: scf.yield
Expand All @@ -497,8 +499,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: scf.yield {{.*}} %[[IF_IDX_WG1]]

// CHECK: %[[TWO:.*]] = arith.constant 2 : i32
// CHECK: %[[WG2:.*]] = arith.cmpi eq, %[[TASKID]], %[[TWO]] : i32
// CHECK: scf.if %[[WG2]]
// CHECK: %[[TWG2:.*]] = arith.cmpi eq, %[[TID]], %[[TWO]] : i32
// CHECK: scf.if %[[TWG2]]
// CHECK: triton_nvidia_gpu.reg_alloc 232
// CHECK: scf.if
// CHECK: scf.yield
Expand Down

0 comments on commit 492969a

Please sign in to comment.