Skip to content

Commit

Permalink
[BACKEND] Allow backend to specify special rules for membar insertion (
Browse files Browse the repository at this point in the history
…triton-lang#4675)

With block level kind of operations like TMA it is possible that some
ops access the shared memory but don't require barriers. This adds a
lambda that backends can pass to explicitly skip barriers in between
some ops.
  • Loading branch information
ThomasRaoux authored Sep 9, 2024
1 parent 63c7d4c commit 8a3fb7e
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 28 deletions.
55 changes: 36 additions & 19 deletions include/triton/Analysis/Membar.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,38 @@ namespace mlir {

class OpBuilder;

/// Callback to allow backend to provide more information on whether a barrier
/// is needed between two operations. Even though two operations access the same
/// shared memory thay may not require a barrier in between them.
using MembarFilterFn = std::function<bool(Operation *, Operation *)>;

struct BlockInfo {
using IntervalSetT = std::set<Interval<size_t>>;
using IntervalMapT = std::map<Interval<size_t>, std::set<Operation *>>;

IntervalSetT syncReadIntervals;
IntervalSetT syncWriteIntervals;
IntervalMapT syncReadIntervals;
IntervalMapT syncWriteIntervals;

BlockInfo() = default;

/// Unions two BlockInfo objects.
BlockInfo &join(const BlockInfo &other) {
syncReadIntervals.insert(other.syncReadIntervals.begin(),
other.syncReadIntervals.end());
syncWriteIntervals.insert(other.syncWriteIntervals.begin(),
other.syncWriteIntervals.end());
for (auto &interval : other.syncReadIntervals)
syncReadIntervals[interval.first].insert(interval.second.begin(),
interval.second.end());
for (auto &interval : other.syncWriteIntervals)
syncWriteIntervals[interval.first].insert(interval.second.begin(),
interval.second.end());
return *this;
}

/// Returns true if intervals in two BlockInfo objects are intersected.
bool isIntersected(const BlockInfo &other) const {
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) ||
bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const {
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals,
filter) ||
/*WAR*/
isIntersected(syncReadIntervals, other.syncWriteIntervals) ||
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) ||
/*WAW*/
isIntersected(syncWriteIntervals, other.syncWriteIntervals);
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
}

/// Clears the intervals because a barrier is inserted.
Expand All @@ -51,12 +59,17 @@ struct BlockInfo {
bool operator!=(const BlockInfo &other) const { return !(*this == other); }

private:
bool isIntersected(const IntervalSetT &lhsIntervalSet,
const IntervalSetT &rhsIntervalSet) const {
bool isIntersected(const IntervalMapT &lhsIntervalSet,
const IntervalMapT &rhsIntervalSet,
MembarFilterFn filter) const {
for (auto &lhs : lhsIntervalSet)
for (auto &rhs : rhsIntervalSet)
if (lhs.intersects(rhs))
return true;
if (lhs.first.intersects(rhs.first))
for (auto lhsOp : lhs.second)
for (auto rhsOp : rhs.second)
if (!filter || !filter(lhsOp, rhsOp))
return true;

return false;
}
};
Expand All @@ -81,7 +94,8 @@ class MembarAnalysis {
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
MembarAnalysis() = default;
explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
: allocation(allocation), filter(filter) {}

/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
Expand Down Expand Up @@ -116,6 +130,7 @@ class MembarAnalysis {

private:
Allocation *allocation = nullptr;
MembarFilterFn filter = nullptr;
};

/// Postorder traversal on the callgraph to insert membar instructions
Expand All @@ -125,9 +140,10 @@ class MembarAnalysis {
/// before and after function calls, but might be a bit conservative.
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
public:
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation)
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation,
MembarFilterFn filter = nullptr)
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
moduleAllocation(moduleAllocation) {}
moduleAllocation(moduleAllocation), filter(filter) {}

void run() {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
Expand All @@ -138,14 +154,15 @@ class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
auto *allocation = moduleAllocation->getFuncData(funcOp);
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
if (inserted) {
MembarAnalysis analysis(allocation);
MembarAnalysis analysis(allocation, filter);
analysis.run(funcMap);
}
});
}

