Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Guo-Peilin committed Oct 13, 2023
1 parent 3aaaf99 commit 407981e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ transform.sequence failures(propagate) {
// 1. use register to cache the result of ldmatrix
// 2. use register to cache the result of mma's accumulation result
// 3. store the final result from reg to smem and to gmem
// 4. use padding for output smem matrix to avoid bank conflict
// 4. use padding for output smem matrix to avoid bank conflict`
%mma = transform.structured.match ops{["nvgpu.mma.sync"]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.move_data_to_register %mma by block_mn_shape = [128, 128] smem_padding = 8 bytes = 2: (!transform.any_op) -> ()
transform.disc.move_data_to_register %mma by block_mn_shape = [128, 128] smem_padding = 8 : (!transform.any_op) -> ()
transform.disc.apply_licm %5 : !transform.any_op
transform.disc.apply_dce %5 : !transform.any_op
transform.disc.apply_cse %5 : !transform.any_op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4506,7 +4506,7 @@ transform_dialect::DISCExpandTransferRWToMemrefCopy::applyToOne(
/*dstIndices*/ ValueRange{offsetX, offsetY}, src,
/*srcIndices*/ ValueRange{offsetX, offsetY},
/*dstElements*/ b.getIndexAttr(kThreadCopyBytes / 2),
/*srcElements*/ nullptr, // TODO: add support for dynamic shape
/*srcElements*/ srcElements,
/*bypassL1*/ b.getUnitAttr());
tokens.push_back(token);
}
Expand Down Expand Up @@ -4973,10 +4973,6 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
ldMatrixB.getRes().getType(), FloatAttr::get(elementType, 0.0)));
Value matrixCReg = b.create<arith::ConstantOp>(DenseElementsAttr::get(
mma.getRes().getType(), FloatAttr::get(elementType, 0.0)));
Value matrixBReg1 = b.create<arith::ConstantOp>(DenseElementsAttr::get(
ldMatrixB.getRes().getType(), FloatAttr::get(elementType, 0.0)));
Value matrixCReg1 = b.create<arith::ConstantOp>(DenseElementsAttr::get(
mma.getRes().getType(), FloatAttr::get(elementType, 0.0)));

