Skip to content

Commit

Permalink
[AMD] Enable masked load and pointer canonicalization pass (#4638)
Browse files Browse the repository at this point in the history
This PR is doing two things:
- We are using the new `llvm.masked{load/store}` intrinsics. This means
that the backend will take responsibility to lower the stores/loads.
- We are enabling the canonicalization pointer pass on the Triton IR. I
extensively run testing and corrected a couple of minor issues still
present in the implementation.

The reason why I am enabling both at the same time is because I saw a
minor regression with `llvm.masked{load,store}` which seems to go away
when using the pointer canonicalization. Also, this combination seems to
reduce the numbers of vgprs used (at least for GEMM kernels).
  • Loading branch information
giuseros authored Sep 12, 2024
1 parent 368c864 commit c238af8
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 41 deletions.
2 changes: 0 additions & 2 deletions python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,10 @@ def test_line_info(func: str):
assert (check_file_lines(file_lines, "test_line_info.py", 16))
elif func == "call":
assert (check_file_lines(file_lines, "test_line_info.py", 28))
assert (check_file_lines(file_lines, "test_line_info.py", 21))
assert (check_file_lines(file_lines, "test_line_info.py", 30))
elif func == "call_noinline":
assert (check_file_lines(file_lines, "test_line_info.py", 42))
assert (check_file_lines(file_lines, "test_line_info.py", 35))
assert (check_file_lines(file_lines, "test_line_info.py", 36))
assert (check_file_lines(file_lines, "test_line_info.py", 37))
elif func == "autotune":
assert (check_file_lines(file_lines, "test_line_info.py", 53))
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/load_store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Load 8 elements from A with two vectorized load instruction
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr -> vector<4xf32>
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
// Load 8 elements from B with two vectorized load instruction
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr -> vector<4xf32>
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
Expand Down
10 changes: 5 additions & 5 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.br
// CHECK: rocdl.barrier
// CHECK: llvm.load
// CHECK: llvm.store
// CHECK: llvm.intr.masked.store
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (!tt.ptr<f32>, f32, i1) -> f32
tt.store %arg0, %0 : !tt.ptr<f32>
tt.return
Expand All @@ -25,10 +25,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.cond_br
// CHECK: llvm.atomicrmw
// CHECK: llvm.atomicrmw
// CHECK: %[[ADDR1:.*]] = llvm.extractvalue
// CHECK: %[[ADDR2:.*]] = llvm.extractvalue
// CHECK: llvm.store %{{.*}}, %[[ADDR1]]
// CHECK: llvm.store %{{.*}}, %[[ADDR2]]
// CHECK: %[[ADDR1:.*]] = llvm.addrspacecast
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR1]]
// CHECK: %[[ADDR2:.*]] = llvm.addrspacecast
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR2]]
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
tt.store %arg0, %0 : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_reduce_data_duplication(pm)
if use_new_pipeliner or options.num_stages != 0:
amd.passes.ttgpuir.add_reorder_instructions(pm)
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
Expand Down
52 changes: 24 additions & 28 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#include "PatternTritonGPUOpToLLVM.h"
#include "TargetInfo.h"
#include "Utility.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using namespace mlir;
Expand Down Expand Up @@ -86,6 +91,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
}
return mask;
}

// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(const AMD::TargetInfo &targetInfo,
Expand Down Expand Up @@ -192,7 +198,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
auto cacheMod = op.getCache();
SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;

const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
Expand All @@ -218,8 +223,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
Value v = undef(vecTy);
for (size_t s = 0; s < vec; ++s) {
Value otherElem = otherElems[vecStart + s];
Value indexVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
Value indexVal = LLVM::createIndexConstant(
rewriter, loc, this->getTypeConverter(), s);
v = insert_element(vecTy, v, otherElem, indexVal);
}
falseVal = v;
Expand Down Expand Up @@ -259,6 +264,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.getPtr();
Value value = op.getValue();
Value mask = op.getMask();

Value llPtr = adaptor.getPtr();
Value llMask = adaptor.getMask();
Expand All @@ -281,24 +287,24 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
// Determine the vectorization size
SmallVector<Value> maskElems;
if (llMask) {
Value mask = op.getMask();
maskElems = unpackLLElements(loc, llMask, rewriter);
assert(valueElems.size() == maskElems.size());

unsigned maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
}

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;

auto cacheMod = op.getCache();
const int numVecs = elemsPerThread / vec;
Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask;
auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec);

const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
const size_t totalWidth = valueElemNBits * vec;
Expand All @@ -307,33 +313,23 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
const size_t wordNElems = width / valueElemNBits;
assert(wordNElems * nWords * numVecs == elemsPerThread);

// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.

Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems);

SmallVector<std::pair<Value, std::string>> asmArgs;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = undef(wordTy);
// Insert each value element to the composition
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = sext(i8_ty, elem);
elem = bitcast(elem, valueElemTy);

llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
auto address = ptrElems[vecStart + wordIdx * wordNElems];
llStore(rewriter, loc, address, llWord, maskVal, cacheMod);
Value elem = valueElems[vecStart];
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);

// Create the store val
Value storeVal = undef(vecTy);
for (size_t s = 0; s < vec; ++s) {
Value otherElem = valueElems[vecStart + s];
Value indexVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
storeVal = insert_element(vecTy, storeVal, otherElem, indexVal);
}
}
llStore(rewriter, loc, ptr, storeVal, pred, cacheMod);
} // end vec
rewriter.eraseOp(op);
return success();
}
Expand Down
69 changes: 68 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "Utility.h"
#include "PatternTritonGPUOpToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

