Skip to content

Commit

Permalink
[Triton][Allocation] Enable getScratchValueSize specialization
Browse files Browse the repository at this point in the history
Allow passing a functor to `ModuleAllocation` constructor to override
`getScratchValueSize` in the allocation analysis. This parameter will
be `nullptr` by default.

Signed-off-by: victor-eds <[email protected]>
  • Loading branch information
victor-eds committed Nov 7, 2024
1 parent 1070ca2 commit ac2bafc
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 8 deletions.
15 changes: 12 additions & 3 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ namespace mlir {
namespace triton {
class AllocationAnalysis;

constexpr inline unsigned invalidAllocationSize = -1;

/// Callback to allow backends to specify target-specific scratch sizes for some
/// operations.
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
// require multiple iterations, with each iteration involving multiple
Expand Down Expand Up @@ -102,7 +108,8 @@ class Allocation {
explicit Allocation(Operation *operation) : operation(operation) {}

/// Runs allocation analysis on the given top-level operation.
void run(FuncAllocMapT &funcAllocMap);
void run(FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);

/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
Expand Down Expand Up @@ -250,7 +257,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;

explicit ModuleAllocation(ModuleOp moduleOp)
ModuleAllocation(
ModuleOp moduleOp,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter = nullptr)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order edge walk callback
Expand All @@ -259,7 +268,7 @@ class ModuleAllocation : public CallGraph<Allocation> {
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap);
iter->second.run(funcMap, scratchSizeGetter);
});
}

Expand Down
22 changes: 18 additions & 4 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation,
Allocation::FuncAllocMapT *funcAllocMap,
Allocation *allocation)
Allocation *allocation,
AllocationAnalysisScratchSizeFn scratchSizeGetter)
: operation(operation), funcAllocMap(funcAllocMap),
allocation(allocation) {
allocation(allocation), scratchSizeGetter(scratchSizeGetter) {
run();
}

Expand Down Expand Up @@ -277,6 +278,15 @@ class AllocationAnalysis {
// Get the alloc values
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
getExplicitValueSize(op);
if (scratchSizeGetter) {
constexpr size_t scratchAlignment = 128;
unsigned bytes = scratchSizeGetter(op);
if (bytes != invalidAllocationSize) {
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
return;
}
}
getScratchValueSize(op);
});
// Get the alias values
Expand Down Expand Up @@ -556,12 +566,16 @@ class AllocationAnalysis {
Allocation::FuncAllocMapT *funcAllocMap;
Allocation *allocation;
BufferRangeMapT bufferRange;
AllocationAnalysisScratchSizeFn scratchSizeGetter;
};

} // namespace triton

void Allocation::run(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
void Allocation::run(
FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this,
scratchSizeGetter);
}

std::map<Operation *, SmallVector<Allocation::BufferId>>
Expand Down
8 changes: 8 additions & 0 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=Invalid" 2>&1 | FileCheck %s

// Check there are no lines with a size different to 128 and we have at least a line with size 128.

// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}
// CHECK-128: scratch offset = {{.*}}, size = 128
// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}

#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
Expand Down
39 changes: 38 additions & 1 deletion test/lib/Analysis/TestAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,48 @@ using namespace mlir;

namespace {

unsigned getScratchSize128(Operation *) { return 128; }
unsigned getScratchSizeInvalid(Operation *) {
return mlir::triton::invalidAllocationSize;
}

enum class GetScratchSizeFunction {
None,
ValidConstant,
Invalid,
};

struct TestAllocationPass
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);

TestAllocationPass() = default;
TestAllocationPass(const TestAllocationPass &other)
: PassWrapper<TestAllocationPass, OperationPass<ModuleOp>>(other) {}

StringRef getArgument() const final { return "test-print-allocation"; }
StringRef getDescription() const final {
return "print the result of the allocation pass";
}

ModuleAllocation getModuleAllocation() {
switch (getScratchSizeFunction) {
case GetScratchSizeFunction::None:
return {getOperation()};
case GetScratchSizeFunction::ValidConstant:
return {getOperation(), getScratchSize128};
case GetScratchSizeFunction::Invalid:
return {getOperation(), getScratchSizeInvalid};
}
llvm_unreachable("Unhandled case");
}

void runOnOperation() override {
auto &os = llvm::errs();
ModuleOp moduleOp = getOperation();
// Convert to std::string can remove quotes from opName
ModuleAllocation moduleAllocation(moduleOp);
ModuleAllocation moduleAllocation = getModuleAllocation();
moduleOp.walk([&](triton::FuncOp funcOp) {
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
os << opName << "\n";
Expand Down Expand Up @@ -48,6 +75,16 @@ struct TestAllocationPass
os << "size = " << allocation->getSharedMemorySize() << "\n";
});
}

Option<GetScratchSizeFunction> getScratchSizeFunction{
*this, "get-scratch-size-function",
llvm::cl::desc("Custom scratch size function to use"),
llvm::cl::init(GetScratchSizeFunction::None),
llvm::cl::values(
clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"),
clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant",
"ValidConstant"),
clEnumValN(GetScratchSizeFunction::Invalid, "Invalid", "Invalid"))};
};

} // namespace
Expand Down

0 comments on commit ac2bafc

Please sign in to comment.