SmallVector<Type> outputTypes(2, b.getIndexType());
SmallVector<Value> shape;
Expand Down Expand Up @@ -5091,6 +5087,15 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
auto mmaMLoopIV = newMmaMLoop.getInductionVar();
auto mmaMShapeValue = b.create<arith::ConstantIndexOp>(mmaMShape);
auto mIndexValue = b.create<arith::DivUIOp>(mmaMLoopIV, mmaMShapeValue);
// load A from reg
for (unsigned i = 0; i < 4; ++i) {
matrixAReg = b.create<vector::InsertOp>(
b.create<vector::LoadOp>(
VectorType::get({vector_width}, elementType), aWarpRegAlloc,
ValueRange{mIndexValue, kIndexValue,
b.create<arith::ConstantIndexOp>(i), zeroIndex}),
matrixAReg, (int64_t[]){i});
}
auto newMmaNLoop = b.create<scf::ForOp>(
zeroIndex, b.create<arith::ConstantIndexOp>(warpNShape),
b.create<arith::ConstantIndexOp>(mmaNShape));
Expand All @@ -5107,15 +5112,6 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
b.create<arith::ConstantIndexOp>(i), zeroIndex}),
matrixCReg, (int64_t[]){i});
}
// load A from reg
for (unsigned i = 0; i < 4; ++i) {
matrixAReg = b.create<vector::InsertOp>(
b.create<vector::LoadOp>(
VectorType::get({vector_width}, elementType), aWarpRegAlloc,
ValueRange{mIndexValue, kIndexValue,
b.create<arith::ConstantIndexOp>(i), zeroIndex}),
matrixAReg, (int64_t[]){i});
}
// load B from reg
for (unsigned i = 0; i < 2; ++i) {
matrixBReg = b.create<vector::InsertOp>(
Expand All @@ -5139,6 +5135,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
(void)mlir::loopUnrollByFactor(newMmaNLoop, warpNShape / mmaNShape);
(void)mlir::loopUnrollByFactor(newMmaMLoop, warpMShape / mmaMShape);
(void)mlir::loopUnrollByFactor(newMmaKLoop, warpKShape / mmaKShape);

// no longer need origin mmaMLoop
for (auto user : mmaMLoop->getUsers()) user->erase();
mmaMLoop.erase();
Expand Down Expand Up @@ -5196,6 +5193,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
Value cWarpSmemCol = b.create<arith::AddIOp>(
cWarpSmemColOffset, b.create<affine::AffineApplyOp>(
cWarpSmemColMap, ValueRange{mmaNLoopIV, laneId}));
// TODO: store 4xf16 rather than 2x2xf16
b.create<vector::StoreOp>(
b.create<vector::LoadOp>(
VectorType::get({vector_width}, elementType), cWarpRegAlloc,
Expand Down Expand Up @@ -5232,10 +5230,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(

// expand genericOp to loop
b.setInsertionPoint(genericOp);
// TODO: theoretically, each thread should copy 16 bytes to achieve best
// performance, however there is bug when lowering mlir code to ptx code
// during converting load/store-vector if we set vector_width to 16
int64_t cpBytesPerThread = getBytes();
int64_t cpBytesPerThread = 16; // 128 bits
int64_t cpElementsPerThread = cpBytesPerThread / 2;
int64_t cpThreadsPerRow = blockNShape / cpElementsPerThread;
int64_t cpRowsPerBlock = 128 * cpElementsPerThread / blockNShape;
Expand All @@ -5256,12 +5251,17 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
auto offsetX = b.create<arith::AddIOp>(
iv, b.create<arith::DivUIOp>(
threadId, b.create<arith::ConstantIndexOp>(cpThreadsPerRow)));
// TODO: enable mask store
b.create<vector::StoreOp>(
b.create<vector::LoadOp>(
VectorType::get({cpBytesPerThread / 2}, elementType), cBlockSmemAlloc,
ValueRange{offsetX, offsetY}),
outputSubView, ValueRange{offsetX, offsetY});
auto dimM = b.create<memref::DimOp>(outputSubView.getSource(),
b.create<arith::ConstantIndexOp>(0));
auto dimN = b.create<memref::DimOp>(outputSubView.getSource(),
b.create<arith::ConstantIndexOp>(1));
// TODO: enable mask vector store
auto vec8xf16 = b.create<vector::LoadOp>(
VectorType::get({cpElementsPerThread}, elementType), cBlockSmemAlloc,
ValueRange{offsetX, offsetY});
auto vecStore = b.create<vector::StoreOp>(vec8xf16, outputSubView,
ValueRange{offsetX, offsetY});
vecStore->setAttr("alignment", IntegerAttr::get(b.getI32Type(), 16));
b.setInsertionPointAfter(smemToGmemLoop);
b.create<gpu::BarrierOp>();
(void)mlir::loopUnrollByFactor(smemToGmemLoop, blockMShape / cpRowsPerBlock);
Expand All @@ -5276,7 +5276,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne(
"greedy pattern applicatin failed");
}

// Delete any transfer_write op
// Delete any remain transfer_write op
Operation* transferWrite = nullptr;
parallelOp->walk([&](memref::AllocOp alloc) {
if (llvm::hasSingleElement(alloc->getUsers())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1668,16 +1668,14 @@ def DISCMoveDataToRegister : Op<Transform_Dialect, "disc.move_data_to_register",

let arguments = (ins TransformHandleTypeInterface:$target,
DenseI64ArrayAttr:$block_mn_shape,
I64Attr:$smem_padding,
I64Attr:$bytes);
I64Attr:$smem_padding);
let results = (outs);

let assemblyFormat = [{
$target
attr-dict
`by` `block_mn_shape` `=` $block_mn_shape
`smem_padding` `=` $smem_padding
`bytes` `=` $bytes
`:` functional-type(operands, results)
}];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -62,6 +63,48 @@ namespace {
/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"

/// Conversion vector.store with align attribute to llvm.store
class VectorStoreWithAlignToLLVMPattern
: public ConvertOpToLLVMPattern<vector::StoreOp> {
using ConvertOpToLLVMPattern<vector::StoreOp>::ConvertOpToLLVMPattern;

LogicalResult matchAndRewrite(
vector::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = storeOp.getVectorType();
if (vectorTy.getRank() > 1) return failure();
auto alignAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
if (!alignAttr) return failure();
storeOp.dump();
unsigned align = alignAttr.getInt();

auto loc = storeOp->getLoc();
MemRefType memRefTy = storeOp.getMemRefType();

// Resolve address.
auto vtype = cast<VectorType>(
this->typeConverter->convertType(storeOp.getVectorType()));
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
adaptor.getIndices(), rewriter);
// Casts a strided element pointer to a vector pointer. The vector pointer
// will be in the same address space as the incoming memref type.
Value ptr;
if ((*this->getTypeConverter()).useOpaquePointers()) {
ptr = dataPtr;
} else {
unsigned addressSpace =
*(*this->getTypeConverter()).getMemRefAddressSpace(memRefTy);
auto pType = LLVM::LLVMPointerType::get(vtype, addressSpace);
ptr = rewriter.create<LLVM::BitcastOp>(loc, pType, dataPtr);
}

rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
storeOp, adaptor.getValueToStore(), ptr, align);
return success();
}
};

/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding NVVM equivalent.
///
Expand Down Expand Up @@ -125,6 +168,8 @@ struct DiscLowerGpuOpsToNVVMOpsPass
llvmPatterns.add<GenericAtomicRMWOpLoweringWithBitcast>(
converter, /* PatternBenefit */ 3);
llvmPatterns.add<RemoveUselessUnrealizedConversionCastOp>(converter);
llvmPatterns.add<VectorStoreWithAlignToLLVMPattern>(converter,
/* PatternBenefit */ 3);
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
Expand Down

0 comments on commit 407981e

Please sign in to comment.