using mlir::triton::gpu::appendOrGetExternFuncOp;
using mlir::triton::gpu::getFunctionType;
Expand Down Expand Up @@ -35,6 +39,35 @@ std::string mangleFunc(std::string name, Type type) {
}
return mangled;
}

// Utility function to create a constant vector mask of length `vecSize` with
// the same `pred` value
Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc,
Value pred, int64_t vecSize) {
auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize);
Value maskVal = undef(vecMaskTy);
for (size_t s = 0; s < vecSize; ++s) {
Value indexVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64IntegerAttr(s));
maskVal = insert_element(vecMaskTy, maskVal, pred, indexVal);
}
return maskVal;
}

// Utility function to get the number of elements of a vector or a scalar
int64_t getNumElements(Type ty) {
if (auto vecType = dyn_cast<VectorType>(ty))
return vecType.getNumElements();
return 1;
}

// Utility function to cast the given scalar or vector type to a vector type
Type castToVectorType(Type ty) {
if (isa<VectorType>(ty))
return ty;
return LLVM::getFixedVectorType(ty, 1);
}

} // namespace

namespace mlir::LLVM::AMD {
Expand Down Expand Up @@ -157,6 +190,25 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,

Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred, Value falseVal, triton::CacheModifier cm) {

// Try to emit llvm.intr.masked.load if we can. In theory the backend should
// be happier because we emit less branchy code to optimize. The backend will
// lower it down however it wants at some point.
if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) {
// `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need
// to bitcast to `vector<1xelemTy>` (and back)
int64_t vecSize = getNumElements(elemTy);
Type vecType = castToVectorType(elemTy);
falseVal = bitcast(falseVal, vecType);
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
bool nt = (cm == triton::CacheModifier::CG);
Value vecData = rewriter.create<LLVM::MaskedLoadOp>(
loc, vecType, ptr, maskVal, falseVal, vecSize, nt);
// If it is not a vector, remember to bitcast back to a scalar
vecData = bitcast(vecData, elemTy);
return vecData;
}

Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal}));
auto parent = ptr.getParentRegion()->getParentOfType<LLVM::LLVMFuncOp>();
auto getLoadNameRaw = [](triton::CacheModifier cm) {
Expand All @@ -173,7 +225,6 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
};

auto funcName = mangleFunc(getLoadNameRaw(cm), funcType);

LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
auto loadVal =
Expand All @@ -185,6 +236,22 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,

void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred, triton::CacheModifier cm) {
// Try to emit llvm.intr.masked.store if we can. In theory the backend should
// be happier because we emit less branchy code to optimize. The backend will
// lower it down however it wants at some point.
if (cm == triton::CacheModifier::NONE) {
// `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need
// to bitcast to `vector<1xelemTy>`
Type elemTy = val.getType();
int64_t vecSize = getNumElements(elemTy);
Type vecType = castToVectorType(elemTy);
val = bitcast(val, vecType);
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
auto op =
rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal, vecSize);
return;
}

auto ctx = ptr.getContext();
Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred}));
auto parent = ptr.getParentRegion()->getParentOfType<LLVM::LLVMFuncOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <utility>

Expand Down Expand Up @@ -225,17 +226,25 @@ Value getScalarConstant(IRRewriter &rewriter, Location loc, Value expr) {
Operation *op = expr.getDefiningOp();

// Check for splatness
if (auto splatOp = dyn_cast<triton::SplatOp>(op))
if (auto splatOp = dyn_cast_or_null<triton::SplatOp>(op))
return splatOp.getSrc();

// Check for constant
DenseIntElementsAttr constVal;
if (auto constOp = dyn_cast<arith::ConstantOp>(op)) {
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) {
Value val = constOp.getResult();
if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat())
return rewriter.create<arith::ConstantOp>(
loc, constVal.getSplatValue<IntegerAttr>());
}

// Check for block arguments
if (auto blockArg = dyn_cast_or_null<BlockArgument>(expr)) {
Type type = blockArg.getType();
if (!isa<RankedTensorType>(type))
return blockArg;
}

return Value();
}

Expand Down Expand Up @@ -318,6 +327,14 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
return {scalarConst, tensorZero};
}

// Base case 2: block argument. Since it is not a scalar constant, it must be
// a tensor. Note that this means we won't be able to decompose across loop
// boundaries (TODO: giuseros).
if (auto blockArg = dyn_cast<BlockArgument>(expr)) {
Value scalarZero = rewriter.create<arith::ConstantIntOp>(loc, 0, bitness);
return std::make_pair(scalarZero, expr);
}

auto offsets =
llvm::TypeSwitch<Operation *, std::pair<Value, Value>>(
expr.getDefiningOp())
Expand All @@ -342,7 +359,7 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
return decomposeOffsetFromMul(loc, expr, bitness);
})
.Default([&](Operation *op) {
// Base case 2: it is not a supported operation. We assume no
// Base case 3: it is not a supported operation. We assume no
// uniform part
Value scalarZero =
rewriter.create<arith::ConstantIntOp>(loc, 0, bitness);
Expand Down

0 comments on commit c238af8

Please sign in to comment.