private:
ModuleAllocation *moduleAllocation;
MembarFilterFn filter;
};

} // namespace mlir
Expand Down
20 changes: 12 additions & 8 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo.syncWriteIntervals.insert(
allocation->getAllocatedInterval(bufferId));
curBlockInfo
.syncWriteIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo.syncReadIntervals.insert(
allocation->getAllocatedInterval(bufferId));
curBlockInfo
.syncReadIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
}
}
}
Expand All @@ -161,15 +165,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
"dependencies");
}
auto interval = allocation->getAllocatedInterval(scratchBufferId);
curBlockInfo.syncWriteIntervals.insert(interval);
if (blockInfo->isIntersected(curBlockInfo)) {
curBlockInfo.syncWriteIntervals[interval].insert(op);
if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
}
// Ops with a scratch buffer internally syncs read/write on shared memory
blockInfo->sync();
curBlockInfo.syncReadIntervals.insert(interval);
} else if (blockInfo->isIntersected(curBlockInfo)) {
curBlockInfo.syncReadIntervals[interval].insert(op);
} else if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
blockInfo->sync();
Expand Down
92 changes: 92 additions & 0 deletions test/Analysis/test-membar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,95 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} {
// CHECK-LABEL: tma_special_cases
tt.func @tma_special_cases(%arg1: !tt.ptr<i8, 0>) -> (tensor<256x64xf16, #blocked>){
%true = arith.constant 1 : i1
%c0 = arith.constant 0 : i32
%barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable>
%alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
// CHECK: triton_nvidia_gpu.init_barrier
// CHECK-NEXT: triton_nvidia_gpu.init_barrier
triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: triton_gpu.local_load
%t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked>

// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_nvidia_gpu.inval_barrier
// CHECK-NEXT: triton_nvidia_gpu.inval_barrier
triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

tt.return %t : tensor<256x64xf16, #blocked>
}
}

// -----

#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} {
// CHECK-LABEL: tma_special_cases_cf
tt.func @tma_special_cases_cf(%arg1: !tt.ptr<i8, 0>, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){
%true = arith.constant 1 : i1
%c0 = arith.constant 0 : i32
%barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable>
%alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
// CHECK: cf.cond_br
scf.if %i1 {
// CHECK-NOT: gpu.barrier
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
// CHECK-NEXT: cf.br
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield
} else {
// CHECK-NOT: gpu.barrier
// CHECK: triton_gpu.local_store
// CHECK-NEXT: cf.br
triton_gpu.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
scf.yield
}
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.local_load
%t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked>
tt.return %t : tensor<256x64xf16, #blocked>
}
}
4 changes: 3 additions & 1 deletion test/lib/Analysis/TestMembar.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "../third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Dialect.h"
Expand Down Expand Up @@ -25,7 +26,8 @@ struct TestMembarPass
ModuleOp moduleOp = cast<ModuleOp>(operation);
// Print all ops after membar pass
ModuleAllocation allocation(moduleOp);
ModuleMembarAnalysis membarPass(&allocation);
ModuleMembarAnalysis membarPass(&allocation,
mlir::triton::NVIDIA::canSkipBarSync);
membarPass.run();
}
};
Expand Down
17 changes: 17 additions & 0 deletions third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H
#define TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H

#include "mlir/IR/Operation.h"

namespace mlir {
namespace triton {
namespace NVIDIA {

/// Return true if we can skip a barrier synchronization between two operations
/// even if they access the same shared memory.
bool canSkipBarSync(Operation *before, Operation *after);
} // namespace NVIDIA
} // namespace triton
} // namespace mlir

#endif // TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H
37 changes: 37 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "Dialect/NVGPU/IR/Dialect.h"
#include "TritonNVIDIAGPUToLLVM/Passes.h"
#include "TritonNVIDIAGPUToLLVM/Utility.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
Expand Down Expand Up @@ -228,5 +229,41 @@ createConvertTritonGPUToLLVMPass(int32_t computeCapability) {
return std::make_unique<ConvertTritonGPUToLLVM>(computeCapability);
}

bool NVIDIA::canSkipBarSync(Operation *before, Operation *after) {
// Multiple init barriers on the same allocation would usually not happen but
// that allows us to avoid barriers between multiple subslice of an array of
// mbarriers. This is still correct even if the inits happen on the same
// allocation.
if (isa<triton::nvidia_gpu::InitBarrierOp>(before) &&
isa<triton::nvidia_gpu::InitBarrierOp>(after))
return true;

if (isa<triton::nvidia_gpu::InvalBarrierOp>(before) &&
isa<triton::nvidia_gpu::InvalBarrierOp>(after))
return true;

// Even though WaitBarrierOp, AsyncTMACopyGlobalToLocalOp and
// AsyncTMACopyGlobalToLocalOp read and write to the mbarrier allocation it is
// valid for them to happen in different order on different threads, therefore
// we don't need a barrier between those operations.
if (isa<triton::nvidia_gpu::WaitBarrierOp,
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
triton::nvidia_gpu::BarrierExpectOp>(before) &&
isa<triton::nvidia_gpu::WaitBarrierOp,
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
triton::nvidia_gpu::BarrierExpectOp>(after))
return true;

// A mbarrier wait is released only when the whole operations is done,
// therefore any thread can access the memory after the barrier even if some
// threads haven't reached the mbarrier wait.
if (isa<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
triton::nvidia_gpu::WaitBarrierOp>(before) &&
!isa<triton::nvidia_gpu::InvalBarrierOp>(after))
return true;

return false;
}

} // namespace triton
} // namespace mlir

0 comments on commit 8a3fb7e

Please sign in to comment.