From 8a3fb7e3fd6b87f09bcb4ebc61fa04d02e2000f9 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sun, 8 Sep 2024 20:19:01 -0700 Subject: [PATCH] [BACKEND] Allow backend to specify special rules for membar insertion (#4675) With block level kind of operations like TMA it is possible that some ops access the shared memory but don't require barriers. This adds a lambda that backends can pass to explicitly skip barriers in between some ops. --- include/triton/Analysis/Membar.h | 55 +++++++---- lib/Analysis/Membar.cpp | 20 ++-- test/Analysis/test-membar.mlir | 92 +++++++++++++++++++ test/lib/Analysis/TestMembar.cpp | 4 +- .../include/TritonNVIDIAGPUToLLVM/Utility.h | 17 ++++ .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 37 ++++++++ 6 files changed, 197 insertions(+), 28 deletions(-) create mode 100644 third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 6475c977cd25..038b0e167563 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -10,30 +10,38 @@ namespace mlir { class OpBuilder; +/// Callback to allow backend to provide more information on whether a barrier +/// is needed between two operations. Even though two operations access the same +/// shared memory thay may not require a barrier in between them. +using MembarFilterFn = std::function; + struct BlockInfo { - using IntervalSetT = std::set>; + using IntervalMapT = std::map, std::set>; - IntervalSetT syncReadIntervals; - IntervalSetT syncWriteIntervals; + IntervalMapT syncReadIntervals; + IntervalMapT syncWriteIntervals; BlockInfo() = default; /// Unions two BlockInfo objects. BlockInfo &join(const BlockInfo &other) { - syncReadIntervals.insert(other.syncReadIntervals.begin(), - other.syncReadIntervals.end()); - syncWriteIntervals.insert(other.syncWriteIntervals.begin(), - other.syncWriteIntervals.end()); + for (auto &interval : other.syncReadIntervals) + syncReadIntervals[interval.first].insert(interval.second.begin(), + interval.second.end()); + for (auto &interval : other.syncWriteIntervals) + syncWriteIntervals[interval.first].insert(interval.second.begin(), + interval.second.end()); return *this; } /// Returns true if intervals in two BlockInfo objects are intersected. - bool isIntersected(const BlockInfo &other) const { - return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) || + bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals, + filter) || /*WAR*/ - isIntersected(syncReadIntervals, other.syncWriteIntervals) || + isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) || /*WAW*/ - isIntersected(syncWriteIntervals, other.syncWriteIntervals); + isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter); } /// Clears the intervals because a barrier is inserted. @@ -51,12 +59,17 @@ struct BlockInfo { bool operator!=(const BlockInfo &other) const { return !(*this == other); } private: - bool isIntersected(const IntervalSetT &lhsIntervalSet, - const IntervalSetT &rhsIntervalSet) const { + bool isIntersected(const IntervalMapT &lhsIntervalSet, + const IntervalMapT &rhsIntervalSet, + MembarFilterFn filter) const { for (auto &lhs : lhsIntervalSet) for (auto &rhs : rhsIntervalSet) - if (lhs.intersects(rhs)) - return true; + if (lhs.first.intersects(rhs.first)) + for (auto lhsOp : lhs.second) + for (auto rhsOp : rhs.second) + if (!filter || !filter(lhsOp, rhsOp)) + return true; + return false; } }; @@ -81,7 +94,8 @@ class MembarAnalysis { /// it is considered as the problem of the operation itself but not the membar /// analysis. MembarAnalysis() = default; - explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter) + : allocation(allocation), filter(filter) {} /// Runs the membar analysis to the given operation, inserts a barrier if /// necessary. @@ -116,6 +130,7 @@ class MembarAnalysis { private: Allocation *allocation = nullptr; + MembarFilterFn filter = nullptr; }; /// Postorder traversal on the callgraph to insert membar instructions @@ -125,9 +140,10 @@ class MembarAnalysis { /// before and after function calls, but might be a bit conservative. class ModuleMembarAnalysis : public CallGraph { public: - ModuleMembarAnalysis(ModuleAllocation *moduleAllocation) + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation, + MembarFilterFn filter = nullptr) : CallGraph(moduleAllocation->getModuleOp()), - moduleAllocation(moduleAllocation) {} + moduleAllocation(moduleAllocation), filter(filter) {} void run() { walk( @@ -138,7 +154,7 @@ class ModuleMembarAnalysis : public CallGraph { auto *allocation = moduleAllocation->getFuncData(funcOp); auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); if (inserted) { - MembarAnalysis analysis(allocation); + MembarAnalysis analysis(allocation, filter); analysis.run(funcMap); } }); @@ -146,6 +162,7 @@ class ModuleMembarAnalysis : public CallGraph { private: ModuleAllocation *moduleAllocation; + MembarFilterFn filter; }; } // namespace mlir diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index ea6cc8fe2a61..bb106238e75b 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -136,11 +136,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, for (auto bufferId : allocation->getBufferIds(value)) { if (bufferId != Allocation::InvalidBufferId) { if (isa(effectInstance.getEffect())) - curBlockInfo.syncWriteIntervals.insert( - allocation->getAllocatedInterval(bufferId)); + curBlockInfo + .syncWriteIntervals[allocation->getAllocatedInterval( + bufferId)] + .insert(op); else if (isa(effectInstance.getEffect())) - curBlockInfo.syncReadIntervals.insert( - allocation->getAllocatedInterval(bufferId)); + curBlockInfo + .syncReadIntervals[allocation->getAllocatedInterval( + bufferId)] + .insert(op); } } } @@ -161,15 +165,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, "dependencies"); } auto interval = allocation->getAllocatedInterval(scratchBufferId); - curBlockInfo.syncWriteIntervals.insert(interval); - if (blockInfo->isIntersected(curBlockInfo)) { + curBlockInfo.syncWriteIntervals[interval].insert(op); + if (blockInfo->isIntersected(curBlockInfo, filter)) { builder->setInsertionPoint(op); insertBarrier(op, builder); } // Ops with a scratch buffer internally syncs read/write on shared memory blockInfo->sync(); - curBlockInfo.syncReadIntervals.insert(interval); - } else if (blockInfo->isIntersected(curBlockInfo)) { + curBlockInfo.syncReadIntervals[interval].insert(op); + } else if (blockInfo->isIntersected(curBlockInfo, filter)) { builder->setInsertionPoint(op); insertBarrier(op, builder); blockInfo->sync(); diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 6ec8e601a8ed..c26e3a06d1f3 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -713,3 +713,95 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +// CHECK-LABEL: tma_special_cases +tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocked>){ + %true = arith.constant 1 : i1 + %c0 = arith.constant 0 : i32 + %barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: triton_nvidia_gpu.init_barrier + // CHECK-NEXT: triton_nvidia_gpu.init_barrier + triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: triton_gpu.local_load + %t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.inval_barrier + // CHECK-NEXT: triton_nvidia_gpu.inval_barrier + triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + tt.return %t : tensor<256x64xf16, #blocked> +} +} + +// ----- + +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +// CHECK-LABEL: tma_special_cases_cf +tt.func @tma_special_cases_cf(%arg1: !tt.ptr, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ + %true = arith.constant 1 : i1 + %c0 = arith.constant 0 : i32 + %barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: cf.cond_br + scf.if %i1 { + // CHECK-NOT: gpu.barrier + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + // CHECK-NEXT: cf.br + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield + } else { + // CHECK-NOT: gpu.barrier + // CHECK: triton_gpu.local_store + // CHECK-NEXT: cf.br + triton_gpu.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield + } + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.local_load + %t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + tt.return %t : tensor<256x64xf16, #blocked> +} +} diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 5e7bbb0c80e4..25e8e2d198bb 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -1,3 +1,4 @@ +#include "../third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Dialect.h" @@ -25,7 +26,8 @@ struct TestMembarPass ModuleOp moduleOp = cast(operation); // Print all ops after membar pass ModuleAllocation allocation(moduleOp); - ModuleMembarAnalysis membarPass(&allocation); + ModuleMembarAnalysis membarPass(&allocation, + mlir::triton::NVIDIA::canSkipBarSync); membarPass.run(); } }; diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h new file mode 100644 index 000000000000..6d1c3c06a596 --- /dev/null +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h @@ -0,0 +1,17 @@ +#ifndef TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H +#define TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H + +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace triton { +namespace NVIDIA { + +/// Return true if we can skip a barrier synchronization between two operations +/// even if they access the same shared memory. +bool canSkipBarSync(Operation *before, Operation *after); +} // namespace NVIDIA +} // namespace triton +} // namespace mlir + +#endif // TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 977f5b5571b1..e6a6ebc2f125 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,5 +1,6 @@ #include "Dialect/NVGPU/IR/Dialect.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" +#include "TritonNVIDIAGPUToLLVM/Utility.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -228,5 +229,41 @@ createConvertTritonGPUToLLVMPass(int32_t computeCapability) { return std::make_unique(computeCapability); } +bool NVIDIA::canSkipBarSync(Operation *before, Operation *after) { + // Multiple init barriers on the same allocation would usually not happen but + // that allows us to avoid barriers between multiple subslice of an array of + // mbarriers. This is still correct even if the inits happen on the same + // allocation. + if (isa(before) && + isa(after)) + return true; + + if (isa(before) && + isa(after)) + return true; + + // Even though WaitBarrierOp, AsyncTMACopyGlobalToLocalOp and + // AsyncTMACopyGlobalToLocalOp read and write to the mbarrier allocation it is + // valid for them to happen in different order on different threads, therefore + // we don't need a barrier between those operations. + if (isa(before) && + isa(after)) + return true; + + // A mbarrier wait is released only when the whole operations is done, + // therefore any thread can access the memory after the barrier even if some + // threads haven't reached the mbarrier wait. + if (isa(before) && + !isa(after)) + return true; + + return false; +} + } // namespace triton } // namespace mlir