diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.td b/include/imex/Dialect/XeTile/IR/XeTileOps.td index 55399a7a9..cf0f6eed5 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.td +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.td @@ -447,6 +447,9 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> { mlir::Type getElementType() { return getA().getType().getElementType(); } + mlir::VectorType getOutputType() { + return getOutput().getType(); + } }]; let hasVerifier = 1; @@ -581,7 +584,7 @@ def XeTile_TransposeOp: XeTile_Op<"transpose", []> { let hasVerifier = 1; } -def XeTile_ReduceOp: XeTile_Op<"reduce", []> { +def XeTile_ReductionOp: XeTile_Op<"reduction", []> { let summary = "performs a reduction operation over a 2D vector."; let description = [{ It has the same semantics as the `vector.multi_reduction`, @@ -591,10 +594,10 @@ def XeTile_ReduceOp: XeTile_Op<"reduce", []> { let arguments = (ins Vector_CombiningKindAttr: $kind, XeTile_2DOr4DVector: $source, - DenseI64ArrayAttr: $reduction_dim); + DenseI64ArrayAttr: $reduction_dims); let results = (outs XeTile_2DOr4DVector: $result); let assemblyFormat = [{ - $kind `,` $source $reduction_dim attr-dict `:` type($source) `->` type($result) + $kind `,` $source $reduction_dims attr-dict `:` type($source) `->` type($result) }]; let hasVerifier = 1; diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.h b/include/imex/Dialect/XeTile/Transforms/Passes.h index 0f1d948a4..91b002c76 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.h +++ b/include/imex/Dialect/XeTile/Transforms/Passes.h @@ -40,6 +40,8 @@ std::unique_ptr createXeTileInitDuplicatePass(); std::unique_ptr createXeTileBlockingPass(const std::string &device = "pvc"); +std::unique_ptr +createNewXeTileBlockingPass(const std::string &device = "pvc"); std::unique_ptr createXeTileBlockAligningPass(); std::unique_ptr createXeTileWgToSgPass(); std::unique_ptr createXeTileOptimizeTransposePass(); diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index d0737931c..242a90a5e 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -130,5 +130,31 @@ def XeTileCanonicalization : Pass<"xetile-canonicalization", "::mlir::gpu::GPUMo ]; } +def NewXeTileBlocking : Pass<"new-xetile-blocking", "::mlir::gpu::GPUModuleOp">{ + let summary = "transform XeTile large tiles(input) into arrays of smaller " + "blocks with appropriate size, such that the operator on each " + "of the blocks can be mapped into one hardware instruction."; + + let description = [{ + This transform pass preprocesses the xetile program by decomposing large XeTile tiles + into smaller ones that can be handled by a hardware instruction. It is going to replace + the xetile-blocking pass. + }]; + + let constructor = "imex::createNewXeTileBlockingPass()"; + let dependentDialects = ["imex::xetile::XeTileDialect", + "mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect"]; + + let options = [ + Option<"device", "device", "std::string", + /*default=*/"\"pvc\"", + "gpu platform architecture where these ops are running"> + ]; +} + #endif // _XeTile_PASSES_TD_INCLUDED_ diff --git a/include/imex/ExecutionEngine/ImexRunnerUtils.h b/include/imex/ExecutionEngine/ImexRunnerUtils.h index b9f03023a..464ad6418 100644 --- a/include/imex/ExecutionEngine/ImexRunnerUtils.h +++ b/include/imex/ExecutionEngine/ImexRunnerUtils.h @@ -72,6 +72,11 @@ _mlir_ciface_fillResource1DRandomF16(UnrankedMemRefType *ptr, const float lower, const float upper, const bool genInt); +extern "C" IMEX_RUNNERUTILS_EXPORT void +_mlir_ciface_fillResource1DRandomF32(UnrankedMemRefType *ptr, + const float lower, const float upper, + const bool genInt); + extern "C" IMEX_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefBF16(UnrankedMemRefType *m); extern "C" IMEX_RUNNERUTILS_EXPORT void diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index d471c8f1b..83d392e54 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -736,15 +736,16 @@ extern llvm::SmallVector lowerInnerReductionWithVectorReduction( mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy, XeOneToNPatternRewriter &rewriter); -struct SgTileReduceOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; +struct SgTileReductionOpPattern + : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; mlir::LogicalResult - matchAndRewrite(xetile::ReduceOp op, OpAdaptor adaptor, + matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor, XeOneToNPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); auto elemTy = srcTy.getElementType(); - auto dims = op.getReductionDim(); + auto dims = op.getReductionDims(); // its input should be a 4D vector, and has 2 reduction dims, // otherwise run the blocking pass first. if (dims.size() != 2 || srcTy.getRank() != 4) @@ -1092,8 +1093,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, SgTileMMAOpPattern, SgUpdateTileOffsetOpPattern, SgTransposeOpPattern, SgTransposeOpPattern, SgBroadcastOpPattern, - SgTileReduceOpPattern, SgVectorCreateMaskOpPattern>(patterns.getContext(), - converter, analysis); + SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( + patterns.getContext(), converter, analysis); patterns.insert, ElementWiseOpPattern, ElementWiseOpPattern, diff --git a/lib/Dialect/XeTile/IR/XeTileOps.cpp b/lib/Dialect/XeTile/IR/XeTileOps.cpp index d52caf164..060b93a69 100644 --- a/lib/Dialect/XeTile/IR/XeTileOps.cpp +++ b/lib/Dialect/XeTile/IR/XeTileOps.cpp @@ -859,8 +859,8 @@ mlir::LogicalResult TransposeOp::verify() { return mlir::success(); } -mlir::LogicalResult ReduceOp::verify() { - auto dims = getReductionDim(); +mlir::LogicalResult ReductionOp::verify() { + auto dims = getReductionDims(); auto resShape = getResult().getType().getShape(); for (auto i : dims) if (resShape[i] != 1) diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index fdefa8e76..7bbec87d1 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -556,15 +556,16 @@ struct VectorMultiDimReductionOpPattern } }; -struct TileReduceOpPattern - : public XeTileConversion { +struct TileReductionOpPattern + : public XeTileConversion { - using XeTileConversion::XeTileConversion; + using XeTileConversion::XeTileConversion; - TileReduceOpPattern(mlir::MLIRContext *context, - imex::XeTypeConverter &converter, - TileUsageAnalysis &analysis, - std::shared_ptr ptruArch) + TileReductionOpPattern(mlir::MLIRContext *context, + imex::XeTypeConverter &converter, + TileUsageAnalysis &analysis, + std::shared_ptr ptruArch) : XeTileConversion(context, converter, analysis) { this->uArchInterface = ptruArch; } @@ -572,13 +573,13 @@ struct TileReduceOpPattern std::shared_ptr uArchInterface = nullptr; mlir::LogicalResult - matchAndRewrite(xetile::ReduceOp op, OpAdaptor adaptor, + matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor, OpPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto srcTy = op.getSource().getType(); auto elemTy = srcTy.getElementType(); auto shape = srcTy.getShape(); - auto reductionDims = op.getReductionDim(); + auto reductionDims = op.getReductionDims(); if (srcTy.getRank() != 2 || reductionDims.size() != 1) return rewriter.notifyMatchFailure( @@ -611,7 +612,7 @@ struct TileReduceOpPattern auto newSource = addPackOp(adaptor.getSource(), {blkSizes[0], blkSizes[1]}, rewriter); - auto newDest = rewriter.create( + auto newDest = rewriter.create( loc, newDestType, op.getKind(), newSource, newReductionDims); auto unpack = addUnpackOp(newDest.getResult(), rewriter); rewriter.replaceOp(op, unpack); @@ -1161,7 +1162,7 @@ void populateXeTileBlockingPatterns( VectorizableOpPattern, SCFForOpPattern, SCFYieldOpPattern, InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern, TileMMAOpPattern, UpdateTileOffsetOpPattern, - VectorMultiDimReductionOpPattern, TileReduceOpPattern, + VectorMultiDimReductionOpPattern, TileReductionOpPattern, TileBroadcastOpPattern>(patterns.getContext(), converter, analysis, ptruArch); patterns.insert, diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp new file mode 100644 index 000000000..ada144241 --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -0,0 +1,778 @@ +#include +#include +#include + +#include "BlockingAnalysis.h" + +namespace llvm { +using imex::Block; +// Implementation of llvm::DenseMapInfo for Block, required for +// using Block as a value in DenseMap. +template <> struct DenseMapInfo { + static inline Block getEmptyKey() { + return Block(-1, -1); // the empty key + } + + static inline Block getTombstoneKey() { + return Block(-2, -2); // the tombstone key + } + + static unsigned getHashValue(const Block &b) { + return hash_combine(b[0], b[1]); + } + + static bool isEqual(const Block &lhs, const Block &rhs) { return lhs == rhs; } +}; +} // namespace llvm + +namespace imex { + +// ===------------------ Block Implementation --------------------------===// + +int64_t &Block::operator[](size_t index) { + assert(index < 2 && "Index out of bounds"); + return values[index]; +} + +const int64_t &Block::operator[](size_t index) const { + assert(index < 2 && "Index out of bounds"); + return values[index]; +} + +bool Block::operator==(Block &other) const { + return values[0] == other.values[0] && values[1] == other.values[1]; +} + +bool Block::operator==(const Block &other) const { + return values[0] == other.values[0] && values[1] == other.values[1]; +} + +void Block::print(llvm::raw_ostream &os) const { + os << "[" << values[0] << ", " << values[1] << "]"; +} + +llvm::ArrayRef Block::asArrayRef() const { + return llvm::ArrayRef(values, 2); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Block blk) { + blk.print(os); + return os; +} + +// ===------------------ BlockRequests Implementation --------------------===// +// A class holding all blocking requests for a given mlir::Value. +// For convience, it also tracks the UsePoint of the value. +class BlockingRequests { +public: + BlockingRequests() = default; + BlockingRequests(int64_t h, int64_t w, mlir::Operation *user, int64_t pos) + : BlockingRequests(h, w, UsePoint(user, pos)) {} + + BlockingRequests(int64_t h, int64_t w, UsePoint point) + : BlockingRequests(Block(h, w), point) {} + + BlockingRequests(llvm::ArrayRef shape, UsePoint point) + : BlockingRequests(shape[0], shape[1], point) { + assert(shape.size() == 2 && "Invalid block size."); + } + + BlockingRequests(Block block, UsePoint point); + + bool operator==(const BlockingRequests &other) const; + bool operator!=(const BlockingRequests &other) const; + + Block getDefBlock() const; + Block getUseBlock(UsePoint point) const; + + void print(llvm::raw_ostream &os) const; + + static BlockingRequests meet(const BlockingRequests &lhs, + const BlockingRequests &rhs); + + static BlockingRequests join(const BlockingRequests &lhs, + const BlockingRequests &rhs); + + // indicate that one use of the result operand + // has decided on the inner block size. + bool isInitialized() const { return requests.size() != 0; } + + int64_t getNumUniqRequests() const { return getRequests().size(); } + + llvm::SmallVector getRequests() const { + llvm::SmallDenseSet reqs; + for (auto [point, block] : requests) + reqs.insert(block); + return llvm::SmallVector(reqs.begin(), reqs.end()); + } + + void updateDefBlock(Block block) { def = block; } + +private: + Block def; + llvm::DenseMap requests; +}; + +BlockingRequests::BlockingRequests(Block block, UsePoint point) { + assert(block && "Invalid block."); + requests.try_emplace(point, block); +} + +Block BlockingRequests::getDefBlock() const { + if (def) + return def; + if (requests.size()) + return (requests.begin()->second); + return Block(); +} + +Block BlockingRequests::getUseBlock(UsePoint point) const { + return requests.lookup(point); +} + +void BlockingRequests::print(llvm::raw_ostream &os) const { + if (!isInitialized()) { + os << "Uninitialized"; + } else { + os << "Requests (" << requests.size() << ", " + << "def: " << def << "): "; + for (auto [i, iter] : llvm::enumerate(requests)) { + os << "(" << *(iter.first).first << ", " << (iter.first).second + << "): \n\t" << iter.second; + if (i != requests.size() - 1) + os << ", "; + } + } +} + +bool BlockingRequests::operator==(const BlockingRequests &other) const { + return requests == other.requests; +} + +bool BlockingRequests::operator!=(const BlockingRequests &other) const { + return !(*this == other); +} + +BlockingRequests BlockingRequests::meet(const BlockingRequests &lhs, + const BlockingRequests &rhs) { + return join(lhs, rhs); +} + +BlockingRequests BlockingRequests::join(const BlockingRequests &lhs, + const BlockingRequests &rhs) { + BlockingRequests newReq; + if (lhs.isInitialized()) { + for (auto [point, block] : lhs.requests) { + newReq.requests.try_emplace(point, block); + } + } + if (rhs.isInitialized()) { + for (auto [point, block] : rhs.requests) { + newReq.requests.try_emplace(point, block); + } + } + return newReq; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + BlockingRequests requests) { + requests.print(os); + return os; +} + +// ===---------------- BlockingLattice Implementation -----------------===// +// A lattice wrapper for BlockingRequests +struct BlockingLattice : public mlir::dataflow::Lattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BlockingLattice) + using Lattice::Lattice; + + mlir::ChangeResult join(const AbstractSparseLattice &rhs) override { + return join(static_cast(rhs).getValue()); + } + + mlir::ChangeResult join(const BlockingRequests &other) { + auto &val = getValue(); + BlockingRequests newValue = BlockingRequests::join(val, other); + if (newValue == val) + return mlir::ChangeResult::NoChange; + val = newValue; + return mlir::ChangeResult::Change; + } +}; + +// ===----------------------BlockingAnalysisImpl ---------------------===// +class BlockingAnalysisImpl + : public mlir::dataflow::SparseBackwardDataFlowAnalysis { +public: + BlockingAnalysisImpl(mlir::DataFlowSolver &solver, + mlir::SymbolTableCollection &symbolTable, + std::shared_ptr uArch) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), uArch(uArch) {} + + void visitOperation(mlir::Operation *op, + mlir::ArrayRef operands, + mlir::ArrayRef results) override; + + void visitBranchOperand(mlir::OpOperand &operand) override {} + + void visitCallOperand(mlir::OpOperand &operand) override {} + + void setToExitState(BlockingLattice *lattice) override {} + +private: + void visitPrefetchTileOp(xetile::PrefetchTileOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitLoadTileOp(xetile::LoadTileOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitStoreTileOp(xetile::StoreTileOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitUpdateTileOp(xetile::UpdateTileOffsetOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitTileMMAOp(xetile::TileMMAOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitVectorizableOp(mlir::Operation *op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitShapecastOp(mlir::vector::ShapeCastOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitReductionOp(xetile::ReductionOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitBroadcastOp(xetile::BroadcastOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + void visitTransposeOp(xetile::TransposeOp op, + mlir::ArrayRef operands, + mlir::ArrayRef results); + + int getMaxSLMBlockSize(int elemBitWidth, int height); + + template + Block getInnerBlockSize(mlir::Operation *op, mlir::Type elemTy, + llvm::ArrayRef &shape, + int memorySpace = 0); + + llvm::SmallVector + getMMASize(mlir::Type elemTy, const int APrecision, const int BPrecision, + const int CPrecision, const int DPrecision); + +private: + std::shared_ptr uArch = nullptr; +}; + +void BlockingAnalysisImpl::visitOperation( + mlir::Operation *op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + + if (auto updateTileOp = mlir::dyn_cast(op)) + visitUpdateTileOp(updateTileOp, operands, results); + + if (auto prefetchOp = mlir::dyn_cast(op)) + visitPrefetchTileOp(prefetchOp, operands, results); + + if (auto loadOp = mlir::dyn_cast(op)) + visitLoadTileOp(loadOp, operands, results); + + if (auto storeOp = mlir::dyn_cast(op)) + visitStoreTileOp(storeOp, operands, results); + + if (auto tileMMAOp = mlir::dyn_cast(op)) + visitTileMMAOp(tileMMAOp, operands, results); + + if (auto reductionOp = mlir::dyn_cast(op)) + visitReductionOp(reductionOp, operands, results); + + if (auto transposeOp = mlir::dyn_cast(op)) + visitTransposeOp(transposeOp, operands, results); + + if (auto broadcastOp = mlir::dyn_cast(op)) + visitBroadcastOp(broadcastOp, operands, results); + + if (op->hasTrait()) + visitVectorizableOp(op, operands, results); + + if (auto shapecastOp = mlir::dyn_cast(op)) + visitShapecastOp(shapecastOp, operands, results); +} + +void BlockingAnalysisImpl::visitPrefetchTileOp( + xetile::PrefetchTileOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto tileTy = op.getTile().getType(); + auto elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto memSpace = tileTy.getMemoryScopeAsInt(); + // initialized with a default size queried from the architecture + auto size = getInnerBlockSize(op, elemTy, shape, memSpace); + if (!size) + return; // do nothing if didnot get a valid block size + auto BlockingRequest = BlockingRequests(size, UsePoint(op, 0)); + propagateIfChanged(operands[0], operands[0]->join(BlockingRequest)); +} + +void BlockingAnalysisImpl::visitLoadTileOp( + xetile::LoadTileOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto lattice = results[0]->getValue(); + + if (lattice.getNumUniqRequests() > 1) + op.emitWarning("multiple users requesting different blocking sizes."); + + auto tileTy = op.getSource().getType(); + auto elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto memSpace = tileTy.getMemoryScopeAsInt(); + // initialized with a default size queried from the architecture + Block block = getInnerBlockSize(op, elemTy, shape, memSpace); + + // It has users but users' requirements are not available yet. + // Worth to wait until all users are visited. + if (!op.getValue().use_empty() && !lattice.isInitialized()) + return; + + // adjust according to user's requirements if it is available + if (lattice.isInitialized()) { + // align the height dimension if user is a transpose op, + // otherwise align the width dimension to minimize the + // in-register data movements. + bool hasTransposeUser = op.getValue().hasOneUse() && + mlir::isa(*(op->user_begin())); + + int dim = hasTransposeUser ? 0 : 1; + for (auto rq : lattice.getRequests()) + block[dim] = std::min(block[dim], rq[dim]); + } + + if (!block) + return; // do nothing if didnot get a valid block size + + auto BlockingRequest = BlockingRequests(block, UsePoint({op, 0})); + // propagate the blocking size to its def op + propagateIfChanged(operands[0], operands[0]->join(BlockingRequest)); + + // update the def block size for the result value + BlockingRequests &def = getLatticeElement(op.getValue())->getValue(); + def.updateDefBlock(block); +} + +void BlockingAnalysisImpl::visitStoreTileOp( + xetile::StoreTileOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto tileTy = op.getTile().getType(); + auto elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto memSpace = tileTy.getMemoryScopeAsInt(); + auto size = getInnerBlockSize(op, elemTy, shape, memSpace); + + if (!size) + return; // do nothing if didnot get a valid block size + + for (auto &&[i, inputOpr] : llvm::enumerate(operands)) { + auto blockingRequest = BlockingRequests(size, UsePoint(op, i)); + propagateIfChanged(inputOpr, inputOpr->join(blockingRequest)); + } +} + +void BlockingAnalysisImpl::visitUpdateTileOp( + xetile::UpdateTileOffsetOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto lattice = results[0]->getValue(); + if (lattice.isInitialized()) { + auto block = lattice.getRequests()[0]; + auto request = BlockingRequests(block, UsePoint(op, 0)); + propagateIfChanged(operands[0], operands[0]->join(request)); + } +} + +void BlockingAnalysisImpl::visitTileMMAOp( + xetile::TileMMAOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + + auto getElemBitWidth = [](mlir::VectorType vecTy) { + return vecTy.getElementType().getIntOrFloatBitWidth(); + }; + + auto C = op.getC(); + auto aPrecision = getElemBitWidth(op.getAType()); + auto bPrecision = getElemBitWidth(op.getBType()); + auto dPrecision = getElemBitWidth(op.getOutputType()); + auto cPrecision = !C ? dPrecision : getElemBitWidth(C.getType()); + + auto mmaSize = getMMASize(op.getElementType(), aPrecision, bPrecision, + cPrecision, dPrecision); + + auto blockSizeForA = + BlockingRequests(mmaSize[0], mmaSize[1], UsePoint({op, 0})); + auto blockSizeForB = + BlockingRequests(mmaSize[1], mmaSize[2], UsePoint({op, 1})); + + propagateIfChanged(operands[0], operands[0]->join(blockSizeForA)); + propagateIfChanged(operands[1], operands[1]->join(blockSizeForB)); + if (C) { + auto blockSizeForC = + BlockingRequests(mmaSize[0], mmaSize[2], UsePoint(op, 2)); + propagateIfChanged(operands[2], operands[2]->join(blockSizeForC)); + } + + // update the def block size for the result value + BlockingRequests &def = getLatticeElement(op.getOutput())->getValue(); + def.updateDefBlock(Block(mmaSize[0], mmaSize[2])); +} + +void BlockingAnalysisImpl::visitReductionOp( + xetile::ReductionOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto srcTy = op.getSource().getType(); + auto dims = op.getReductionDims(); + // We only support reduction on 2D types now. + if (srcTy.getRank() != 2 || dims.size() != 1) + return; + + auto elemTy = srcTy.getElementType(); + auto shape = srcTy.getShape(); + // ReductionOp is special. Its blocking size is fixed to {1, + // min(subgroupSize, width)} + auto size = getInnerBlockSize(op, elemTy, shape); + if (!size) + return; // do nothing if didnot get a valid block size + + auto blockingRequest = BlockingRequests(size, UsePoint(op, 0)); + propagateIfChanged(operands[0], operands[0]->join(blockingRequest)); +} + +void BlockingAnalysisImpl::visitBroadcastOp( + xetile::BroadcastOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto srcTy = op.getSource().getType(); + auto dims = op.getBroadcastDim(); + // We only support reduction on 2D types now. + if (srcTy.getRank() != 2 || dims.size() != 1) + return; + + auto elemTy = srcTy.getElementType(); + auto shape = srcTy.getShape(); + // BroadcastOp is special. Its blocking size is fixed to {1, + // min(subgroupSize, width)} + auto size = getInnerBlockSize(op, elemTy, shape); + if (!size) + return; // do nothing if didnot get a valid block size + + auto blockingRequest = BlockingRequests(size, UsePoint(op, 0)); + propagateIfChanged(operands[0], operands[0]->join(blockingRequest)); +} + +void BlockingAnalysisImpl::visitTransposeOp( + xetile::TransposeOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + + auto permutation = op.getPermutation(); + auto resType = op.getResult().getType(); + // we only support true 2D transpose now + if (resType.getRank() != 2 || permutation != mlir::ArrayRef({1, 0})) + return; + + auto lattice = results[0]->getValue(); + + // Wait for requests from users. + if (!op->use_empty() && !lattice.isInitialized()) + return; + + Block block; + + // use the default size if no users + if (op->use_empty()) { + auto srcTy = op.getVector().getType(); + auto shape = srcTy.getShape(); + block = getInnerBlockSize(op, srcTy.getElementType(), shape); + } + + // TransposeOp determines its blocking size based on requests from + // its users, by swapping the blocking size of its users. + if (lattice.isInitialized()) { + // TODO: handle multiple users + if (lattice.getNumUniqRequests() == 1) { + auto req = lattice.getRequests()[0]; + block = Block(req[1], req[0]); + } + } + + if (!block) + return; // do nothing if didnot get a valid block size + + auto request = BlockingRequests(block, UsePoint(op, 0)); + propagateIfChanged(operands[0], operands[0]->join(request)); + + // update the def block size for the result value + BlockingRequests &def = getLatticeElement(op.getResult())->getValue(); + def.updateDefBlock(Block(block[1], block[0])); +} + +void BlockingAnalysisImpl::visitVectorizableOp( + mlir::Operation *op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + // Currently only supports simple elementwise math ops. + if (op->getNumResults() != 1) + return; + + auto type = mlir::dyn_cast(op->getResult(0).getType()); + if (!type) + return; + + auto lattice = results[0]->getValue(); + + // Wait for requests from users. + if (!op->use_empty() && !lattice.isInitialized()) + return; + + auto elemTy = type.getElementType(); + auto shape = type.getShape(); + Block block = getInnerBlockSize(op, elemTy, shape); + + // elementwise operations are not sensitive to the block size. + // It will use the block size requested by its users. + if (lattice.isInitialized()) { + block[0] = 0; + for (auto &req : lattice.getRequests()) { + block[0] = std::max(block[0], req[0]); + block[1] = std::min(block[1], req[1]); + } + } + + // do nothing if get an invalid block + if (!block) + return; + + // propagate the block size on its operands + for (auto &&[i, inputOpr] : llvm::enumerate(operands)) { + auto req = BlockingRequests(block, UsePoint(op, i)); + propagateIfChanged(inputOpr, inputOpr->join(req)); + } + + // update the def block size for the result value + BlockingRequests &def = getLatticeElement(op->getResult(0))->getValue(); + def.updateDefBlock(block); +} + +void BlockingAnalysisImpl::visitShapecastOp( + mlir::vector::ShapeCastOp op, mlir::ArrayRef operands, + mlir::ArrayRef results) { + auto shape = op.getSource().getType().getShape(); + if (shape.size() == 2) { + auto BlockingRequest = BlockingRequests(shape, UsePoint(op, 0)); + propagateIfChanged(operands[0], operands[0]->join(BlockingRequest)); + } +} + +int BlockingAnalysisImpl::getMaxSLMBlockSize(int elemBitWidth, int height) { + // TODO: use uArch to get max vec size? + const int lscConstraint = 512; // lsc supports upto 512 bytes per load/store + int numElems = (lscConstraint * 8) / elemBitWidth; + int width = numElems / height; + return width; +} + +// Determine the inner block size for the given operation based on the +// operand's element data type, shape, and also memory space. +template +Block BlockingAnalysisImpl::getInnerBlockSize( + mlir::Operation *op, mlir::Type elemTy, llvm::ArrayRef &shape, + int memorySpace) { + assert(elemTy.isIntOrFloat() && "only support int or float element type."); + + // TODO: get from uArch ? + const int64_t subgroupSize = 16; + int elemSize = elemTy.getIntOrFloatBitWidth(); + + int maxHeight = 0, minHeight = 0, maxWidth = 0, minWidth = 0; + if (mlir::isa(op) || + mlir::isa(op)) { + // for reduction and broadcast ops, we simply using + // [1, subgroupSize] as innerblock size + maxWidth = subgroupSize; + minWidth = 1; + maxHeight = 1; + minHeight = 1; + } else if (op->hasTrait()) { + // for elementwise operations, they are pretty flexiable + // on the block size. But we expect its second dimension + // is subgroupSize aligned. + minWidth = 1; + minHeight = 1; + maxWidth = std::min(shape[1], subgroupSize); + maxHeight = shape[0]; + } else if (mlir::isa(op)) { + // for transpose op, we will use the original shape + // as the default size, and adjust it if it is defined + // by a load op + minWidth = 1; + minHeight = 1; + maxWidth = shape[1]; + maxHeight = shape[0]; + + // if the transpose follows a load op, and data element is 32-bit + // or 64-bit, it is expected to be folded with a load, and need to + // be aligned to hardware constraints. + auto defOp = op->getOperand(0).getDefiningOp(); + if (defOp && elemSize >= 32) { + auto params = uArch->get2DLoadConfig(defOp, elemSize, false, true); + minHeight = params->blockHeight.min; + minWidth = params->blockWidth.min; + // to be compatible with the SIMT instrinsic, the maximum height is + // limited to 16, which is maximum supported value by SIMT instrinsic. + maxHeight = std::min(params->blockHeight.max, 16); + maxWidth = params->blockWidth.max; + } + } else if (memorySpace == 3) { + // this is supposed for load/store from/to SLM, they will use regular + // load/store instructions with chunk size. lsc instrinsic and hardware + // has serveral limits on the size per load/store. + minHeight = minWidth = 1; + // If shape[0] is divisible by subgroup size, we use regular load (with + // chunk size) with XeGPU.load_gather (maxHeight = 16). Otherwise, we + // use 1D load with XeGPU.load_nd(1d, maxHeight = 1). + maxHeight = shape[0] % subgroupSize == 0 ? subgroupSize : 1; + maxWidth = getMaxSLMBlockSize(elemSize, maxHeight); + } else { // for load/store from/to global memory + mlir::FailureOr params; + if (mlir::isa(op)) + params = uArch->get2DStoreConfig(elemSize); + if (mlir::isa(op) || + mlir::isa(op)) { + bool transpose = false; + // if its user is a transpose op, and data element is 32-bit + // or 64-bit, we will use the transpose supported size. + if (auto loadOp = mlir::dyn_cast(op)) { + auto value = loadOp.getValue(); + transpose = elemSize >= 32 && value.hasOneUse() && + mlir::isa(*(value.user_begin())); + } + params = uArch->get2DLoadConfig(op, elemSize, false, transpose); + } + if (mlir::succeeded(params)) { + maxHeight = params->blockHeight.max; + minHeight = params->blockHeight.min; + maxWidth = params->blockWidth.max; + minWidth = params->blockWidth.min; + } + } + + auto findLargestDivisorInRange = [&](int64_t v, int64_t l, int64_t h) { + for (int i = h; i >= l; i--) { + if (v % i == 0) + return i; + } + // irregular shape or shape is not in the supported range. + return 0; + }; + + auto height = findLargestDivisorInRange(shape[0], minHeight, maxHeight); + auto width = findLargestDivisorInRange(shape[1], minWidth, maxWidth); + return Block(height, width); +} + +llvm::SmallVector +BlockingAnalysisImpl::getMMASize(mlir::Type elemTy, const int APrecision, + const int BPrecision, const int CPrecision, + const int DPrecision) { + assert(elemTy.isIntOrFloat() && "only support int or float data type."); + auto dpasParams = + uArch->getDPASConfig(APrecision, BPrecision, CPrecision, DPrecision); + return llvm::SmallVector( + {dpasParams.m, dpasParams.k, dpasParams.n}); +} + +// ===--------------------------------BlockingAnalysis---------------------------------===// + +mlir::LogicalResult BlockingAnalysis::run(mlir::Operation *op) { + mlir::SymbolTableCollection symbolTable; + // BlockingAnalysisImpl is using default initialize method + // provided by SparseBackwardDataFlowAnalysis. And this default + // initialize method relies on results of DeadCodeAnalysis to + // skip analysis on the dead code. + solver.load(); + solver.load(); + solver.load(symbolTable, uArch); + target = op; + return solver.initializeAndRun(op); +} + +void BlockingAnalysis::printAnalysisResult() { + llvm::dbgs() << "\n\nBlockingAnalysis Results:\n"; + target->walk([&](mlir::Operation *op) { + if (op->getNumRegions() == 0 && op->getNumResults() == 1) { + auto resTy = op->getResult(0).getType(); + if (mlir::isa(resTy) || + mlir::isa(resTy)) { + llvm::dbgs() << "\nOp: " << *op; + for (auto [i, inputOpr] : llvm::enumerate(op->getOperands())) { + if (mlir::isa(inputOpr.getType()) || + mlir::isa(inputOpr.getType())) { + UsePoint p(op, i); + llvm::dbgs() << "\n opr[" << i << "]: " << inputOpr + << " --> blkSZ: " << getUseBlockSize(inputOpr, p); + } + } + + for (auto [i, res] : llvm::enumerate(op->getResults())) + llvm::dbgs() << "\n res[" << i << "]: " << res + << " --> blkSZ: " << getDefBlockSize(res); + llvm::dbgs() << "\n"; + } + } else if (auto forOp = mlir::dyn_cast(op)) { + llvm::dbgs() << "\nOp: " << op->getName(); + for (auto [i, arg] : llvm::enumerate(forOp.getRegionIterArgs())) + llvm::dbgs() << "\n arg[" << i << "]: " + << " --> blkSZ: " << getDefBlockSize(arg); + + for (auto [i, res] : llvm::enumerate(forOp.getResults())) + llvm::dbgs() << "\n res[" << i << "]: " + << " --> blkSZ: " << getDefBlockSize(res); + llvm::dbgs() << "\n"; + } else if (auto YieldOp = mlir::dyn_cast(op)) { + llvm::dbgs() << "\nOp: " << op->getName(); + for (auto [i, res] : llvm::enumerate(YieldOp.getResults())) + llvm::dbgs() << "\n res[" << i << "]: " << res + << " --> blkSZ: " << getDefBlockSize(res) << ", " + << getUseBlockSize(res, UsePoint(op, i)); + llvm::dbgs() << "\n"; + } else if (auto StoreOp = mlir::dyn_cast(op)) { + llvm::dbgs() << "\nOp: " << *op; + for (auto [i, inputOpr] : llvm::enumerate(op->getOperands())) { + llvm::dbgs() << "\n opr[" << i << "]: " << inputOpr << " --> blkSZ: " + << getUseBlockSize(inputOpr, UsePoint(StoreOp, i)); + } + llvm::dbgs() << "\n"; + } + }); +} + +Block BlockingAnalysis::getUseBlockSize(mlir::Value val, UsePoint point) const { + auto *state = solver.lookupState(val); + if (!state) + return Block(); + return state->getValue().getUseBlock(point); +} + +Block BlockingAnalysis::getDefBlockSize(mlir::Value val) const { + auto *state = solver.lookupState(val); + if (!state) + return Block(); + return state->getValue().getDefBlock(); +} + +} // namespace imex diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.h b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.h new file mode 100644 index 000000000..96f2249e2 --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.h @@ -0,0 +1,68 @@ + +#ifndef IMEX_BLOCKING_ANALYSIS_H +#define IMEX_BLOCKING_ANALYSIS_H + +#include +#include +#include + +#include + +#include "imex/Utils/XeArch.h" + +namespace imex { + +/// a class representing a inner block size, provides some +/// convinient methods for manipulation. +class Block { +public: + Block() : values{0, 0} {} + + Block(int64_t h, int64_t w) : values{h, w} {} + + int64_t &operator[](size_t index); + const int64_t &operator[](size_t index) const; + + bool operator==(Block &other) const; + bool operator==(const Block &other) const; + + bool operator!=(Block &other) const { return !(*this == other); } + bool operator!=(const Block &other) const { return !(*this == other); } + + void print(llvm::raw_ostream &os) const; + + llvm::ArrayRef asArrayRef() const; + + operator bool() const { return values[0] != 0 && values[1] != 0; } + +private: + int64_t values[2]; +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Block blk); + +// A pair of operator and operand index number representing +// the use point of a value. +typedef std::pair UsePoint; + +class BlockingAnalysis { +public: + explicit BlockingAnalysis(std::shared_ptr uArch) { + this->uArch = uArch; + }; + + mlir::LogicalResult run(mlir::Operation *op); + + Block getUseBlockSize(mlir::Value val, UsePoint point) const; + Block getDefBlockSize(mlir::Value val) const; + void printAnalysisResult(); + +private: + mlir::DataFlowSolver solver; + std::shared_ptr uArch; + mlir::Operation *target; +}; + +} // namespace imex + +#endif // IMEX_BLOCKING_ANALYSIS_H diff --git a/lib/Dialect/XeTile/Transforms/BlockingRewrite.cpp b/lib/Dialect/XeTile/Transforms/BlockingRewrite.cpp new file mode 100644 index 000000000..74604dd52 --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/BlockingRewrite.cpp @@ -0,0 +1,875 @@ +//===-------------- Blocking.cpp --------- Blocking Pass -------*- C++ -*-===// +// +// Copyright 2024 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains lowering transformation for determing the problem size +/// that can be handled by an XeGPU operator (hardware instruction). XeTile +/// program can work one bigger problem size that cannot be handled by a +/// hardware instruction. But it needs to be decomposed into smaller pieces +/// such that each pieces can be handled by a hardware instruction. +/// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "imex/Dialect/XeTile/Transforms/Passes.h" +#include "imex/Utils/DebugUtils.h" +#include "imex/Utils/XeArch.h" + +#include "BlockingAnalysis.h" +#include "PassDetail.h" + +using namespace mlir; +using namespace llvm; +using namespace imex; +namespace imex { +#define GEN_PASS_DECL_NEWXETILEBLOCKING +#define GEN_PASS_DEF_NEWXETILEBLOCKING +#include "imex/Dialect/XeTile/Transforms/Passes.h.inc" +} // namespace imex + +namespace imex { +namespace Blocking { + +static xetile::TileUnpackOp +addUnpackOp(mlir::Value src, mlir::ConversionPatternRewriter &rewriter) { + auto srcTy = llvm::dyn_cast_if_present(src.getType()); + assert(srcTy && srcTy.getRank() == 4); + auto shape = srcTy.getShape(); + auto grids = shape.take_front(2); + auto innerBlocks = shape.take_back(2); + llvm::SmallVector unpackShape( + {grids[0] * innerBlocks[0], grids[1] * innerBlocks[1]}); + + auto unpackTy = mlir::VectorType::get(unpackShape, srcTy.getElementType()); + return rewriter.create( + src.getLoc(), unpackTy, src, + mlir::DenseI64ArrayAttr::get(src.getContext(), innerBlocks)); +} + +static mlir::Value addPackOp(mlir::Value src, + llvm::ArrayRef targetBlkSizes, + mlir::ConversionPatternRewriter &rewriter) { + auto srcTy = mlir::dyn_cast(src.getType()); + assert(srcTy && targetBlkSizes.size() == 2); + auto shape = srcTy.getShape(); + llvm::SmallVector packShape({shape[0] / targetBlkSizes[0], + shape[1] / targetBlkSizes[1], + targetBlkSizes[0], targetBlkSizes[1]}); + + auto packTy = mlir::VectorType::get(packShape, srcTy.getElementType()); + auto packOp = rewriter.create( + src.getLoc(), packTy, src, + mlir::DenseI64ArrayAttr::get(src.getContext(), targetBlkSizes)); + return packOp; +} + +/// OpConversionPatternWithAnalysis is a wrapper around OpConversionPattern +/// but takes an extra AnalysisT object as an argument, such that patterns +/// can leverage the analysis results. +template +class OpConversionPatternWithAnalysis + : public mlir::OpConversionPattern { +public: + using OpPatternRewriter = typename mlir::ConversionPatternRewriter; + + OpConversionPatternWithAnalysis(mlir::MLIRContext *context, + AnalysisT &analysis) + : mlir::OpConversionPattern(context), analysis(analysis) {} + +protected: + AnalysisT &analysis; +}; + +/// OpTraitConversionPatternWithAnalysis is a wrapper around +/// OpTraitConversionPattern but takes an extra AnalysisT object as an argument, +/// such that patterns can leverage the analysis results. +template