Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Introduce an OptimizeLDSUsage pass #3730

Merged
merged 22 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ if(NOT WIN32)
find_library(TERMINFO_LIBRARY tinfo)
endif()

if(TRITON_BUILD_UT)
include(AddTritonUnitTest)
endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
Expand Down
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerConvertTritonAMDGPUToLLVM();
mlir::triton::registerConvertBuiltinFuncToLLVM();
mlir::triton::registerDecomposeUnsupportedAMDConversions();
mlir::triton::registerOptimizeAMDLDSUsage();

// TritonAMDGPUTransforms passes
mlir::registerTritonAMDGPUAccelerateMatmul();
Expand Down
39 changes: 39 additions & 0 deletions cmake/AddTritonUnitTest.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
include(${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)

include(GoogleTest)
enable_testing()

function(add_triton_ut)
set(options)
set(oneValueArgs NAME)
set(multiValueArgs SRCS LIBS DEFS)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)

add_test(NAME ${__NAME}
COMMAND ${__NAME})
add_executable(
${__NAME}
${__SRCS})
target_link_libraries(
${__NAME}
PRIVATE
GTest::gtest_main
${triton_libs}
${dialect_libs}
${conversion_libs}
gmock
${__LIBS})

target_compile_options(${__NAME} PRIVATE -fno-rtti)

target_compile_definitions(${__NAME} PRIVATE ${__DEFS})

# Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac
# laptop. I think the issue may be that the very first time you run a program
# it's a bit slow.
gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60)
endfunction()
9 changes: 9 additions & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ class AllocationAnalysis;
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getScratchConfigForCvtLayout(RankedTensorType srcType,
RankedTensorType dstType,
unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op);
SmallVector<unsigned> getRepShapeForCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

} // namespace triton

Expand Down Expand Up @@ -135,6 +141,9 @@ class Allocation {
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }

/// Returns mapping from operation to list of live LDS buffers
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();

private:
/// A class that represents a shared memory buffer
struct BufferT {
Expand Down
41 changes: 38 additions & 3 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
return getRepShapeForCvtLayout(srcTy, dstTy);
}

SmallVector<unsigned> getRepShapeForCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

Expand Down Expand Up @@ -92,12 +97,19 @@ SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
auto repShape = getRepShapeForCvtLayout(op);
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
return getScratchConfigForCvtLayout(srcTy, dstTy, inVec, outVec);
}

SmallVector<unsigned> getScratchConfigForCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy,
unsigned &inVec,
unsigned &outVec) {
auto repShape = getRepShapeForCvtLayout(srcTy, dstTy);
if (repShape.empty())
return repShape;
auto rank = repShape.size();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

Expand Down Expand Up @@ -627,4 +639,27 @@ void Allocation::run(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
}

std::map<Operation *, SmallVector<Allocation::BufferId>>
Allocation::getLiveBuffers() {
std::map<Operation *, SmallVector<BufferId>> liveBuffers;

Operation *rootOperation = getOperation();
mlir::Liveness liveness(rootOperation);
auto analyzeOperation = [&](Operation *op) -> void {
auto scratchBuffer = getBufferId(op);
if (scratchBuffer != InvalidBufferId)
liveBuffers[op].push_back(scratchBuffer);
for (auto result : op->getOpResults()) {
auto bufferId = getBufferId(result);
if (bufferId == Allocation::InvalidBufferId)
continue;
auto liveOperations = liveness.resolveLiveness(result);
for (auto depOp : liveOperations)
liveBuffers[depOp].push_back(bufferId);
}
};
rootOperation->walk(analyzeOperation);
return liveBuffers;
}

} // namespace mlir
89 changes: 89 additions & 0 deletions test/TritonGPU/amd/optimize-lds-usage.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a | FileCheck %s
// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a -optimize-amd-lds-usage=lds-limit=32768 | FileCheck %s --check-prefix=CHECK-32KLIMIT

// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS
// CHECK-LABEL: alloc_convert_load
// CHECK-32KLIMIT-LABEL: alloc_convert_load
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I forgot to mention that I think this cvtOp is decomposed just because it uses more than 64 KB of LDS since padding is used. Therefore, this test does not test the functionality that a cvtOp could still be decomposed even it uses less than 64 KB LDS.

Copy link
Contributor Author

@binarman binarman Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new test: it uses fp16 instead of fp32, so cvt scratch buffer is x2 smaller

%3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}

// -----

// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS
// in case of relatively small scratch buffer
// CHECK-LABEL: alloc_convert_small_load
// CHECK-32KLIMIT-LABEL: alloc_convert_small_load
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}

// -----

// Check that optimization works with 3d tensors
// in case of relatively small scratch buffer
// CHECK-LABEL: alloc_convert_3d_load
// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#mma{{.*}}#mma1
// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<1x128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}

// -----

// Check that optimization triggers with custom LDS limit and do not triggers with default one
// CHECK-LABEL: alloc_convert_32k_limit
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
// CHECK: %2 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
// CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit
// CHECK-32KLIMIT: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK-32KLIMIT: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
// CHECK-32KLIMIT: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
// CHECK-32KLIMIT: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>>
tt.return
}
}
3 changes: 3 additions & 0 deletions third_party/amd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms)
endif()
if(TRITON_BUILD_UT)
add_subdirectory(unittest)
endif()
7 changes: 7 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def make_llir(src, metadata, options):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
# custom_lds_size is an experimental parameter that defines amount of LDS available
# for one thread block. Measured in bytes.
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
passes.convert.add_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)

Expand Down
7 changes: 7 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ namespace AMD {
std::unique_ptr<OperationPass<ModuleOp>>
createDecomposeUnsupportedConversionsPass(StringRef targetArch);

/// @brief Creates pass that keep LDS consumption within specified limits.
/// @param arch target architecture name, for example "gfx940"
/// @param customLDSLimit defines LDS size available for one thread block
/// zero value tells pass that whole LDS is available on a device
/// @return created pass
std::unique_ptr<OperationPass<ModuleOp>>
createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
} // namespace AMD

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
12 changes: 12 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers
];
}

def OptimizeAMDLDSUsage : Pass<"optimize-amd-lds-usage", "mlir::ModuleOp"> {
let summary = "Minimize LDS usage";
let constructor = "mlir::triton::AMD::createOptimizeLDSUsagePass(\"\")";

let options = [
Option<"targetArch", "target-arch", "std::string", /*default*/"",
"gfx target device architecture, e.g., gfx942">,
Option<"customLDSLimit", "lds-limit", "int", /*default*/"0",
"custom limit of LDS consumption, if not provided, maximum LDS size is used">,
];
}

def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ add_triton_library(TritonAMDGPUToLLVM
TargetInfo.cpp
TargetUtils.cpp
DecomposeUnsupportedConversions.cpp
OptimizeLDSUsage.cpp
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp

DEPENDS
Expand Down
Loading
Loading