diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h new file mode 100644 index 000000000..1925d4fcb --- /dev/null +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -0,0 +1,308 @@ +//===-- VectorBasedFusionAnalysis.h - vector fusion analysis ----*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_VECTORBASEDFUSIONANALYSIS_H +#define MLIR_ANALYSIS_VECTORBASEDFUSIONANALYSIS_H + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/Utils.h" +#include "gc/Dialect/Microkernel/MicrokernelOps.h" +#include "gc/Transforms/Passes.h" +#include "gc/Transforms/Utils/VectorUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +namespace mlir { +namespace gc { + +/// record hardware information +struct HardWareInfo { + size_t vectorWidth = 0; +}; + +/// Vector type conversion helper class +class TypeHelper { +private: + HardWareInfo info; + +public: + TypeHelper() = default; + TypeHelper(HardWareInfo info) : info(info) {} + /// get current hardware information + HardWareInfo &getHardwareInfo() { return this->info; } + /// use \param info to set hardware information + void setHardWareInfo(HardWareInfo &info) { this->info = info; } + /// get vector \param type max loop step according to hardware information + int getDataTypeValidSteps(VectorType type); + /// get vector \param type an even for loop step + int generateValidSteps(int steps, VectorType type); + /// get vector \param type an even for loop step when shape dimension is + /// shapeDim + int generateValidSteps(int steps, VectorType type, int shapeDim); + /// get vector \param type max simd length according to hardware + /// information + int getDataTypeMAXSIMDLength(VectorType type); + /// get operation's vector type + VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); +}; + +/// operation return kind, which is used to determine whether the operation +/// need to return it's result in current for loop +enum class ReturnTypeKind { + RT_Both, + RT_OutGroup, + RT_InGroup, +}; + +/// Base class of vector-based fusion. +class VectorFusionBase { + +private: + /// current function IR + func::FuncOp func; + /// Type helper class, can help us to get operation type + TypeHelper typehelper; + /// IR rewriter + IRRewriter *rewriter; + +public: + VectorFusionBase(func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter) + : func(func), typehelper(info), rewriter(rewriter) {} + VectorFusionBase(VectorFusionBase &base, IRRewriter *rewriter) + : func(base.getFunction()), typehelper(base.getHardwareInfo()), + rewriter(rewriter) {} + + /// get current function IR + func::FuncOp &getFunction() { return func; } + /// get current hardware info + HardWareInfo &getHardwareInfo() noexcept { + return typehelper.getHardwareInfo(); + } + TypeHelper &getTypeHelper() noexcept { return typehelper; } + IRRewriter *getRewriter() noexcept { return rewriter; } + void setRewriter(IRRewriter *rewriter) noexcept { this->rewriter = rewriter; } +}; + +/// Group operation fusion strategy class. +/// 1. Classify operaions: +/// classify the operations into : +/// a. reorder, transpose. Reorder(or transpose) dim may bring data +/// dependency. +/// b. elemenwise. Those operations can be fused into a common for loop. +/// c. broadcast. Need to analysis broadcast dim and the data +/// dependency. +/// d. reduction. Need to analysis broadcast dim and the +/// data dependency. +/// Same group operations have no data dependencies. They can be fused into a +/// common for loop body. + +/// Using queue to store the operation order. In order to ensure that +/// subsequent moves to the operation will not cause semantic changes. +class GroupOperationFusion : public VectorFusionBase { +private: + /// operation groups, operations in each group can generate a common for + /// loop + SmallVector, 8> opGroups; + /// group max vectorize steps + SmallVector groupMaxSteps; + /// vector type which has bigest rank in current operation group + llvm::SmallDenseMap groupBigestRankVectorType; + /// query current operation in which group, return group index + DenseMap opGroupIndexMap; + /// can fused into prev operation which axis position + DenseMap opAnchorPos; + /// record some operations which not need to No need to judge whether can be + /// fused + std::queue notNeedToJudgeOps; + /// analysis the operation's operands and results + SmallVector>, 8> + groupOpResults; + /// store loop iteration args for each of operation group + SmallVector, 8> groupOpInitArgs; + // store read and write operations permutation maps in order to convenient + // to replace loop induction var + DenseMap opPermuationMap; + /// record operation operand original operate value + DenseMap operandOriginalValue; + +public: + GroupOperationFusion(func::FuncOp &func, HardWareInfo &info, + IRRewriter *rewriter) + : VectorFusionBase(func, info, rewriter) {} + + GroupOperationFusion(GroupOperationFusion &strategy, IRRewriter *rewriter) + : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(), + rewriter), + opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps), + opGroupIndexMap(strategy.opGroupIndexMap), + opAnchorPos(strategy.opAnchorPos){}; + + GroupOperationFusion(GroupOperationFusion &&strategy, IRRewriter *rewriter) + : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(), + rewriter), + opGroups(std::move(strategy.opGroups)), + groupMaxSteps(std::move(strategy.groupMaxSteps)), + groupBigestRankVectorType( + std::move(strategy.getGroupBiggestRankVectorType())), + opGroupIndexMap(std::move(strategy.opGroupIndexMap)), + opAnchorPos(std::move(strategy.opAnchorPos)){}; + + GroupOperationFusion &operator=(GroupOperationFusion &fusion) { + this->getOpGroups() = fusion.getOpGroups(); + this->getGroupMaxSteps() = fusion.getGroupMaxSteps(); + this->getGroupBiggestRankVectorType() = + fusion.getGroupBiggestRankVectorType(); + this->getOpGroupIndexMap() = fusion.getOpGroupIndexMap(); + this->getOpAnchorPos() = fusion.getOpAnchorPos(); + this->notNeedToJudgeOps = fusion.notNeedToJudgeOps; + this->getGroupOpResults() = fusion.getGroupOpResults(); + this->getGroupOpInitArgs() = fusion.getGroupOpInitArgs(); + this->getOpPermuationMap() = fusion.getOpPermuationMap(); + this->getOperandOriginalValue() = fusion.getOperandOriginalValue(); + this->getFunction() = fusion.getFunction(); + this->getHardwareInfo() = fusion.getHardwareInfo(); + this->getTypeHelper() = fusion.getTypeHelper(); + this->setRewriter(fusion.getRewriter()); + return *this; + }; + + /// Get the map which contains each group vector type which has biggest + /// rank. + llvm::SmallDenseMap & + getGroupBiggestRankVectorType() noexcept { + return groupBigestRankVectorType; + }; + /// Get the operation group obtained by fusion strategy analysis + SmallVector, 8> &getOpGroups() noexcept { + return opGroups; + } + /// Get the operation belong to which group index map + DenseMap &getOpGroupIndexMap() noexcept { + return opGroupIndexMap; + } + /// Get the map contains max steps of each group + SmallVector &getGroupMaxSteps() noexcept { + return groupMaxSteps; + } + /// Get the map contains anchor position of each operation + DenseMap &getOpAnchorPos() noexcept { + return opAnchorPos; + } + /// get current operation group results + SmallVector>, 8> & + getGroupOpResults() noexcept { + return groupOpResults; + } + + SmallVector, 8> &getGroupOpInitArgs() noexcept { + return groupOpInitArgs; + } + + DenseMap &getOpPermuationMap() noexcept { + return opPermuationMap; + } + + DenseMap &getOperandOriginalValue() noexcept { + return operandOriginalValue; + } + /// set operation groups + void setGroupOpResults( + const SmallVector< + llvm::MapVector>, 8> + &results) { + groupOpResults = std::move(results); + } + + void setGroupOpIterArgs( + const SmallVector, 8> &initArgs) noexcept { + groupOpInitArgs = std::move(initArgs); + } + + void setPermutationMap(const DenseMap &map) noexcept { + opPermuationMap = std::move(map); + } + /// Do fusion strategy + void classifyOperations(); + + /// Whether two operations have compatible vector shapes + bool isCompatibleVectorType(Operation *op1, Operation *op2); + + /// update bigest vector type for last operation group + void updateGroupBigestVectorType(VectorType vectorType); + + /// Check whether the operation can fuse with previous operation + bool isNeedNewGroup(Operation *op); + + /// Add Operation \p op into current last group or a new Group + /// \p op must has valid value, can't be nullptr + void addOperationToGroup(Operation *op); + + /// get next operation in current operation group + template + Operation *getNextTargetOperationInCurrentGroup(Operation *curOp, + const size_t grpIdx); + + /// run the vector-based fusion strategy + void run(); +}; + +template +Operation *GroupOperationFusion::getNextTargetOperationInCurrentGroup( + Operation *curOp, const size_t grpIdx) { + std::queue tmpOpQueue(getOpGroups()[grpIdx]); + if (isa(curOp)) + return curOp; + + while (!tmpOpQueue.empty()) { + auto frontOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + if (not isa(frontOp)) + continue; + for (auto x : frontOp->getOperands()) + if (x.getDefiningOp() == curOp) + return frontOp; + } + return nullptr; +} + +/// Analysis each operation group class. +/// Currently it will run vector-base fusion, analysis empty group and each +/// operation group's max vectorized step. +class GroupOperationAnalysis { +private: + /// vector-based fusion related data + GroupOperationFusion fusionStrategy; + IRRewriter *rewriter; + +public: + GroupOperationAnalysis(func::FuncOp &func, HardWareInfo &info, + IRRewriter *rewriter) + : fusionStrategy(func, info, rewriter), rewriter(rewriter) {} + /// remove the useless operation, due to it result is not require by other + /// operation + void analysisEmptyGroup(); + /// get each operation in each group maximum support vectorization length + void analysisGroupMaxSteps(); + /// get fusion strategy + GroupOperationFusion &getGroupOperationFusion() { return fusionStrategy; } + /// running the vector-based fusion + void run() { fusionStrategy.run(); } + /// get current function rewriter + IRRewriter *getRewriter() { return rewriter; } +}; +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index fb5581bb5..3de3d26a1 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -188,6 +188,21 @@ def MergeNestedForall : Pass<"merge-nested-forall"> { let dependentDialects = ["scf::SCFDialect"]; } +def CPUPhysicalRegisterPass : Pass<"CPU-physical-register-pass", "func::FuncOp"> { + let summary = "Lower operation to cpu pysical register size."; + let description = [{ + Physical register size lowering pass. + }]; + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::math::MathDialect", + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", + "::mlir::vector::VectorDialect", + "::mlir::scf::SCFDialect", + ]; +} + def FoldTensorOperation : Pass<"fold-tensor-operation"> { let summary = "Fold some tensor operation"; let description = [{ diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h new file mode 100644 index 000000000..90cca7101 --- /dev/null +++ b/include/gc/Transforms/TilingVector.h @@ -0,0 +1,658 @@ +//===- TilingVector.h - Tiling large vector to small vector -----*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef GC_PASSES_TILINGVECTOR_H +#define GC_PASSES_TILINGVECTOR_H + +#include "gc/Analysis//VectorBasedFusionAnalysis.h" +#include "gc/Analysis/TargetDescriptionAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" + +namespace mlir { +namespace gc { + +/// get fusion kind +/// Has two kind: +/// 1. OperationGroup: +/// The operation is converted into physical registers through our fusion +/// strategy. +/// 2. Operations:(TODO:) +/// The user ensures that there is no data dependency between operations, +/// and we directly convert the operations into physical register sizes. +enum CanonicalizerKind { GroupOperations, Operations }; + +/// To avoid too many parameters in function when generate for loop +struct GenerateLoopHelper { + /// anchor id + size_t anchorIdx = 0; + /// group id + size_t groupIdx = 0; + /// for loop results + ValueRange forResults; + /// for loop block + Block *forBlock; + /// loop iteration args index map + DenseMap currentLoopStateIdxMap; + /// loop iteration args + ValueRange loopIterArgs; + /// next loop anchor yield results + SmallVector nextAnchorResults; + /// next loop anchor yield results index map + DenseMap nextAnchorResultsIdxMap; + /// next loop anchor yield results original result map + DenseMap nextAnchorResultOrignalResultMap; + /// original result with next anchor result map + DenseMap orignalResultNextAnchorResultMap; + /// loop induction variables + SmallVector inductionVars; + /// original operand with loop args map + DenseMap originalOperandLoopArgsMap; + /// loop args with original operand map + DenseMap loopArgsOriginalOperandMap; + /// candidate operation queue + std::queue *candidateOps; + /// moved operation queue + std::queue *movedOps; + /// record operation's correct loop indice, due to some operation like + /// reduce may need to reorder loop indice + DenseMap> indiceLoopMap; + GenerateLoopHelper() = default; + GenerateLoopHelper(const size_t groupIdx) noexcept { + this->groupIdx = groupIdx; + } + GenerateLoopHelper(const size_t groupIdx, const size_t anchorIdx) noexcept { + this->groupIdx = groupIdx; + this->anchorIdx = anchorIdx; + } + /// clear next anchor results related data + void clearNextAnchorResults(); + /// set next anchor results related data + void setNextAnchorResults(SmallVector ¤tAnchorResults, + DenseMap ¤tResultMap, + DenseMap ¤tResultIdxMap); + /// set next anchor iteration args + void setNextAnchorArgs(DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + /// set id of for loop anchor + void setAnchorId(const size_t anchorId) noexcept; + /// Before perform processing previous operation, we need to update some data + void updateDataBeforePreOpMove(ArrayRef loopstate, + std::queue &candidateQueue, + std::queue &movedQueue); + /// After previous operation movement, we need to update some data + void updateDataAfterPreOpMove(DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + /// Before perform processing previous operation, we need to update some data + void updateDataBeforePostOpMove( + ArrayRef iterArgs, DenseMap ¤tLoopStateIdxMap, + DenseMap ¤toriginalArgsMap, + DenseMap ¤tArgsOriginalMap, ValueRange forResults, + Block *forBlock, std::queue &movedQueue, size_t anchorId); + /// After previous operation movement, we need to update some data + void updateDataAfterPostOpMove(size_t anchorId, + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + + /// update loop iteration args data + void updateCurrentArgsStatus(DenseMap ¤tArgsIdxMap, + SmallVector ¤tArgs, + DenseMap &originalArgsMap, + DenseMap &argsOriginalMap); +}; + +//===----------------------------------------------------------------------===// +// vectorize operation class +//===----------------------------------------------------------------------===// +class MultiReductionCanonicalizer; +class BroadcastCanonicalizer; +class TransposeCanonicalizer; +class ShapeCastCanonicalizer; + +// fixed extraction trait +template struct SpecialOpTraits; +template <> struct SpecialOpTraits { + using DerivedSpecialT = MultiReductionCanonicalizer; +}; +template <> struct SpecialOpTraits { + using DerivedSpecialT = BroadcastCanonicalizer; +}; +template <> struct SpecialOpTraits { + using DerivedSpecialT = TransposeCanonicalizer; +}; +template <> struct SpecialOpTraits { + using DerivedSpecialT = ShapeCastCanonicalizer; +}; + +/// base class of special operation +template class SpecialOperationCanonicalizer { + using DerivedT = typename SpecialOpTraits::DerivedSpecialT; + +private: + /// store current special operation + SmallVector candidateRdOps; + /// vectorize step + size_t vectorStep = 1; + +public: + enum class SpecialOperationKind { + OP_MultiDimReduction, + OP_Broadcast, + OP_Transpose, + OP_ShapeCast + }; + +private: + const SpecialOperationKind kind; + +public: + SpecialOperationCanonicalizer() = default; + SpecialOperationCanonicalizer(const SmallVector &candidateRdOps, + SpecialOperationKind kind) + : candidateRdOps(candidateRdOps), kind(kind) {} + SpecialOperationCanonicalizer(const SmallVector &candidateRdOps, + SpecialOperationKind kind, size_t step) + : candidateRdOps(candidateRdOps), vectorStep(step), kind(kind) {} + SmallVector &getCandidateOps(); + virtual ~SpecialOperationCanonicalizer() {} + /// call derived speical operation init information methods + void prepareSpecialOperationInfo() { + static_cast(this)->prepareSpecialInfo(); + } + /// get kind of speical operation + SpecialOperationKind getKind() noexcept { return kind; } + /// set current operation group vectorize step + void setVectorStep(size_t step) noexcept { vectorStep = step; } + /// get current operation group vectorize step + size_t getVectorStep() noexcept { return vectorStep; } +}; + +enum class MultiReduceOpAxisKind { Reduction, Parallel }; +/// Help to vectorize reduction operation +class MultiReductionCanonicalizer + : public SpecialOperationCanonicalizer { +private: + /// reduction parallel axis and reduction axis + SmallVector reductionAxis, parallelAxis; + /// operations before reduction operation and operations after reduction + /// operation + std::queue prevOps, postOps, accRelatedOps, sourceRelatedOps; + bool haslastDimReduction = false; + bool isStandaloneOp = false; + /// empty reduction means that all the reduction axis is 1 + bool isEmptyReduction = true; + /// vector type rank + int64_t typeRank = -1; + /// record original operation result + SetVector originalOpResults; + /// vector type of source operation and accumulate operation + VectorType sourceType, accType; + /// for loop yield result index map + llvm::SmallDenseMap resultIdxMap; + +public: + MultiReductionCanonicalizer( + const SmallVector &candidateRdOps, + size_t steps = 1) + : SpecialOperationCanonicalizer( + candidateRdOps, SpecialOperationKind::OP_MultiDimReduction, steps) { + isStandaloneOp = candidateRdOps.size() == 1; + }; + virtual ~MultiReductionCanonicalizer() noexcept {}; + /// get reduction vector type, we use source operation type as reduction + /// vector type + int64_t getTypeRank(); + /// get reduction operation reduction and parallel axis + void getReductionAxisAndParallelAxis(); + /// whether last dim is reduction axis + bool hasLastDimReduction(); + /// whether only reduction operation in current operation group + bool getIsStandAloneOp() noexcept { return isStandaloneOp; } + /// get whether last dim is reduction axis + bool getHasLastDimReduction() noexcept { return haslastDimReduction; } + /// initialize to get reduction axis + void initReductionAxis(); + /// initialize to get parallel axis + void initParallelAxis(); + /// get reduction axis + SmallVector &getReductionAxis() noexcept { + return reductionAxis; + }; + /// get parallel axis + SmallVector &getParallelAxis() noexcept { return parallelAxis; }; + /// get prev operation in current operation group + std::queue &getPrevOps() noexcept { return prevOps; } + /// get post operation in current operation group + std::queue &getPostOps() noexcept { return postOps; } + /// get accumulate operation in reduction operation + std::queue &getAccRelatedOps() noexcept { return accRelatedOps; } + /// get source operation in reduction operation + std::queue &getSourceRelatedOps() noexcept { + return sourceRelatedOps; + } + /// get reduction operation original result + SetVector &getOriginalOpResults() noexcept { + return originalOpResults; + } + /// get source operation vector type + VectorType getSourceType() noexcept { return sourceType; }; + /// get accumulate operation vector type + VectorType getAccType() noexcept { return accType; }; + /// get result index map + llvm::SmallDenseMap &getResultIdxMap() noexcept { + return resultIdxMap; + } + /// set result index map + void setResultIdxMap(const llvm::SmallDenseMap &map) { + resultIdxMap = map; + } + + /// initalize parallel, reduction axis, reduction operation type and whether + /// last dim is reduction axis + void prepareSpecialInfo(); + + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == + SpecialOperationKind::OP_MultiDimReduction; + } +}; + +class BroadcastCanonicalizer + : public SpecialOperationCanonicalizer { +private: +public: + BroadcastCanonicalizer( + const SmallVector &candidateBcOps, + size_t steps = 1) + : SpecialOperationCanonicalizer( + candidateBcOps, SpecialOperationKind::OP_Broadcast, steps){}; + virtual ~BroadcastCanonicalizer() noexcept {} + void prepareSpecialInfo(){}; + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == SpecialOperationKind::OP_Broadcast; + } +}; + +class TransposeCanonicalizer + : public SpecialOperationCanonicalizer { +private: + /// first and second transpose axis + size_t firstTpIdx = 0, secondTpIdx = 0; + +public: + TransposeCanonicalizer( + const llvm::SmallVector &candidateTpOps, + size_t steps = 1) + : SpecialOperationCanonicalizer( + candidateTpOps, SpecialOperationKind::OP_Transpose, steps){}; + virtual ~TransposeCanonicalizer() noexcept {} + void prepareSpecialInfo(){}; + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; + } + enum TRANSPOSE_KERNEL { + KERNEL_16X16 = 16, + }; + /// get first transpose axis + size_t getFirstTpIdx() noexcept { return firstTpIdx; } + /// get second transpose axis + size_t getSecondTpIdx() noexcept { return secondTpIdx; } + /// whether transpose on two dimensions + bool isTwoDTranspose(); + /// whether transpose on all dimension size is one + bool isTransposeOnAllOneDim(); + /// whether transpose on last dimension + bool transposeOnLastDim(); +}; + +class ShapeCastCanonicalizer + : public SpecialOperationCanonicalizer { +public: + ShapeCastCanonicalizer( + const SmallVector &candidateScOps, + size_t steps = 1) + : SpecialOperationCanonicalizer( + candidateScOps, SpecialOperationKind::OP_ShapeCast, steps){}; + virtual ~ShapeCastCanonicalizer() {} + void prepareSpecialInfo() {} + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == SpecialOperationKind::OP_ShapeCast; + } + /// whether store and load on last dimension + bool isReadWriteOnLastDim(); +}; + +/// generate for loop for each operation. +class ForLoopGenerator { +private: + GroupOperationFusion vectorBasedFusion; + IRRewriter *rewriter; + +public: + ForLoopGenerator(GroupOperationFusion &fusion, IRRewriter *rewriter) + : vectorBasedFusion(fusion, rewriter), rewriter(rewriter) {} + + virtual ~ForLoopGenerator() noexcept {} + + IRRewriter *getRewriter() noexcept { return rewriter; } + + void setVectorBaseFusion(GroupOperationFusion &vectorBasedFusion) { + this->vectorBasedFusion = vectorBasedFusion; + }; + + /// clear current group operation + void clearCurrentOperationGroup(size_t grpIdx); + + /// prepare for loop iteration args + void prepareForLoopArgs(const size_t grpIdx, GenerateLoopHelper &loopHelper); + + /// replace original operation result with corresponding for loop result + void replaceOpUsersWithForLoopResult( + scf::ForOp forOp, int grpIdx, SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap); + + /// mark which operation need to set correct for loop var idx + /// due to sometimes we need to chage for loop order like reduce operation. + void getCurrentGroupIndiceLoopMap( + DenseMap> &indiceLoopMap, + const size_t groupId, Operation *op, + const DenseMap &setIdxMap = DenseMap({})); + + // get methods + GroupOperationFusion &getVectorBasedFusion() noexcept { + return vectorBasedFusion; + } + /// rewrite operation as vectorize IR in current operation group + void + rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, + const std::queue *queue = nullptr, + const size_t vectorizeStep = 0); + /// Reimplementation of writing a tensor from a constant of denseElementattr. + void createNewConstantOp(Operation *srcOp, + vector::TransferWriteOp *transferWriteOp, + size_t groupSteps); + // Generate elementwise operation for loop + mlir::FailureOr + generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, + const VectorType vectorType); + scf::ForOp constructNestedForOp(const size_t groupIdx, OpBuilder &b, + const Location &loc, ArrayRef dims, + GenerateLoopHelper &loopHelper); + /// move operations in \param queue to current loop anchor + void moveOperationsToCurrentForBody(const OpBuilder &b, + std::queue &queue, + GenerateLoopHelper &loopHelperParam); + + /// Set correct operand with loop args for the operation + void setOperationCorrectOperand( + Operation *op, const DenseMap &opPermuationMap, + GenerateLoopHelper &loopHelperParam); + + /// Get current anchor return retults + void getResultInCurrentOps(const size_t anchorIdx, const size_t groupId, + const std::queue &ops, + SmallVector &results, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap); + /// Get next anchor's iteration loop args + void getInitArgsToNextAnchor(llvm::DenseMap &nextAnchorArgsIdxMap, + llvm::SmallVector &nextAnchorArgs, + GenerateLoopHelper &loopHelperParam); + /// Get operation should appear in current loop anchor + void getOperationInCurrentAnchor(const size_t anchorIdx, + std::queue &fromQueue, + std::queue &toQueue); + /// Get current loop operation result + void generateLoopResults(OpBuilder &b, const Location &loc, + GenerateLoopHelper &loopHelperParam, + DenseMap &nextOperandIdxMap); + + /// Move post operations in current operation group to the for loop body + void movePostOpToCurrentAnchor(OpBuilder &b, + GenerateLoopHelper &loopHelperParam); + + /// Move previous operations in current operation group to the for loop body + void movePreOpToCurrentAnchor(OpBuilder &b, + DenseMap &nextLoopStateIdxMap, + SmallVector &nextAnchorArgs, + GenerateLoopHelper &loopHelperParam); + + /// replace moved operation result used by current post operations with for + /// loop result + void replaceOperationsWithForLoopResult( + IRRewriter &rewrite, const std::queue &movingOperations, + GenerateLoopHelper &loopHelperParam); + + /// rectify indice for transfer_write operation + /// e.g.: vector.transfer_write"(%16, %9, %c0, %c0), the first %c0 should use + /// original indice not create by us + void rectifyWriteOperationIndice(vector::TransferWriteOp *originalWriteOp, + SmallVectorImpl &writeVars); + /// rectify indice for transfer_read operation, like broadcast operation + /// fusion by transfer_read , but the transfer_read operation is in innermost + /// for loop body, we must set correct for loop var. e.g.: + /// vector.transfer_read"(%16, %9, %c0), the first %c0 should use correct for + /// innermost loop iter vars + void rectifyReadOperationIndice(vector::TransferReadOp *originalReadOp, + VectorType loopType, + ArrayRef inductionVars, + SmallVectorImpl &readVars); + + /// rectify each group operand use for loop result + void rectifyGroupOperands(size_t currentGroupId, Value originalResult, + Value forResult); +}; + +class LoopGeneratorImpl : public ForLoopGenerator { + +private: + SmallVector multiRdCanonicalizers; + SmallVector broadcastCanonicalizers; + SmallVector transposeCanonicalizers; + SmallVector shapeCastCanonicalizers; + +public: + LoopGeneratorImpl(GroupOperationFusion &fusion, IRRewriter *rewriter) + : ForLoopGenerator(fusion, rewriter){}; + + virtual ~LoopGeneratorImpl() noexcept {}; + + SmallVector & + getMultiRdCanonicalizers() noexcept { + return multiRdCanonicalizers; + } + + SmallVector & + getBroadcastCanonicalizers() noexcept { + return broadcastCanonicalizers; + } + + SmallVector & + getTransposeCanonicalizers() noexcept { + return transposeCanonicalizers; + } + + SmallVector & + getShapeCastCanonicalizers() noexcept { + return shapeCastCanonicalizers; + } + /// clear special operation canonicalizer container + void clearSpecialOperationCanonicalizers(); + + /// add a dummy special canonicalizer + void dummyInitSpecialOperation(size_t steps); + + /// initialize all the speical operation canonicalizer + void initSpeicalOperationCanonicalizers(); + + /// generate for loop for current special operation use \param generateFunc + template + void processSpecialOperation( + T &canonicalizers, const std::function &generateFunc); + // Canonicalize special operation + void canonicalizeSpecialOperation(); + + /// whether \param grpIdx operation group has special operation + bool isGroupHasSpecialOperation(const size_t grpIdx); + + // multireduction forloop methods + scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); + + /// reduction operation reduction axis for loop + scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, + const size_t reductionIdx, + GenerateLoopHelper &loopHelperParam); + void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, Location loc); + /// reduction operation parallel axis for loop + scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder, + GenerateLoopHelper &loopHelperParam); + /// ensure accumulate operation appear in parallel loop, inorder to have + /// correct reduce fusion + void ensureAccInParallelLoop(GenerateLoopHelper &loopHelperParam, + ArrayRef parallelAxis, + Value multiReductionAcc, + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + + /// Rearrange the current opIR to facilitate the generation of the correct + /// reduction IR + void rearrageMultiReductionIR( + const size_t grpIdx, + DenseMap> &indiceLoopMap); + + /// generate for loop for transpose operation + scf::ForOp generateTransposeForLoop(const size_t grpIdx); + /// shuffle instruction optimize for transpose operation + scf::ForOp generateTransposeForLoopWithLastDim( + OpBuilder &opBuilder, const int tpSteps, const Location &loc, + Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam); + + /// generate transpose operation for loop of simple data movement + scf::ForOp + generateTransposeScalarDataMovement(OpBuilder &opBuilder, const Location &loc, + DenseMap &tpAxisMap, + GenerateLoopHelper &loopHelperParam); + + /// generate shapecast operation for loop + scf::ForOp generateShapeCastForLoop(const size_t grpIdx); + /// generate simple data movement for loop + scf::ForOp generateShapeCastReadWriteLoop( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const size_t steps, const Location &loc, + SmallVector &inductionVars, ValueRange iterArgs); + + /// vectorize operations in current operation group + void generateGroupOpVectorizedIR(const int idx); +}; + +/// group operation fusion implementation class +class GroupOperationFusionImpl : public GroupOperationAnalysis { +private: + /// In which tensor is the result of the source operation stored, and the + /// result of transfer_write. + DenseMap> srcOpCanoniclizedMap; + /// have visited operations + DenseMap visitedOperation; + +public: + virtual ~GroupOperationFusionImpl() = default; + GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info, + IRRewriter *rewriter) + : GroupOperationAnalysis(func, info, rewriter) {} + + void broadcastFromElements(Operation *op, size_t grpIdx); + void scalarOperandFromElements(); + + /// Generate emtpy tensor and write operations for operations that need to + /// return their results, and generate read operations for operations that + /// need to read parameters from the block. + void canonicalizeEachOperationGroup(); + + void specialOperationRectify(DenseMap &visitedOperation); + /// update operation result kind + void updateReturnResultKind(Operation *sourceOp, size_t sourceOpGid, + ReturnTypeKind rtKind); + + /// process the operation which need to return result + /// \param *op current operation + void GroupOperationReturnResultProcess(size_t sourceOpGid, + Operation *sourceOp, Operation *op, + size_t operandIdx, + bool inSameGroupNeedReturn); + /// source operation write it's result to a tensor + void makeSourceOpWriteResultToTensor(Operation *sourceOp, size_t sourceOpGid, + ReturnTypeKind rtKind); + /// analysis constant operation and replace it with a new constant operation + void replaceConstantOpAsNewOp(Operation *op, Operation *sourceOp, + size_t operandIdx); + /// replace \param op in \param grpIdx operation group with \param replacedOp + void removeOpInCurrentGroups(size_t grpIdx, Operation *op, + Operation *replacedOp); + /// update operation in grpIdx group related information + void updateOpGroupInfo(size_t grpIdx); + /// make a transfer_read operation and read the producer operation result + Value + canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, + size_t operandIdx, + vector::TransferReadOp *srcReadOp = nullptr); + /// update \param opGid operation group + void updateOpOperandResultInGroups(size_t opGid, Operation *op, + const Value &init = Value(), + const Value &result = Value()); + + /// make emtpy tensor and write the operation result to the tensor + void generateEmptyTensorAndWrite( + Operation *sourceOp, + llvm::DenseMap> + &srcOpCanoniclizedMap, + size_t anchorPos, ReturnTypeKind retKind, + DenseMap &visitedOperation); + + /// make a transfer_read operation + Operation * + createTransferReadOpBefore(Operation *op, const Value &operand, + vector::TransferReadOp *srcReadOp = nullptr); +}; +/// Vectorize vector operation with target machines max simd length. +class VectorOperationCanonicalizer { +private: + GroupOperationFusionImpl fusion; + LoopGeneratorImpl loopGenerator; + CanonicalizerKind kind; + func::FuncOp func; + IRRewriter *rewriter; + +public: + VectorOperationCanonicalizer( + func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter, + CanonicalizerKind kind = CanonicalizerKind::GroupOperations) + : fusion(func, info, rewriter), + loopGenerator(fusion.getGroupOperationFusion(), rewriter), kind(kind), + rewriter(rewriter) {} + virtual ~VectorOperationCanonicalizer() = default; + /// run the vector canonicalizer for the IR + void run(); + /// get current funtion rewriter + IRRewriter *getRewriter() noexcept { return rewriter; } +}; +} // namespace gc +} // namespace mlir +#endif \ No newline at end of file diff --git a/include/gc/Transforms/Utils/NumericUtils.h b/include/gc/Transforms/Utils/NumericUtils.h new file mode 100644 index 000000000..f47d9dace --- /dev/null +++ b/include/gc/Transforms/Utils/NumericUtils.h @@ -0,0 +1,34 @@ +//===-- NumericUtils.h - numeric utilities ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_TRANSFORMS_UTILS_NUMERICUTILS_H +#define GC_TRANSFORMS_UTILS_NUMERICUTILS_H +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include +#include +#include + +namespace mlir { +namespace gc { + +union Float32Bits { + uint32_t u; + float f; +}; +uint16_t float2half(float floatValue); +float half2float(uint16_t halfValue); +uint16_t float2bfloat(float floatValue); +float bfloat2float(uint16_t bfloatBits); +std::variant numeric_limits_minimum(Type type); +std::variant numericLimitsMaximum(Type type); + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h new file mode 100644 index 000000000..89341bb82 --- /dev/null +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -0,0 +1,186 @@ +//===-- VectorUtils.h - vector utilities ------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H +#define GC_TRANSFORMS_UTILS_VECTORUTILS_H +#include "gc/Transforms/Utils/NumericUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include +#include +#include +#include + +namespace mlir { +namespace gc { + +enum class OPPRIORITY : uint8_t { + FIRST = 0, + SECOND, + THIRD, + LAST, + OTHERS = 255, +}; +/// Need to move some operations like extract_slice or insert_slice. +/// Because those operation may interpret our analysis result. e.g.: +/// ``` +/// clang-format off +/// %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : +/// tensor<16x16xf32>, vector<16x16xf32> %22 = arith.addf %21, %20 : +/// vector<16x16xf32> %23 = vector.transfer_write %22, %extracted_slice_12[%c0, +/// %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> +/// %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] +/// [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> %extracted_slice_14 = +/// tensor.extract_slice %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<32x16xf32> +/// to tensor<16x16xf32> %24 = vector.transfer_read %cst_0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> %25 = +/// arith.maximumf %22, %24 : vector<16x16xf32> %26 = vector.transfer_write %25, +/// %extracted_slice_14[%c0, %c0] {in_bounds = [true, true]} : +/// vector<16x16xf32>, tensor<16x16xf32> %inserted_slice_15 = +/// tensor.insert_slice %23 into %arg15[%arg13, 0] [16, 16] [1, 1] : +/// tensor<16x16xf32> into tensor<32x16xf32> %inserted_slice_16 = +/// tensor.insert_slice %26 into %arg16[%arg13, 0] [16, 16] [1, 1] : +/// tensor<16x16xf32> into tensor<32x16xf32> clang-format on +/// ``` +/// The maximumf and addf operation can be a same group, but the extract_slice +/// operation interpret us. +/// The move operation(extra_slice) will check its parameters. In order to +/// ensure that it does not affect the correctness of the result, we will only +/// move the moved op after the op to which the parameters belong to. If it's +/// operand is all the block argument, we will move it to the begining of the +/// block. +/// insert_slice just move them to the privious of the first operation which +/// use it. +void moveOpsFrontOrBack(func::FuncOp *func, IRRewriter &rewriter, + OPPRIORITY start, OPPRIORITY end); + +/// build a constant operation of index type +Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, + int64_t x); + +/// find the original tensor +Value findOriginalTensor(Value writeTensor, Block *block); +/// get operation read or write tensor +mlir::FailureOr getOperationOperateTensor(Operation *op); + +/// set correct operand for the operation +void setOperationCorrectOperand( + Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, + ArrayRef inductionVars, + DenseMap &opPermuationMap); + +/// Get vector type of the operation \param op +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp = true); + +/// select nearest even step +int getNearestVectorStep(const int step); + +/// get operation vector type +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationMaxVectorType(Operation *op); + +template +T getInitValForReduce(vector::CombiningKind kind, Type t) { + T result; + Type t1 = getElementTypeOrSelf(t); + + switch (kind) { + case vector::CombiningKind::ADD: + if (t1.isIntOrIndex()) + result = 0; + else if (isa(t1)) + result = 0.0f; + else + llvm_unreachable("invalid value types for ADD reduction"); + break; + case vector::CombiningKind::MAXNUMF: + case vector::CombiningKind::MAXIMUMF: + if (not isa(t1)) + llvm_unreachable("Expected float values."); + result = std::get(numeric_limits_minimum(t)); + break; + case vector::CombiningKind::MINNUMF: + case vector::CombiningKind::MINIMUMF: + if (not isa(t1)) + llvm_unreachable("Expected float values."); + result = std::get(numericLimitsMaximum(t)); + break; + case vector::CombiningKind::MAXSI: + case vector::CombiningKind::MAXUI: + if (not t1.isIntOrIndex()) + llvm_unreachable("Expected int or index values."); + result = std::get(numeric_limits_minimum(t)); + break; + case vector::CombiningKind::MINSI: + case vector::CombiningKind::MINUI: + if (not t1.isIntOrIndex()) + llvm_unreachable("Expected int or index values."); + result = std::get(numericLimitsMaximum(t)); + break; + case vector::CombiningKind::MUL: + if (t1.isIntOrIndex()) + result = 1; + else if (isa(t1)) + result = 1.f; + else + llvm_unreachable("invalid value types for MUL reduction"); + break; + default: + llvm_unreachable("unsupported reduction kind"); + }; + return result; +} + +template +void getSameBlockTargetOp(Operation *op, + std::queue &candidateOps) { + if (isa(op)) { + candidateOps.push(op); + return; + } + auto getSameBlockSrcOp = [](Operation *trackSrcOp, + std::queue &trackOps, + std::queue &candidateOps) { + for (Value opd : trackSrcOp->getOperands()) { + if (isa(opd) or + opd.getDefiningOp()->getBlock() != trackSrcOp->getBlock()) + continue; + if (isa(opd.getDefiningOp())) + candidateOps.push(opd.getDefiningOp()); + else + trackOps.push(opd.getDefiningOp()); + } + }; + + std::queue trackOps; + getSameBlockSrcOp(op, trackOps, candidateOps); + while (not trackOps.empty()) { + Operation *cadidateOp = trackOps.front(); + trackOps.pop(); + getSameBlockSrcOp(cadidateOp, trackOps, candidateOps); + } +} + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index d7160f350..b11eb3bd4 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp MatmulConfigAnalysis.cpp + VectorBasedFusionAnalysis.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp new file mode 100644 index 000000000..23ff5faa5 --- /dev/null +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -0,0 +1,598 @@ +//===- VectorBasedFusionAnalysis.cpp - analysis vector ops ------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Analysis/VectorBasedFusionAnalysis.h" + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "vector-operation-analysis" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") + +#define ARITH_CAST_OPERATIONS \ + arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::BitcastOp, \ + arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ + arith::TruncFOp, arith::TruncIOp + +#define NOT_NEED_TO_PROCESS_OP \ + linalg::BatchReduceMatmulOp, linalg::MatmulOp, linalg::BatchMatmulOp, \ + linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ + linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ + tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \ + tensor::InsertSliceOp, microkernel::BrgemmOp + +static inline bool isNotNeedToProcessOp(Operation *op) { + return isa(op) or linalgx::isMatmulOp(op); +} + +static inline bool isSpecialOp(Operation *op) { + return isa( + op); +} + +static inline bool isReadOrWriteOperation(Operation *op) { + return isa(op); +} + +/// which axis do the shape cast in source shape a +void shapeCastSourceAxis(ArrayRef a, ArrayRef b, + SmallVector &res) { + unsigned rankA = a.size(); + unsigned rankB = b.size(); + if (rankA >= rankB) + llvm_unreachable("May be invalid shape cast operation."); + + auto isOne = [](int64_t v) { return v == 1; }; + + // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape + // casted to a 0-d vector. + if (rankA == 0 && all_of(b, isOne)) { + for (size_t i = 0; i < a.size(); i++) + res.emplace_back(i); + return; + } + + unsigned i = 0; + unsigned j = 0; + while (i < rankA && j < rankB) { + int64_t dimA = a[i]; + int64_t dimB = 1; + int64_t bAxisBegin = j; + while (dimB < dimA && j < rankB) + dimB *= b[j++]; + if (dimA != dimB) { + llvm_unreachable(" Invalid shape cast operation."); + break; + } + if (bAxisBegin != j) { + res.emplace_back(i); + } + ++i; + + // Handle the case when trailing dimensions are of size 1. + // Include them into the contiguous sequence. + if (i < rankA && all_of(a.slice(i), isOne)) + i = rankA; + if (j < rankB && all_of(b.slice(j), isOne)) + j = rankB; + } + if (i != rankA or j != rankB) + llvm_unreachable("Invalid shapecast operation."); +} + +bool isScalar(Type type) { + if (not type) + llvm_unreachable("Not a valid type"); + if (auto vecType = dyn_cast(type)) + return false; + if (auto tensorType = dyn_cast(type)) + return false; + return true; +} + +void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, + SmallVector &bcAxis) { + auto inputShape = input.getShape(); + auto outputShape = output.getShape(); + // following auto_broadcast semantics + const size_t input_rank = inputShape.size(); + const size_t output_rank = outputShape.size(); + if (output_rank < input_rank) + llvm_unreachable("Incorrect input or output shape for broadcast op."); + const size_t offset = output_rank - input_rank; + for (size_t i = 0; i < input_rank; ++i) { + if (inputShape[i] == outputShape[i + offset] || + (ShapedType::isDynamic(inputShape[i]) && + ShapedType::isDynamic(outputShape[i + offset]))) { + bcAxis.emplace_back(i); + } + } + if (bcAxis.empty()) + bcAxis.emplace_back(-1); +} + +void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { + return TypeSwitch(op) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + auto rdDimsRange = multiReductionOp.getReductionDims(); + dataAxis.assign(rdDimsRange.begin(), rdDimsRange.end()); + return; + }) + .Case([&](vector::ShapeCastOp shapeCastOp) { + auto srcType = shapeCastOp.getSourceVectorType(); + auto dstType = shapeCastOp.getResultVectorType(); + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + if (srcShape.size() < dstShape.size()) + shapeCastSourceAxis(srcShape, dstShape, dataAxis); + else + shapeCastSourceAxis(dstShape, srcShape, dataAxis); + + return; + }) + .Case([&](vector::BroadcastOp broadcastOp) { + auto srcType = broadcastOp.getSourceType(); + auto dstType = broadcastOp.getResultVectorType(); + if (isScalar(srcType)) { + dataAxis.emplace_back(0); + } else { + auto inputType = mlir::cast(srcType); + auto outputType = mlir::cast(dstType); + getSrcBroadcastDim(inputType, outputType, dataAxis); + } + return; + }) + .Case([&](vector::TransposeOp transposeOp) { + auto perm = transposeOp.getPermutation(); + int start = 0; + for (auto x : perm) { + if (x != start) { + dataAxis.emplace_back(x); + } + start++; + } + return; + }) + .Default([&](Operation *op) { + // default is last axis + dataAxis.emplace_back( + cast(op->getResultTypes()[0]).getRank() - 1); + return; + }); +} + +static inline bool hasSameAxis(ArrayRef dims1, + ArrayRef dims2) { + DenseSet checkSet(dims2.begin(), dims2.end()); + return llvm::any_of(dims1, + [&checkSet](int64_t x) { return checkSet.contains(x); }); +} + +/// whether op2 use op1 result +/// Currently we just enable this function for write and read operation +template || + std::is_same_v, + T>> +static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { + return llvm::any_of(op2->getOperands(), + [&op1](Value opd) { return opd.getDefiningOp() == op1; }); +} + +/// whether two operation has data dependency +/// op1 default is previous operation, op2 default is current operation +bool hasDataDependency(Operation *op1, Operation *op2) { + if (!isSpecialOp(op1) and !isSpecialOp(op2)) + return false; + + if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { + // if op1 is read the value and pass it to op2, it is not data dependency + if (isOperationsHasDefUseRelation(op1, op2)) + return false; + } + + // broadcast only fuse with post-op + if (isa(op2)) + return true; + + // only special operation may cause data dependency + if (!isSpecialOp(op1)) + return hasDataDependency(op2, op1); + + auto res = + TypeSwitch(op1) + .Case([&](vector::ShapeCastOp shapeCastOp) { + SmallVector dims1, dims2; + getOperationDataAxis(op1, dims1); + getOperationDataAxis(op2, dims2); + if (!isSpecialOp(op2)) + return hasSameAxis(dims1, dims2); + + return true; + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + SmallVector dims2, reductionDims, parallelDims; + getOperationDataAxis(op1, reductionDims); + getOperationDataAxis(op2, dims2); + DenseSet checkSet(dims2.begin(), dims2.end()); + auto op2VectorType = getOperationVectorType(op2); + if (!isSpecialOp(op2)) { + // all reduction axis should be op2's data axis + bool reduceDependent = false; + for (auto x : reductionDims) { + if (!checkSet.contains(x)) { + reduceDependent = true; + break; + } + } + if (!reduceDependent) + return false; + + // all parallel axis should equal to op2's axis + checkSet.clear(); + checkSet.insert(reductionDims.begin(), reductionDims.end()); + auto rdRank = + multiReductionOp.getSourceVectorType().getRank(); + for (auto i = 0; i < rdRank; i++) + if (not checkSet.contains(i)) + parallelDims.emplace_back(i); + + checkSet.clear(); + checkSet.insert(parallelDims.begin(), parallelDims.end()); + auto rank = op2VectorType->getRank(); + for (auto i = 0; i < rank; i++) + if (!checkSet.contains(i)) + return true; + + return false; + } + + return true; + }) + .Case([&](vector::BroadcastOp broadcastOp) { + if (isSpecialOp(op2)) + return true; + + return !OpTrait::util::staticallyKnownBroadcastable( + getOperationVectorType(op1, false)->getShape(), + getOperationVectorType(op2)->getShape()); + }) + .Case( + [&](vector::TransposeOp transposeOp) { return true; }) + .Default([&](Operation *op) { return false; }); + + return res; +} + +/// Get the operation which is not a read-write in current queue +/// \param [in, out] op +Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { + Operation *op = nullptr; + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + if (isReadOrWriteOperation(cur)) + continue; + + op = cur; + } + return op; +} + +/// operation should not contain for loop +bool is_innermost_operation(Operation *op) { + bool inner_most = true; + op->walk([&inner_most](Operation *p) { + if (isa(p)) { + inner_most = false; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return inner_most; +} + +/// whether operate on last dimension +bool isLastDim(const AffineExpr &expr, const size_t rank) { + return isa(expr) && + dyn_cast(expr).getPosition() == rank - 1; +} + +bool isReadWriteOnLastDim(Operation *op) { + if (isReadOrWriteOperation(op)) { + AffineMap permutationMap = + dyn_cast(op) + ? cast(op).getPermutationMap() + : cast(op).getPermutationMap(); + int64_t rank = + dyn_cast(op) + ? cast(op->getOperand(0).getType()).getRank() + : cast(op->getOperand(1).getType()).getRank(); + ArrayRef dimExpr = permutationMap.getResults(); + bool find = false; + for (const auto &expr : dimExpr) + if (isLastDim(expr, rank)) { + find = true; + break; + } + + return find; + } + llvm::llvm_unreachable_internal( + "The operation is not a read or write operation."); + return false; +} + +// Filter out the operations that can be vectorized. We are only interested in +// operations that do not contain any for loops(innermost IR). +[[nodiscard]] bool filterOperation(Operation *op) { + if (!is_innermost_operation(op)) { + return false; + } + + // We are only interested about the operation in vector dialect + if (failed(getOperationVectorType(op))) { + return false; + } + + // We don't need to vectorize the constant operation + if (isa(op)) { + return false; + } + + if (isReadOrWriteOperation(op) and !isReadWriteOnLastDim(op)) { + return false; + } + + return true; +} + +VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { + // Check that the operation type can be broken + // down into a loop. + mlir::FailureOr baseType = getOperationVectorType(op); + if (failed(baseType)) { + llvm_unreachable("Failed to get vector type for operation"); + return VectorType(); + } + auto vectorizedType = baseType.value(); + if (loopStep == 0) + loopStep = getDataTypeValidSteps(vectorizedType); + + return VectorType::get({loopStep}, vectorizedType.getElementType()); +} + +int TypeHelper::generateValidSteps(int steps, VectorType type, int shapeDim) { + if (shapeDim & 1) + return 1; + auto typebits = type.getElementTypeBitWidth(); + if (shapeDim >= steps) + return steps * typebits >= 128 ? steps : 1; + int evenStep = getNearestVectorStep(shapeDim); + return evenStep * typebits >= 128 ? evenStep : 1; +} + +int TypeHelper::generateValidSteps(int steps, VectorType type) { + // TODO: support odd shape using mask load store + if (type.getShape().back() & 1) + return 1; + auto typebits = type.getElementTypeBitWidth(); + if (type.getShape().back() >= steps) + return steps * typebits >= 128 ? steps : 1; + int evenStep = getNearestVectorStep(type.getShape().back()); + return evenStep * typebits >= 128 ? evenStep : 1; +} + +// Get the maximum number of current data types that a register can hold +[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(VectorType type) { + auto typebits = type.getElementTypeBitWidth(); + return info.vectorWidth / typebits; +} + +/// Get a appropriate for loop step for current vector type +[[nodiscard]] int TypeHelper::getDataTypeValidSteps(VectorType type) { + return generateValidSteps(getDataTypeMAXSIMDLength(type), type); +} + +/// default op1 is previous operation +bool GroupOperationFusion::isCompatibleVectorType(Operation *op1, + Operation *op2) { + // only lower to vector pass can produce read operation. In general two read + // operation is compatible + if (isa(op1) and isa(op2)) { + return true; + } + + mlir::FailureOr type1 = getOperationVectorType(op1, true); + mlir::FailureOr type2 = getOperationVectorType(op2, false); + // some operation has two different operands type like multireduction, we need + // to check whether compitable with accumulate vector + VectorType suppleType; + if (failed(type1) || failed(type2)) + return false; + + auto sp1 = type1.value(); + auto sp2 = type2.value(); + + auto isCompatible = [](VectorType sp1, VectorType sp2) { + bool isCompatible = true; + auto min_rank = std::min(sp1.getRank(), sp2.getRank()); + // from front to back + for (long i = 0; i < min_rank; i++) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + isCompatible = false; + break; + } + } + return isCompatible; + }; + + bool result; + result = isCompatible(sp1, sp2); + // operand check only happen on later operation is op2 + // TODO: may need to support other similar operation like multireduction has + // two different operands type + if (isa(op2)) { + suppleType = cast(op2->getOperandTypes()[1]); + result |= isCompatible(suppleType, sp1); + } + + return result; +} + +void GroupOperationFusion::updateGroupBigestVectorType(VectorType vectorType) { + int64_t rank = vectorType.getRank(); + llvm::SmallDenseMap &groupVectorType = + getGroupBiggestRankVectorType(); + + if (groupVectorType.contains(opGroups.size() - 1)) { + VectorType bigestType = groupVectorType[opGroups.size() - 1]; + if (bigestType.getRank() < rank) + groupVectorType[opGroups.size() - 1] = vectorType; + + return; + } + + groupVectorType[opGroups.size() - 1] = vectorType; +} + +void GroupOperationFusion::addOperationToGroup(Operation *op) { + if (not op) + llvm_unreachable("Op can't be NULL."); + VectorType vectorType = getOperationMaxVectorType(op).value(); + if (isNeedNewGroup(op)) + opGroups.emplace_back(std::queue()); + + if (not isa(op)) { + updateGroupBigestVectorType(vectorType); + while (not notNeedToJudgeOps.empty()) { + auto cur = notNeedToJudgeOps.front(); + notNeedToJudgeOps.pop(); + opGroupIndexMap[cur] = opGroups.size() - 1; + opGroups.back().push(cur); + } + opGroups.back().push(op); + opGroupIndexMap[op] = opGroups.size() - 1; + } + opAnchorPos[op] = getOperationMaxVectorType(op)->getRank() - 1; +} + +// We classify the operations we are interested in after filtering. Operations +// of in the same group have no data dependencies. Those operations can generate +// a same outter for loop. +void GroupOperationFusion::classifyOperations() { + // dummpy + if (opGroups.empty()) + opGroups.emplace_back(std::queue()); + + func::FuncOp func = getFunction(); + + func->walk([&](Operation *op) { + if (filterOperation(op)) { + addOperationToGroup(op); + return WalkResult::advance(); + } + if (isNotNeedToProcessOp(op) and !opGroups.back().empty()) + opGroups.emplace_back(std::queue()); + + return WalkResult::advance(); + }); + // init operations results and initialization args + groupOpResults.clear(); + groupOpInitArgs.clear(); + for (size_t i = 0; i < opGroups.size(); i++) { + groupOpResults.emplace_back( + llvm::MapVector>()); + groupOpInitArgs.emplace_back(SetVector()); + } +} + +void GroupOperationFusion::run() { classifyOperations(); } + +bool GroupOperationFusion::isNeedNewGroup(Operation *op) { + if (isa(op)) { + notNeedToJudgeOps.push(op); + return false; + } + // 1. check previous operation + if (!opGroups.back().empty()) { + // We only care about the calculation operation. + std::queue tmpQ(opGroups.back()); + Operation *prevOp = nullptr; + prevOp = getNotReadWriteOperaiton(tmpQ); + if (!prevOp) { + // if previous operation is not in the same block, we need to create a + // group + return opGroups.back().back()->getParentOp() != op->getParentOp() or + isSpecialOp(op); + } + + if (prevOp->getParentOp() != op->getParentOp()) + return true; + + // special operation need to check data dependency axis + if (hasDataDependency(prevOp, op)) + return true; + + // previous operation vector type is not compatible with current operation + if (!isCompatibleVectorType(prevOp, op)) + return true; + } + return false; +} + +void GroupOperationAnalysis::analysisEmptyGroup() { + SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + SmallVector>, 8> + &groupOpResults = fusionStrategy.getGroupOpResults(); + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + if (grp.empty()) + continue; + if (groupOpResults[idx].empty()) + std::queue().swap(grp); + } +} + +void GroupOperationAnalysis::analysisGroupMaxSteps() { + auto &opGroups = fusionStrategy.getOpGroups(); + + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + + uint32_t steps = std::numeric_limits::max(); + + llvm::SmallVector &grpSteps = + fusionStrategy.getGroupMaxSteps(); + while (idx + 1 > grpSteps.size()) + grpSteps.emplace_back(steps); + + std::queue tmpQueue(grp); + auto calculateOpSteps = [&](Type type) { + auto opType = dyn_cast(type); + if (opType) + steps = std::min(steps, (uint32_t)fusionStrategy.getTypeHelper() + .getDataTypeValidSteps(opType)); + }; + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + if (isa(op)) + calculateOpSteps(op->getOperandTypes()[0]); + + calculateOpSteps(getOperationVectorType(op).value()); + } + grpSteps[idx] = steps; + } +} +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 2d10ed88f..863dae5cd 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -26,6 +26,7 @@ gc_add_mlir_library(GcPasses MergeAllocTickBased.cpp FoldTensorOperation.cpp LowerToTileVector.cpp + CPUPhysicalRegisterPass.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp new file mode 100644 index 000000000..b8246d17a --- /dev/null +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -0,0 +1,2959 @@ +//===- CPUPhysicalRegisterPass.cpp - tiling as physical vector --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Transforms/TilingVector.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_CPUPHYSICALREGISTERPASS +#include "gc/Transforms/Passes.h.inc" +#define DEBUG_TYPE "lower-to-physical-register-pass" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") + +#define ARITH_CAST_OPERATIONS \ + arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::BitcastOp, \ + arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ + arith::TruncFOp, arith::TruncIOp + +/// TODO: remove it in the future +bool enableDebugPrinter = true; +bool disableSpecialOp = false; + +void printQueue(const std::queue &opQueue) { + auto tempQ(opQueue); + while (!tempQ.empty()) { + auto cur = tempQ.front(); + LDBG(*cur); + tempQ.pop(); + } +} + +void printGroupOps(SmallVector, 8> &opGroups) { + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + LDBG("group id: " << idx); + if (grp.empty()) + continue; + + LDBG("__________________ group start_____________"); + printQueue(grp); + LDBG("__________________ group end_______________"); + } +} + +static inline bool isBroadcastOp(Operation *op) { + return isa_and_nonnull(op); +} + +static inline bool isReadOrWriteOperation(Operation *op) { + return isa(op); +} + +/// Get the index position of the first element that is true +static size_t getFirstTrueIndex(ArrayRef ararys) { + for (size_t i = 0; i < ararys.size(); i++) + if (!ararys[i]) + return i; + + return -1; +} + +static inline bool isSpecialOp(Operation *op) { + return isa( + op); +} + +/// whether operation is a not support operation +bool isNotSupportOperation(Operation *op) { + return isa(op); +} + +/// whether the vector operation is operate on dynamic shape +bool hasDynamicShape(Operation *op) { + if (failed(getOperationVectorType(op))) { + return false; + } + auto isDynamicShapedType = [](Value x) { + if (auto type = dyn_cast(x.getType())) + if (ShapedType::isDynamicShape(type.getShape())) + return true; + + return false; + }; + // Check operands data type. + if (llvm::any_of(op->getOperands(), [&isDynamicShapedType](Value x) { + return isDynamicShapedType(x); + })) + return true; + + // Check results data type. + if (llvm::any_of(op->getResults(), [&isDynamicShapedType](OpResult x) { + return isDynamicShapedType(x); + })) + return true; + + return false; +} + +// TODO: Need to support these operations in the future +bool hasNotSupportOperation(func::FuncOp *func) { + auto walkRes = func->walk([](Operation *op) { + if (isNotSupportOperation(op)) { + LDBG("Operation do not support yet : " << *op << "\n"); + return WalkResult::interrupt(); + } + if (hasDynamicShape(op)) { + LDBG("Operation has dynamic shape: " << *op << "\n"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return walkRes != WalkResult::advance(); +} + +void GenerateLoopHelper::setNextAnchorArgs( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + currentLoopStateIdxMap = nextAnchorArgsIdxMap; + loopIterArgs = nextAnchorArgs; +} + +void GenerateLoopHelper::clearNextAnchorResults() { + nextAnchorResults.clear(); + nextAnchorResultsIdxMap.clear(); + nextAnchorResultOrignalResultMap.clear(); +} + +void GenerateLoopHelper::setAnchorId(size_t anchorId) noexcept { + anchorIdx = anchorId; +} + +void GenerateLoopHelper::updateDataBeforePreOpMove( + ArrayRef loopState, std::queue &candidateQueue, + std::queue &movedQueue) { + loopIterArgs = loopState; + candidateOps = &candidateQueue; + movedOps = &movedQueue; +} + +void GenerateLoopHelper::updateDataAfterPreOpMove( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); +} + +void GenerateLoopHelper::updateDataBeforePostOpMove( + ArrayRef iterArgs, DenseMap ¤tLoopStateIdxMap, + DenseMap ¤toriginalArgsMap, + DenseMap ¤tArgsOriginalMap, ValueRange forResults, + Block *forBlock, std::queue &movedQueue, size_t anchorId) { + this->originalOperandLoopArgsMap = currentoriginalArgsMap; + this->loopArgsOriginalOperandMap = currentArgsOriginalMap; + this->forResults = forResults; + this->forBlock = forBlock; + this->anchorIdx = anchorId; + this->currentLoopStateIdxMap = currentLoopStateIdxMap; + this->loopIterArgs = iterArgs; + this->movedOps = &movedQueue; +} + +void GenerateLoopHelper::updateDataAfterPostOpMove( + size_t anchorId, DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + setAnchorId(anchorId); + setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); +} + +void GenerateLoopHelper::setNextAnchorResults( + SmallVector ¤tAnchorResults, + DenseMap ¤tResultMap, + DenseMap ¤tResultIdxMap) { + nextAnchorResults = std::move(currentAnchorResults); + nextAnchorResultOrignalResultMap = std::move(currentResultMap); + nextAnchorResultsIdxMap = std::move(currentResultIdxMap); +} + +void GenerateLoopHelper::updateCurrentArgsStatus( + DenseMap ¤tArgsIdxMap, SmallVector ¤tArgs, + DenseMap &originalArgsMap, + DenseMap &argsOriginalMap) { + setNextAnchorArgs(currentArgsIdxMap, currentArgs); + originalOperandLoopArgsMap = originalArgsMap; + loopArgsOriginalOperandMap = argsOriginalMap; +} + +/// get float or integer dense attribute +/// \param [in,out] attr +template +void getConstantDenseAttr(TypedAttr &attr, VectorType type, + DenseElementsAttr denseAttr) { + using APX = std::conditional_t, + APFloat, APInt>; + attr = T::get(type, denseAttr.getSplatValue()); +} + +/// Create a new arith constant operation according to the dense element attr +FailureOr createArithSplatConstantOp(IRRewriter &rewriter, + const Location &loc, + DenseElementsAttr valueType, + VectorType newOperandType) { + if (not valueType.isSplat()) + return failure(); + + TypedAttr attr; + if (isa(newOperandType.getElementType())) + getConstantDenseAttr(attr, newOperandType, valueType); + else + getConstantDenseAttr(attr, newOperandType, valueType); + + return rewriter.create(loc, attr)->getResults()[0]; +} + +/// whether the operation result need to be returned +/// \param anchorIdx resuilt produce operation anchor position +/// \param retType resuilt return type +bool needReturnResult(std::pair &retType, + size_t anchorIdx) { + return retType.first != ReturnTypeKind::RT_InGroup or + retType.second < anchorIdx; +} + +// Since we rewrite transfer_read and transfer_write, the `permutationmap` must +// be changed. +void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, + const ShapedType &tensorType, + const AffineMap &permutationMap) { + auto dimExpr = permutationMap.getResults(); + auto lastDim = dyn_cast(dimExpr.back()); + if (not isa(lastDim)) + llvm_unreachable("Must be AffineDimExpr."); + + SmallVector affineExprs(1, lastDim); + auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, + rewriter.getContext()); + SmallVector inBounds(1, true); + if (isa(op)) { + auto transferWriteOp = cast(op); + transferWriteOp.setPermutationMap(destAffineMap); + transferWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } else if (isa(op)) { + auto transferReadOp = cast(op); + transferReadOp.setPermutationMap(destAffineMap); + transferReadOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } +} + +// scf.for yield helper function +scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { + bool hasRetVal = !value.empty(); + if (hasRetVal) + return b.create(loc, value); + else + return b.create(loc); +} + +Operation *createTensorEmptyBefore(Operation *op, IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + auto rtType = cast(op->getResultTypes()[0]); + Block *block = op->getBlock(); + rewriter.setInsertionPoint(block, block->getOperations().begin()); + + SmallVector shapes; + SmallVector dynDims; + for (unsigned i = 0; i < rtType.getRank(); i++) { + shapes.push_back(rtType.getDimSize(i)); + if (rtType.isDynamicDim(i)) + dynDims.push_back( + rewriter.create(op->getLoc(), op->getResult(0), i)); + } + auto emtpyOp = rewriter.create( + op->getLoc(), rtType.getShape(), rtType.getElementType(), dynDims); + return emtpyOp; +} + +/// get the tensor that operation should write into +Value getOperationResultTensor(Operation *op, + DenseMap &visitedOperation, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + + OpResult result = op->getResults()[0]; + for (Operation *x : result.getUsers()) { + if (not isa(x)) + continue; + + Value sourceTensor = x->getOperands()[1]; + Operation *srcOp = sourceTensor.getDefiningOp(); + if (not visitedOperation.contains(srcOp)) + continue; + + size_t pos = visitedOperation[srcOp]; + if (pos > visitedOperation[op]) + continue; + + return sourceTensor; + } + LDBG("Result not write back to tensor."); + + return createTensorEmptyBefore(op, rewriter)->getResults()[0]; +} + +Operation *createTransferWriteOpAfter(Operation *op, const Value &dest, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + + auto rtType = cast(op->getResultTypes()[0]); + int64_t rank = rtType.getRank(); + auto dstType = cast(dest.getType()); + rewriter.setInsertionPoint(op); + + auto zero = rewriter.create(op->getLoc(), 0); + + rewriter.setInsertionPointAfter(op); + SmallVector inBoundsVal(rank, true); + + SmallVector shapes; + SmallVector dynDims; + for (unsigned i = 0; i < rtType.getRank(); i++) { + shapes.push_back(rtType.getDimSize(i)); + if (rtType.isDynamicDim(i)) + dynDims.push_back( + rewriter.create(op->getLoc(), op->getResult(0), i)); + } + return rewriter.create( + op->getLoc(), + /*vector=*/op->getResult(0), + /*source=*/dest, + /*indices=*/SmallVector(dstType.getRank(), zero), + /*inBounds=*/inBoundsVal); +} + +Operation *GroupOperationFusionImpl::createTransferReadOpBefore( + Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp) { + IRRewriter &rewriter = *getRewriter(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + auto operandType = cast(operand.getType()); + auto zero = rewriter.create(op->getLoc(), 0); + auto padValue = rewriter.create( + op->getLoc(), rewriter.getZeroAttr(operandType.getElementType())); + + if (srcReadOp) { + auto resultType = cast(srcReadOp->getType()); + SmallVector inBoundsVal(resultType.getRank(), true); + auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); + // result of read operation should be same as operand + auto t = rewriter.create( + op->getLoc(), + /*vectorType=*/ + VectorType::get(resultType.getShape(), resultType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ srcReadOpAffineMap, + /*inBounds=*/inBoundsVal); + DenseMap &permutationMap = + getGroupOperationFusion().getOpPermuationMap(); + permutationMap[t] = srcReadOpAffineMap; + getGroupOperationFusion().getOpAnchorPos()[t] = + t.getVectorType().getRank() - 1; + + return t; + } + SmallVector inBoundsVal(operandType.getRank(), true); + auto t = rewriter.create( + op->getLoc(), + /*vectorType=*/ + VectorType::get(operandType.getShape(), operandType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ padValue, + /*inBounds=*/inBoundsVal); + DenseMap &permutationMap = + getGroupOperationFusion().getOpPermuationMap(); + permutationMap[t] = t.getPermutationMap(); + getGroupOperationFusion().getOpAnchorPos()[t] = + t.getVectorType().getRank() - 1; + + return t; +} + +// canonicalizing operation as tensor empty and transfer write the operation +// result into the empty tensor +[[nodiscard]] std::pair +canonicalizeSourceOperation(Operation *op, + DenseMap &visitedOperation, + IRRewriter &rewriter) { + auto resultTensor = getOperationResultTensor(op, visitedOperation, rewriter); + auto writeOp = createTransferWriteOpAfter(op, resultTensor, rewriter); + return std::make_pair(resultTensor, writeOp->getResults()[0]); +} + +[[nodiscard]] Value GroupOperationFusionImpl::canonicalizeCurrentOperation( + Operation *op, const Value &transferReadOperand, size_t operandIdx, + vector::TransferReadOp *srcReadOp) { + // transfer_read operation + auto readOp = createTransferReadOpBefore(op, transferReadOperand, srcReadOp); + op->setOperand(operandIdx, readOp->getResults()[0]); + return readOp->getResults()[0]; +} + +// __________________________________ +// Speical operations canonicalization +// __________________________________ + +//===----------------------------------------------------------------------===// +// MultiReduce Operation +//===----------------------------------------------------------------------===// + +void getOpSourceOps(Operation *op, DenseSet &srcOps) { + SmallVector srcOperands = op->getOperands(); + std::deque srcOperandsQueue(srcOperands.begin(), srcOperands.end()); + DenseSet visited; + visited.insert(op); + while (!srcOperandsQueue.empty()) { + Value accOperand = srcOperandsQueue.front(); + srcOperandsQueue.pop_front(); + Operation *accOperandOp = accOperand.getDefiningOp(); + if (!accOperandOp or visited.count(accOperandOp)) + continue; + + visited.insert(accOperandOp); + srcOps.insert(accOperandOp); + auto accOperandOperands = accOperandOp->getOperands(); + srcOperandsQueue.insert(srcOperandsQueue.end(), accOperandOperands.begin(), + accOperandOperands.end()); + } +} + +bool isSrcRelated(const DenseSet &srcOps, Operation *op) { + return srcOps.count(op); +} + +void getPrevOps(std::queue &prevOps, + std::queue &opQueue, Operation *currentOp) { + while (!opQueue.empty() && currentOp != opQueue.front()) { + prevOps.push(opQueue.front()); + opQueue.pop(); + } +} + +void getPostOps(std::queue &postOps, + std::queue &opQueue, Operation *currentOp) { + // pop multireduction op + if (currentOp != opQueue.front()) + llvm_unreachable( + "Current operation is not the front operation of the operation queue."); + + opQueue.pop(); + while (!opQueue.empty()) { + postOps.push(opQueue.front()); + opQueue.pop(); + } +} + +void getReductionInitAttr(vector::MultiDimReductionOp &multiReductionOp, + Attribute &initValueAttr) { + auto vecType = multiReductionOp.getSourceVectorType(); + auto resultElementType = vecType.getElementType(); + if (isa(resultElementType)) + initValueAttr = FloatAttr::get( + resultElementType, + getInitValForReduce(multiReductionOp.getKind(), vecType)); + else + initValueAttr = IntegerAttr::get( + resultElementType, + getInitValForReduce(multiReductionOp.getKind(), vecType)); +} + +/// get multi_reduction operation accumulate value source related operations +/// \param srcOp accumulate value source operation +void classifyAccRelatedOps(std::queue &accRelatedOps, + std::queue &sourceRelatedOps, + Operation *srcOp, std::queue &prevOps) { + DenseSet srcOpsSet; + getOpSourceOps(srcOp, srcOpsSet); + while (!prevOps.empty()) { + auto op = prevOps.front(); + prevOps.pop(); + if (isSrcRelated(srcOpsSet, op) or op == srcOp) + accRelatedOps.push(op); + else + sourceRelatedOps.push(op); + } +} + +void ForLoopGenerator::moveOperationsToCurrentForBody( + const OpBuilder &b, std::queue &opQueue, + GenerateLoopHelper &loopHelperParam) { + auto &opPermuationMap = getVectorBasedFusion().getOpPermuationMap(); + auto tmpQ(opQueue); + while (!tmpQ.empty()) { + auto x = tmpQ.front(); + tmpQ.pop(); + x->moveBefore(b.getBlock(), b.getBlock()->end()); + // check operation type to set correct operand + setOperationCorrectOperand(x, opPermuationMap, loopHelperParam); + } +} + +void ForLoopGenerator::getResultInCurrentOps( + const size_t anchorIdx, const size_t groupId, + const std::queue &ops, SmallVector &results, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap) { + auto tmpQ(ops); + llvm::MapVector> &groupResults = + getVectorBasedFusion().getGroupOpResults()[groupId]; + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + auto curResult = cur->getResults()[0]; + if (groupResults.contains(curResult)) { + std::pair retType = groupResults[curResult]; + if (needReturnResult(retType, anchorIdx)) { + results.emplace_back(curResult); + nextAnchorResultsIdxMap[curResult] = results.size() - 1; + forResultOrignalResultMap[curResult] = curResult; + } + } + } +} + +/// update loop args related status +/// \param nextAnchorArgsIdxMap anchor args index map +/// \param nextOriginalOperandMap original value to next loop args map +/// \param nextOperandOriginalMap next loop args to original value map +void updateCurrentArgsStatus(ValueRange loopState, const size_t loopStateIdx, + SmallVector &nextAnchorArgs, + Value originalValue, + DenseMap &nextAnchorArgsIdxMap, + DenseMap &nextOriginalOperandMap, + DenseMap &nextOperandOriginalMap) { + Value currentArgs = loopState[loopStateIdx]; + nextAnchorArgs.emplace_back(currentArgs); + nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size() - 1; + nextOriginalOperandMap[originalValue] = currentArgs; + nextOperandOriginalMap[currentArgs] = originalValue; +} + +void ForLoopGenerator::getInitArgsToNextAnchor( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs, + GenerateLoopHelper &loopHelperParam) { + DenseMap &opAnchorPos = + getVectorBasedFusion().getOpAnchorPos(); + SetVector &opInitArgs = + getVectorBasedFusion().getGroupOpInitArgs()[loopHelperParam.groupIdx]; + + DenseSet visited; + // find the next anchor arguments + std::queue tmpQ(*loopHelperParam.candidateOps); + DenseMap nextOriginalOperandMap, nextOperandOriginalMap; + + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + auto curOperands = cur->getOperands(); + for (auto x : curOperands) { + if (!visited.contains(x) and opInitArgs.contains(x) and + opAnchorPos[cur] > loopHelperParam.anchorIdx) { + if (not loopHelperParam.originalOperandLoopArgsMap.contains(x)) + llvm_unreachable("Must contains current value."); + int loopStateIdx = loopHelperParam.currentLoopStateIdxMap + [loopHelperParam.originalOperandLoopArgsMap[x]]; + updateCurrentArgsStatus(loopHelperParam.loopIterArgs, loopStateIdx, + nextAnchorArgs, x, nextAnchorArgsIdxMap, + nextOriginalOperandMap, nextOperandOriginalMap); + visited.insert(x); + } + } + } + loopHelperParam.originalOperandLoopArgsMap = + std::move(nextOriginalOperandMap); + loopHelperParam.loopArgsOriginalOperandMap = + std::move(nextOperandOriginalMap); +} + +void ForLoopGenerator::getOperationInCurrentAnchor( + const size_t anchorIdx, std::queue &fromQueue, + std::queue &toQueue) { + while (!fromQueue.empty()) { + Operation *curOp = fromQueue.front(); + if (anchorIdx == getVectorBasedFusion().getOpAnchorPos()[curOp]) { + toQueue.push(curOp); + fromQueue.pop(); + continue; + } + break; + } +} + +void ForLoopGenerator::replaceOperationsWithForLoopResult( + IRRewriter &rewrite, const std::queue &movingOperations, + GenerateLoopHelper &loopHelperParam) { + auto tmpQ(movingOperations); + DenseSet operationOperands; + while (!tmpQ.empty()) { + auto curOp = tmpQ.front(); + tmpQ.pop(); + for (auto x : curOp->getOperands()) + operationOperands.insert(x); + } + auto replaceIfFn = [&](OpOperand &use) { + return operationOperands.contains(use.get()); + }; + for (auto [nxtForResult, nextLoopResult] : + zip(loopHelperParam.forResults, loopHelperParam.nextAnchorResults)) { + Value originalResult = + loopHelperParam.nextAnchorResultOrignalResultMap[nextLoopResult]; + + rewrite.replaceOpUsesWithIf(originalResult.getDefiningOp(), nxtForResult, + replaceIfFn); + } +} + +/// \param [in,out] nextLoopStateIdxMap +/// \param [in,out] nextAnchorArgs +void ForLoopGenerator::movePreOpToCurrentAnchor( + OpBuilder &b, DenseMap &nextLoopStateIdxMap, + SmallVector &nextAnchorArgs, + GenerateLoopHelper &loopHelperParam) { + + // 1. get operations in current anchor position + std::queue movingOperation; + getOperationInCurrentAnchor(loopHelperParam.anchorIdx, + *loopHelperParam.candidateOps, movingOperation); + + // 2. rewrite operation as vectorize IR + rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperation); + + // 3. move opeartions to current for block + moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); + + // 4. get next anchor args + getInitArgsToNextAnchor(nextLoopStateIdxMap, nextAnchorArgs, loopHelperParam); + + // 5. move operations to moved queue + while (!movingOperation.empty()) { + loopHelperParam.movedOps->push(movingOperation.front()); + movingOperation.pop(); + } +} + +void ForLoopGenerator::movePostOpToCurrentAnchor( + OpBuilder &b, GenerateLoopHelper &loopHelperParam) { + OpBuilder::InsertionGuard g(b); + + std::queue movingOperations; + // 1. get post-op to current loop bod + getOperationInCurrentAnchor(loopHelperParam.anchorIdx, + *loopHelperParam.candidateOps, movingOperations); + // 2. rewrite operation as vectorize IR + rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperations); + + // 3. move opeartions to current for block + moveOperationsToCurrentForBody(b, movingOperations, loopHelperParam); + + // 4. replace correct for loop result to post-op + replaceOperationsWithForLoopResult(*getRewriter(), movingOperations, + loopHelperParam); + + // 5. move operations to moved queue + while (!movingOperations.empty()) { + loopHelperParam.movedOps->push(movingOperations.front()); + movingOperations.pop(); + } +} + +void ForLoopGenerator::generateLoopResults( + OpBuilder &b, const Location &loc, GenerateLoopHelper &loopHelperParam, + DenseMap &nextOperandIdxMap) { + OpBuilder::InsertionGuard g(b); + SmallVector results; + DenseMap currentResultMap; + getResultInCurrentOps(loopHelperParam.anchorIdx, loopHelperParam.groupIdx, + *loopHelperParam.movedOps, results, + loopHelperParam.nextAnchorResultsIdxMap, + currentResultMap); + + llvm::MapVector> &groupResults = + getVectorBasedFusion().getGroupOpResults()[loopHelperParam.groupIdx]; + // check for yield results whether need to return to next anchor + for (auto [idx, forResult] : + llvm::enumerate(loopHelperParam.nextAnchorResults)) { + Value originalResult = + loopHelperParam.nextAnchorResultOrignalResultMap[forResult]; + + if (groupResults.contains(originalResult)) { + std::pair resultType = + groupResults[originalResult]; + if (needReturnResult(resultType, loopHelperParam.anchorIdx)) { + results.emplace_back(loopHelperParam.forResults[idx]); + currentResultMap[loopHelperParam.forResults[idx]] = originalResult; + } + } + } + + loopHelperParam.nextAnchorResults.clear(); + loopHelperParam.nextAnchorResultsIdxMap.clear(); + // reduction operation due to special process results size will be zero + if (not results.empty()) + for (Value x : loopHelperParam.loopIterArgs) { + loopHelperParam.nextAnchorResults.emplace_back( + results[nextOperandIdxMap[x]]); + loopHelperParam.nextAnchorResultsIdxMap[results[nextOperandIdxMap[x]]] = + loopHelperParam.nextAnchorResults.size() - 1; + } + + loopHelperParam.nextAnchorResultOrignalResultMap = + std::move(currentResultMap); +} + +void updateLoopArgsData(Value val, Value originalVal, + SmallVector &argsArray, + DenseMap &anchorArgsIdxMap, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap) { + argsArray.emplace_back(val); + anchorArgsIdxMap[val] = argsArray.size() - 1; + loopArgsOriginalOperandMap[val] = originalVal; + originalOperandLoopArgsMap[originalVal] = val; +} + +void LoopGeneratorImpl::rectifyParallelIndice( + GenerateLoopHelper &loopHelperParam, Location loc) { + OpBuilder::InsertionGuard g(*getRewriter()); + MultiReductionCanonicalizer rdCanonicalizer = + getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; + auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; + SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); + + // rectify indice of read from source operand + std::queue candidateOps; + getSameBlockTargetOp( + multireductionOp.getSource().getDefiningOp(), candidateOps); + while (not candidateOps.empty()) { + auto sourceReadOp = candidateOps.front(); + candidateOps.pop(); + getRewriter()->setInsertionPoint(sourceReadOp); + AffineExpr outterParallel, innerParallel; + bindDims(multireductionOp->getContext(), outterParallel, innerParallel); + + Value op = + loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - + reductionAxis.size() - 2]; + Value ip = + loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - + reductionAxis.size() - 1]; + Value newIndice = getRewriter()->createOrFold( + loc, (outterParallel + innerParallel), ValueRange{op, ip}); + int parallelSize = rdCanonicalizer.getParallelAxis().size(); + int readIndiceOffset = + 1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1]; + sourceReadOp->setOperand(readIndiceOffset, newIndice); + } +} + +scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( + OpBuilder &opBuilder, const size_t reductionIdx, + GenerateLoopHelper &loopHelperParam) { + + MultiReductionCanonicalizer rdCanonicalizer = + getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; + auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; + GroupOperationFusion &fusionStrategy = getVectorBasedFusion(); + + SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + std::queue &opQueue = opGroups[loopHelperParam.groupIdx]; + + const auto loc = multireductionOp->getLoc(); + SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); + VectorType vectorType = rdCanonicalizer.getSourceType(); + auto tpHelper = fusionStrategy.getTypeHelper(); + + int loopStep = tpHelper.generateValidSteps( + fusionStrategy.getTypeHelper().getDataTypeMAXSIMDLength(vectorType), + vectorType, vectorType.getShape()[reductionAxis[reductionIdx]]); + bool isLastDimReduction = rdCanonicalizer.getHasLastDimReduction(); + loopStep = (reductionIdx == reductionAxis.size() - 1 && isLastDimReduction) + ? loopStep + : 1; + + Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); + Value forSteps = makeIndexArithConstantOp(opBuilder, loc, loopStep); + Value numIter = makeIndexArithConstantOp( + opBuilder, loc, vectorType.getShape()[reductionAxis[reductionIdx]]); + scf::ForOp forOp = opBuilder.create( + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + loopHelperParam.inductionVars.emplace_back(iv); + size_t currentAnchorId = loopHelperParam.anchorIdx; + SmallVector tmpArgs(loopState); + Value originalRetVal = multireductionOp->getResults()[0]; + + if (reductionIdx < reductionAxis.size() - 1) { + + // 1. move pre-Op to current body + DenseMap nextAnchorArgsIdxMap; + SmallVector nextAnchorArgs; + std::queue movedOperation; + DenseMap currentoriginalArgsMap = + loopHelperParam.originalOperandLoopArgsMap; + DenseMap currentArgsOriginalMap = + loopHelperParam.loopArgsOriginalOperandMap; + DenseMap currentArgsIdxMap = + loopHelperParam.currentLoopStateIdxMap; + DenseMap originalArgsMap, argsOriginalMap; + loopHelperParam.updateDataBeforePreOpMove(tmpArgs, opQueue, + movedOperation); + movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, + loopHelperParam); + loopHelperParam.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, + nextAnchorArgs); + + // replace reduction init args + if (currentoriginalArgsMap.contains(multireductionOp.getAcc())) { + size_t accValIdx = currentArgsIdxMap + [currentoriginalArgsMap[multireductionOp.getAcc()]]; + updateCurrentArgsStatus( + loopState, accValIdx, nextAnchorArgs, multireductionOp.getAcc(), + nextAnchorArgsIdxMap, originalArgsMap, argsOriginalMap); + loopHelperParam.updateCurrentArgsStatus( + nextAnchorArgsIdxMap, nextAnchorArgs, originalArgsMap, + argsOriginalMap); + } + + loopHelperParam.anchorIdx += 1; + // 2. generate next for loop + scf::ForOp nxtFor = reductionAxisGenerateForLoop(b, reductionIdx + 1, + loopHelperParam); + loopHelperParam.anchorIdx -= 1; + + loopHelperParam.updateDataBeforePostOpMove( + tmpArgs, currentArgsIdxMap, currentoriginalArgsMap, + currentArgsOriginalMap, nxtFor->getResults(), b.getBlock(), + movedOperation, currentAnchorId); + // 3. move postOp to current body + movePostOpToCurrentAnchor(b, loopHelperParam); + + // 4. generate loop results + generateLoopResults(b, loc, loopHelperParam, nextAnchorArgsIdxMap); + + // reduction must return accumulate + if (loopHelperParam.orignalResultNextAnchorResultMap.contains( + originalRetVal)) { + Value lastForResult = + loopHelperParam + .orignalResultNextAnchorResultMap[originalRetVal]; + size_t retIdx = nextAnchorArgsIdxMap + [loopHelperParam + .nextAnchorResultOrignalResultMap[lastForResult]]; + Value forRes = nxtFor->getResults()[retIdx]; + // accumulate for loop iter args must be last, so we just put the + // reduction result as the last result + updateLoopArgsData( + forRes, originalRetVal, loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.orignalResultNextAnchorResultMap, + loopHelperParam.nextAnchorResultOrignalResultMap); + } + + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); + + } else if (reductionIdx == reductionAxis.size() - 1) { + std::queue movingOperation; + + while (!opQueue.empty()) { + Operation *curOp = opQueue.front(); + opQueue.pop(); + if (isa(curOp)) + break; + + movingOperation.push(curOp); + } + // remove all the multi_reduction operation + while (!opQueue.empty()) { + Operation *curOp = opQueue.front(); + if (isa(curOp)) { + opQueue.pop(); + continue; + } + break; + } + + rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, + &movingOperation, + isLastDimReduction ? loopStep : 0); + loopHelperParam.loopIterArgs = loopState; + moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); + if (isLastDimReduction) + rectifyParallelIndice(loopHelperParam, loc); + loopHelperParam.movedOps = &movingOperation; + loopHelperParam.candidateOps = &opQueue; + + int accValIdx = + loopHelperParam.currentLoopStateIdxMap + [loopHelperParam + .originalOperandLoopArgsMap[multireductionOp.getAcc()]]; + + Value reductionResult = makeArithReduction( + b, loc, multireductionOp.getKind(), multireductionOp.getSource(), + loopState[accValIdx]); + + loopHelperParam.updateDataBeforePostOpMove( + tmpArgs, loopHelperParam.currentLoopStateIdxMap, + loopHelperParam.originalOperandLoopArgsMap, + loopHelperParam.loopArgsOriginalOperandMap, ValueRange(), + b.getBlock(), movingOperation, currentAnchorId); + + movePostOpToCurrentAnchor(b, loopHelperParam); + + loopHelperParam.nextAnchorResults.clear(); + updateLoopArgsData(reductionResult, originalRetVal, + loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.orignalResultNextAnchorResultMap, + loopHelperParam.nextAnchorResultOrignalResultMap); + getResultInCurrentOps( + loopHelperParam.anchorIdx, loopHelperParam.groupIdx, + movingOperation, loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.nextAnchorResultOrignalResultMap); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); + } + }); + + return forOp; +} + +void LoopGeneratorImpl::ensureAccInParallelLoop( + GenerateLoopHelper &loopHelperParam, ArrayRef parallelAxis, + Value multiReductionAcc, DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + if (loopHelperParam.anchorIdx == parallelAxis.size() - 1) { + // Ensure accumalate expression appear in this parallel anchor + // position. If it not appear in current anchor, we must move it in + // here. + // 1. delete it in operation queue + // 2. move it in current movedqueue + DenseSet argsSet(nextAnchorArgs.begin(), nextAnchorArgs.end()); + std::queue checkAccQueue(*loopHelperParam.movedOps); + Value accInitVal; + while (!checkAccQueue.empty()) { + Operation *cur = checkAccQueue.front(); + checkAccQueue.pop(); + bool ok = false; + for (auto x : cur->getResults()) { + if (x == multiReductionAcc) { + accInitVal = x; + ok = true; + break; + } + } + if (ok) + break; + } + if (accInitVal) { + // we put initVal at last for loop args + if (!argsSet.contains(accInitVal)) { + nextAnchorArgs.emplace_back(accInitVal); + nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; + loopHelperParam.loopArgsOriginalOperandMap[accInitVal] = + multiReductionAcc; + loopHelperParam.originalOperandLoopArgsMap[multiReductionAcc] = + accInitVal; + } + loopHelperParam.loopIterArgs = nextAnchorArgs; + loopHelperParam.nextAnchorResultsIdxMap = nextAnchorArgsIdxMap; + } else { + llvm::llvm_unreachable_internal("Wrong accumualte source value. Because " + "acc value must appear in here."); + } + } +} + +/// Generate for loop for parallel axis of `vector.multi_reduction`. +/// This function also call reduction axis for loop +scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop( + OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam) { + OpBuilder::InsertionGuard g(opBuilder); + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; + vector::MultiDimReductionOp &multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; + VectorType vectorType = rdCanonicalizer.getSourceType(); + GroupOperationFusion &fusionStrategy = getVectorBasedFusion(); + + SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); + const Location &loc = multiReductionOp.getLoc(); + Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); + size_t grpMaxStep = + getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx]; + size_t actualStep = + (loopHelperParam.anchorIdx == parallelAxis.size() - 1 ? grpMaxStep : 1); + Value forSteps = makeIndexArithConstantOp(opBuilder, loc, actualStep); + + // last dim reduction need to a generate dim=16 loop for fused with pre-op + int dimSize = 0; + if (loopHelperParam.anchorIdx == parallelAxis.size()) + dimSize = + getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx]; + else + dimSize = vectorType.getShape()[parallelAxis[loopHelperParam.anchorIdx]]; + + Value numIter = makeIndexArithConstantOp(opBuilder, loc, dimSize); + // Create a loop and move vectorized operation into loops. + return opBuilder.create( + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + loopHelperParam.inductionVars.emplace_back(iv); + + DenseMap &opIndexMap = + fusionStrategy.getOpGroupIndexMap(); + + if (not opIndexMap.contains(multiReductionOp)) + llvm_unreachable("Must constains multireduction operation."); + + size_t opIndex = opIndexMap[multiReductionOp]; + SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + std::queue &opQueue = opGroups[opIndex]; + Value multiReductionAcc = multiReductionOp.getAcc(); + + if (loopHelperParam.anchorIdx < parallelAxis.size()) { + // 1. move pre-Op to current body + DenseMap nextAnchorArgsIdxMap; + SmallVector nextAnchorArgs; + std::queue movedQueue; + DenseMap currentOriginalOperandMap = + loopHelperParam.originalOperandLoopArgsMap; + DenseMap currentOperandOriginalMap = + loopHelperParam.loopArgsOriginalOperandMap; + DenseMap currentLoopStateIdxMap = + loopHelperParam.currentLoopStateIdxMap; + SmallVector tmpArgs(loopState); + loopHelperParam.updateDataBeforePreOpMove(tmpArgs, opQueue, + movedQueue); + movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, + loopHelperParam); + loopHelperParam.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, + nextAnchorArgs); + ensureAccInParallelLoop(loopHelperParam, parallelAxis, + multiReductionAcc, nextAnchorArgsIdxMap, + nextAnchorArgs); + scf::ForOp nxtFor; + // 2. generate next for loop + bool useParallelLoop = + rdCanonicalizer.hasLastDimReduction() or + loopHelperParam.anchorIdx < parallelAxis.size() - 1; + loopHelperParam.anchorIdx += 1; + if (useParallelLoop) { + nxtFor = parallelAxisGenerateForLoop(b, loopHelperParam); + } else { + nxtFor = reductionAxisGenerateForLoop(b, 0, loopHelperParam); + } + loopHelperParam.anchorIdx -= 1; + + // 3. move postOp to current body + loopHelperParam.updateDataBeforePostOpMove( + tmpArgs, currentLoopStateIdxMap, currentOriginalOperandMap, + currentOperandOriginalMap, nxtFor->getResults(), + nxtFor->getBlock(), movedQueue, loopHelperParam.anchorIdx); + movePostOpToCurrentAnchor(b, loopHelperParam); + // 4. generate loop results + generateLoopResults(b, loc, loopHelperParam, nextAnchorArgsIdxMap); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); + + } else if (loopHelperParam.anchorIdx == parallelAxis.size()) { + + DenseMap tmpOriginOperandLoopArgsMap = + loopHelperParam.originalOperandLoopArgsMap; + DenseMap tmpLoopArgsOriginalOperandMap = + loopHelperParam.loopArgsOriginalOperandMap; + + // get accumualte value + Attribute initValueAttr; + getReductionInitAttr(multiReductionOp, initValueAttr); + SmallVector &reductionAxis = + rdCanonicalizer.getReductionAxis(); + TypeHelper tpHelper = fusionStrategy.getTypeHelper(); + int loopStep = tpHelper.generateValidSteps( + tpHelper.getDataTypeMAXSIMDLength(vectorType), vectorType, + vectorType.getShape()[reductionAxis[reductionAxis.size() - 1]]); + auto accVal = b.create( + loc, DenseElementsAttr::get( + fusionStrategy.getTypeHelper().getVectorzedType( + multiReductionOp, loopStep), + {initValueAttr})); + + // put accumulte val at first for loop args + DenseMap localAnchorArgsIdxMap; + DenseMap localOriginalOperandLoopArgsMap, + localLoopArgsOriginalOperandMap; + SmallVector argsArray; + updateLoopArgsData( + accVal, multiReductionAcc, argsArray, localAnchorArgsIdxMap, + localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap); + + size_t accLoopStateIdx = + loopHelperParam.currentLoopStateIdxMap + [loopHelperParam + .originalOperandLoopArgsMap[multiReductionAcc]]; + for (auto [idx, x] : llvm::enumerate(loopState)) { + if (idx == accLoopStateIdx) + continue; + updateLoopArgsData(x, + loopHelperParam.loopArgsOriginalOperandMap + [loopHelperParam.loopIterArgs[idx]], + argsArray, localAnchorArgsIdxMap, + localOriginalOperandLoopArgsMap, + localLoopArgsOriginalOperandMap); + } + loopHelperParam.updateCurrentArgsStatus( + localAnchorArgsIdxMap, argsArray, localOriginalOperandLoopArgsMap, + localLoopArgsOriginalOperandMap); + DenseMap originalResultForResultMap; + auto nxtFor = reductionAxisGenerateForLoop(b, 0, loopHelperParam); + + // insert accumulate value to original vector + Value nxtForAccVal = + originalResultForResultMap[multiReductionOp->getResults()[0]]; + size_t accIdx = loopHelperParam.nextAnchorResultsIdxMap[nxtForAccVal]; + auto accRes = nxtFor->getResults()[accIdx]; + + Operation *reductionOp = b.create( + loc, multiReductionOp.getKind(), accRes); + auto insertOp = b.create( + loc, reductionOp->getResult(0), loopState[accLoopStateIdx], iv); + + // generate loop result + SmallVector currentAnchorResults(loopState.size()); + DenseMap currentResultMap; + DenseMap currentResultIdxMap; + + currentAnchorResults[accLoopStateIdx] = insertOp->getResults()[0]; + // reduce axis for loop first result we has already processed above + currentResultMap[insertOp->getResults()[0]] = + multiReductionOp->getResults()[0]; + currentResultIdxMap[insertOp->getResults()[0]] = accLoopStateIdx; + for (auto [idx, x] : + llvm::enumerate(loopHelperParam.nextAnchorResults)) { + if (loopHelperParam.nextAnchorResultOrignalResultMap[x] == + multiReductionOp->getResults()[0]) + continue; + + Value originalResult = + loopHelperParam.nextAnchorResultOrignalResultMap[x]; + size_t itrIdx = loopHelperParam.currentLoopStateIdxMap + [tmpOriginOperandLoopArgsMap[originalResult]]; + currentAnchorResults[itrIdx] = nxtFor->getResults()[idx]; + currentResultIdxMap[nxtFor->getResults()[idx]] = itrIdx; + currentResultMap[nxtFor->getResults()[idx]] = originalResult; + } + loopHelperParam.clearNextAnchorResults(); + loopHelperParam.setNextAnchorResults( + currentAnchorResults, currentResultMap, currentResultIdxMap); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); + } + }); +} + +scf::ForOp LoopGeneratorImpl::generateTransposeForLoopWithLastDim( + OpBuilder &opBuilder, const int tpSteps, const Location &loc, + Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam) { + auto &tpCanonicalizer = + getTransposeCanonicalizers()[loopHelperParam.groupIdx]; + vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + VectorType vtType = tpOp.getVector().getType(); + size_t rank = vtType.getRank(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + bool isTransposeDim = + loopHelperParam.anchorIdx == tpCanonicalizer.getFirstTpIdx() or + loopHelperParam.anchorIdx == tpCanonicalizer.getSecondTpIdx(); + auto forSteps = + makeIndexArithConstantOp(opBuilder, loc, isTransposeDim ? tpSteps : 1); + auto numIter = makeIndexArithConstantOp( + opBuilder, loc, vtType.getShape()[loopHelperParam.anchorIdx]); + VectorType kernelType = + VectorType::get({tpSteps, tpSteps}, vtType.getElementType()); + // generate transpose for loop + return opBuilder.create( + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + loopHelperParam.inductionVars.emplace_back(iv); + + // inner most body of the loop + if (loopHelperParam.anchorIdx == rank - 1) { + // transfer read from source tensor + Value source = tpOp->getOperand(0); + auto readSourceOp = + cast(source.getDefiningOp()); + auto padValue = b.create( + loc, b.getZeroAttr(vtType.getElementType())); + SmallVector inBoundsVal(2, true); + inBoundsVal[0] = !ShapedType::isDynamic( + vtType.getShape()[tpCanonicalizer.getFirstTpIdx()]); + inBoundsVal[1] = !ShapedType::isDynamic( + vtType.getShape()[tpCanonicalizer.getSecondTpIdx()]); + + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp.getSource(), + /*indices=*/loopHelperParam.inductionVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + SmallVector perm{1, 0}; + auto transposeOp = b.create( + loc, transferReadOp->getResults()[0], perm); + SmallVector writeVars(loopHelperParam.inductionVars.begin(), + loopHelperParam.inductionVars.end()); + writeVars[tpCanonicalizer.getSecondTpIdx()] = + loopHelperParam.inductionVars[tpCanonicalizer.getFirstTpIdx()]; + writeVars[tpCanonicalizer.getFirstTpIdx()] = + loopHelperParam.inductionVars[tpCanonicalizer.getSecondTpIdx()]; + auto writeOp = b.create( + loc, transposeOp->getResults()[0], loopState[0], writeVars, + inBoundsVal); + maybeYieldValue(b, loc, writeOp->getResults()); + } else { + // outter loop + loopHelperParam.anchorIdx += 1; + loopHelperParam.loopIterArgs = loopState; + auto nxtFor = generateTransposeForLoopWithLastDim( + b, tpSteps, loc, successorWriteOp, loopHelperParam); + loopHelperParam.anchorIdx -= 1; + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); +} + +void ForLoopGenerator::prepareForLoopArgs(const size_t grpIdx, + GenerateLoopHelper &loopHelper) { + SetVector &grpArgs = + getVectorBasedFusion().getGroupOpInitArgs()[grpIdx]; + loopHelper.loopIterArgs = grpArgs.getArrayRef(); + for (auto [idx, val] : llvm::enumerate(grpArgs)) { + loopHelper.currentLoopStateIdxMap[val] = idx; + loopHelper.originalOperandLoopArgsMap[val] = val; + loopHelper.loopArgsOriginalOperandMap[val] = val; + } +} + +void LoopGeneratorImpl::rearrageMultiReductionIR( + const size_t grpIdx, + DenseMap> &indiceLoopMap) { + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[grpIdx]; + vector::MultiDimReductionOp multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; + SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); + SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); + std::queue &prevOps = rdCanonicalizer.getPrevOps(); + std::queue &postOps = rdCanonicalizer.getPostOps(); + std::queue &accRelatedOps = rdCanonicalizer.getAccRelatedOps(); + std::queue &sourceRelatedOps = + rdCanonicalizer.getSourceRelatedOps(); + std::queue &opQueue = + getVectorBasedFusion().getOpGroups()[grpIdx]; + auto copyOpQueue(opQueue); + getPrevOps(prevOps, copyOpQueue, multiReductionOp); + getPostOps(postOps, copyOpQueue, multiReductionOp); + classifyAccRelatedOps(accRelatedOps, sourceRelatedOps, + multiReductionOp.getAcc().getDefiningOp(), prevOps); + + // mark source read operation need to set correct for loop var idx + std::queue tmpSourceQ(sourceRelatedOps); + DenseMap varLoopIdxMap; + VectorType groupVector = + getVectorBasedFusion().getGroupBiggestRankVectorType()[grpIdx]; + for (size_t i = 0; i < parallelAxis.size(); i++) + varLoopIdxMap[parallelAxis[i]] = i; + + size_t offset = rdCanonicalizer.hasLastDimReduction() ? 1 : 0; + for (size_t i = parallelAxis.size() + offset; + i < groupVector.getRank() + offset; i++) + varLoopIdxMap[reductionAxis[i - parallelAxis.size() - offset]] = i; + + while (!tmpSourceQ.empty()) { + auto *curOp = tmpSourceQ.front(); + tmpSourceQ.pop(); + if (isa(curOp)) + getCurrentGroupIndiceLoopMap(indiceLoopMap, grpIdx, curOp, varLoopIdxMap); + } + + // move accumulate related operation to operation first + std::queue rectifyQueue; + DenseSet pushedSet; + auto moveOperation = [&](std::queue &from, + std::queue &to) { + while (!from.empty()) { + auto cur = from.front(); + from.pop(); + if (pushedSet.contains(cur)) + continue; + + to.push(cur); + pushedSet.insert(cur); + } + }; + moveOperation(accRelatedOps, rectifyQueue); + moveOperation(opQueue, rectifyQueue); + opQueue = rectifyQueue; +} + +void ForLoopGenerator::replaceOpUsersWithForLoopResult( + scf::ForOp forOp, int grpIdx, SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap) { + + DenseSet forOpChildOps; + forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); + auto replaceIfFn = [&](OpOperand &use) { + return not forOpChildOps.contains(use.getOwner()); + }; + for (auto x : nextAnchorResults) { + auto originalResult = forResultOrignalResultMap[x]; + Value forResult = forOp->getResults()[nextAnchorResultsIdxMap[x]]; + // subsequent group must use the replaced result as operand + rectifyGroupOperands(grpIdx, originalResult, forResult); + getRewriter()->replaceOpUsesWithIf(originalResult.getDefiningOp(), + forResult, replaceIfFn); + } +} +scf::ForOp +LoopGeneratorImpl::generateMultiReductionForLoop(const size_t grpIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); + DenseMap> indiceLoopMap; + rearrageMultiReductionIR(grpIdx, indiceLoopMap); + // get current loop init args + DenseMap currentLoopStateIdxMap, nextAnchorResultsIdxMap; + GenerateLoopHelper loopHelper(grpIdx, 0); + prepareForLoopArgs(grpIdx, loopHelper); + + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[grpIdx]; + + getRewriter()->setInsertionPoint(rdCanonicalizer.getCandidateOps()[0]); + loopHelper.indiceLoopMap = indiceLoopMap; + + scf::ForOp forOp = parallelAxisGenerateForLoop(*getRewriter(), loopHelper); + replaceOpUsersWithForLoopResult(forOp, grpIdx, loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, + loopHelper.nextAnchorResultOrignalResultMap); + + vector::MultiDimReductionOp multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; + getRewriter()->eraseOp(multiReductionOp); + + return forOp; +} + +// generate simple data movement for loop +scf::ForOp LoopGeneratorImpl::generateTransposeScalarDataMovement( + OpBuilder &opBuilder, const Location &loc, + DenseMap &tpAxisMap, GenerateLoopHelper &loopHelperParam) { + auto &tpCanonicalizer = + getTransposeCanonicalizers()[loopHelperParam.groupIdx]; + vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + VectorType vtType = tpOp.getSourceVectorType(); + size_t rank = vtType.getRank(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + size_t vecStep = tpCanonicalizer.transposeOnLastDim() + ? tpCanonicalizer.getVectorStep() + : 1; + auto forSteps = makeIndexArithConstantOp( + opBuilder, loc, loopHelperParam.anchorIdx == rank - 1 ? (vecStep) : 1); + auto numIter = makeIndexArithConstantOp( + opBuilder, loc, vtType.getShape()[loopHelperParam.anchorIdx]); + + SmallVector vecShapes(1, vecStep); + VectorType kernelType = VectorType::get(vecShapes, vtType.getElementType()); + // generate transpose for loop + return opBuilder.create( + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + loopHelperParam.inductionVars.emplace_back(iv); + + // inner most body of the loop + if (loopHelperParam.anchorIdx == rank - 1) { + // transfer read from source tensor + Value source = tpOp->getOperand(0); + auto readSourceOp = + cast(source.getDefiningOp()); + vector::TransferWriteOp successorWriteOp; + for (Operation *x : tpOp->getUsers()) { + if (isa(x)) { + successorWriteOp = cast(x); + break; + } + } + auto padValue = b.create( + loc, b.getZeroAttr(vtType.getElementType())); + SmallVector inBoundsVal(1, true); + SmallVector writeVars; + size_t itrIdx = 0; + while (itrIdx < rank) { + writeVars.emplace_back( + loopHelperParam.inductionVars[tpAxisMap[itrIdx]]); + itrIdx++; + } + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp.getSource(), + /*indices=*/loopHelperParam.inductionVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + + rectifyWriteOperationIndice(&successorWriteOp, writeVars); + + auto writeOp = b.create( + loc, transferReadOp->getResults()[0], loopState[0], writeVars, + inBoundsVal); + maybeYieldValue(b, loc, writeOp->getResults()); + } else { + // outter loop + loopHelperParam.anchorIdx += 1; + loopHelperParam.loopIterArgs = loopState; + auto nxtFor = generateTransposeScalarDataMovement(b, loc, tpAxisMap, + loopHelperParam); + loopHelperParam.anchorIdx -= 1; + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); +} + +scf::ForOp LoopGeneratorImpl::generateShapeCastReadWriteLoop( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const size_t steps, const Location &loc, SmallVector &inductionVars, + ValueRange iterArgs) { + auto &scCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; + vector::ShapeCastOp &scOp = scCanonicalizer.getCandidateOps()[0]; + VectorType sourceType = scOp.getSourceVectorType(); + VectorType destType = scOp.getResultVectorType(); + VectorType loopType = + sourceType.getRank() > destType.getRank() ? sourceType : destType; + size_t rank = loopType.getRank(); + DenseMap &opIndexMap = + getVectorBasedFusion().getOpGroupIndexMap(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + bool isLastDim = loopType.getRank() - 1 == (int64_t)forDimIdx; + auto forSteps = + makeIndexArithConstantOp(opBuilder, loc, isLastDim ? steps : 1); + auto numIter = + makeIndexArithConstantOp(opBuilder, loc, loopType.getShape()[forDimIdx]); + VectorType kernelType = + VectorType::get({(int64_t)steps}, loopType.getElementType()); + + // generate transpose for loop + return opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + // inner most body of the loop + if (forDimIdx == rank - 1) { + // transfer read from source tensor + Value source = scOp->getOperand(0); + auto readSourceOp = + cast(source.getDefiningOp()); + SmallVector successorWriteOps; + for (Operation *x : scOp->getUsers()) { + if (isa(x) and opIndexMap.contains(x) and + opIndexMap[x] == opIndexMap[scOp]) { + successorWriteOps.emplace_back(cast(x)); + } + } + SmallVector exprs(loopType.getRank(), AffineExpr()); + bindSymbolsList(b.getContext(), exprs); + SmallVector operands{inductionVars.begin(), + inductionVars.end()}; + SmallVector smallRankShapeVars; + + auto getSmallRankShapeVars = [&](VectorType smallType) { + size_t itrIdx = 0; + SmallVector visitedAxis(rank, false); + while ((int64_t)itrIdx < smallType.getRank()) { + + size_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; + if (endShape >= rank) + llvm_unreachable("Invalid shape."); + // skip non corresponding axis + // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> + while (loopType.getShape()[endShape] > + smallType.getShape()[itrIdx]) { + endShape++; + } + const size_t expandIdx = endShape; + while (endShape < rank) { + visitedAxis[endShape] = true; + dimSize *= loopType.getShape()[endShape]; + if ((int64_t)dimSize == smallType.getShape()[itrIdx]) { + break; + } + endShape += 1; + } + const size_t expandSize = endShape - expandIdx + 1; + AffineExpr calculateOffset; + SmallVector offsetVars; + + for (size_t i = 0; i < expandSize; i++) { + size_t startIdx = i + 1; + size_t otherDimsSize = 1; + while (startIdx < expandSize) { + otherDimsSize *= (loopType.getShape()[startIdx + expandIdx]); + startIdx++; + } + AffineExpr dimSize = + getAffineConstantExpr(otherDimsSize, b.getContext()); + if (i == 0) { + calculateOffset = exprs[i] * dimSize; + } else { + calculateOffset = calculateOffset + exprs[i] * dimSize; + } + + offsetVars.emplace_back(inductionVars[i + expandIdx]); + } + AffineMap map = AffineMap::get(0, expandSize, calculateOffset); + + Value offset = + b.createOrFold(loc, map, offsetVars); + smallRankShapeVars.emplace_back(offset); + itrIdx++; + } + }; + + if (loopType == sourceType) { + getSmallRankShapeVars(destType); + } else { + getSmallRankShapeVars(sourceType); + } + + auto padValue = b.create( + loc, b.getZeroAttr(loopType.getElementType())); + + SmallVector inBoundsVal(1, true); + + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp->getOperands()[0], + /*indices=*/loopType == sourceType ? inductionVars + : smallRankShapeVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + + SmallVector writeVars = + loopType == sourceType ? smallRankShapeVars : inductionVars; + SmallVector writeResults; + for (auto successorWriteOp : successorWriteOps) { + rectifyWriteOperationIndice(&successorWriteOp, writeVars); + auto writeOp = b.create( + loc, transferReadOp->getResults()[0], loopState[0], writeVars, + inBoundsVal); + writeResults.emplace_back(writeOp->getResults()[0]); + } + maybeYieldValue(b, loc, writeResults); + } else { + // outter loop + auto nxtFor = generateShapeCastReadWriteLoop( + b, grpIdx, forDimIdx + 1, steps, loc, inductionVars, loopState); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); +} + +void ForLoopGenerator::rectifyWriteOperationIndice( + vector::TransferWriteOp *originalWriteOp, + SmallVectorImpl &writeVars) { + VectorType sucessWriteVectorType = originalWriteOp->getVectorType(); + ShapedType successWriteTensorType = + cast(originalWriteOp->getResultTypes()[0]); + size_t inMutableIdx = + successWriteTensorType.getRank() - sucessWriteVectorType.getRank(); + Operation::operand_range writeIndices = originalWriteOp->getIndices(); + + for (size_t i = 0; i < inMutableIdx; i++) + writeVars[i] = writeIndices[i]; +} + +void ForLoopGenerator::rectifyReadOperationIndice( + vector::TransferReadOp *originalReadOp, VectorType loopType, + ArrayRef inductionVars, SmallVectorImpl &readVars) { + ShapedType readTensorType = + cast(originalReadOp->getSource().getType()); + // currently only broadcast (fuse as transfer_read) will move into more inner + // loop + if (readTensorType.getRank() - 1 >= + (int64_t)getVectorBasedFusion().getOpAnchorPos()[*originalReadOp]) + return; + + int64_t itrIdx = loopType.getRank() - 1; + int64_t readIdx = readTensorType.getRank() - 1; + while (itrIdx >= 0 and readIdx >= 0) { + if (readTensorType.getShape()[readIdx] == loopType.getShape()[itrIdx]) { + readVars[readIdx] = inductionVars[itrIdx]; + readIdx--; + } + itrIdx--; + } +} + +/// generate transpose for loop +scf::ForOp LoopGeneratorImpl::generateShapeCastForLoop(const size_t grpIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); + ShapeCastCanonicalizer &scCanonicalizer = + getShapeCastCanonicalizers()[grpIdx]; + vector::ShapeCastOp &scOp = scCanonicalizer.getCandidateOps()[0]; + + VectorType sourceType = scOp.getSourceVectorType(); + VectorType destType = scOp.getResultVectorType(); + DenseMap &opIndexMap = + getVectorBasedFusion().getOpGroupIndexMap(); + + OpBuilder b(scOp); + SmallVector iterArgs; + SmallVector successorWriteOps; + for (Operation *x : scOp->getUsers()) + if (isa(x) and opIndexMap.contains(x) and + opIndexMap[x] == opIndexMap[scOp]) + successorWriteOps.emplace_back(cast(x)); + + for (auto successorWriteOp : successorWriteOps) + iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + + SmallVector inductionVars; + getRewriter()->setInsertionPoint(scOp); + const size_t groupStep = getVectorBasedFusion().getGroupMaxSteps()[grpIdx]; + + bool isSourceMultiple = + sourceType.getShape()[sourceType.getRank() - 1] % groupStep == 0; + bool isDestMultiple = + destType.getShape()[destType.getRank() - 1] % groupStep == 0; + + scf::ForOp forOp; + bool canVectorizedLoadStore = isDestMultiple and isSourceMultiple and + scCanonicalizer.isReadWriteOnLastDim(); + if (canVectorizedLoadStore) { + forOp = generateShapeCastReadWriteLoop( + b, grpIdx, 0, groupStep, scOp.getLoc(), inductionVars, iterArgs); + } else { + // scalar data movement + forOp = generateShapeCastReadWriteLoop(b, grpIdx, 0, 1, scOp.getLoc(), + inductionVars, iterArgs); + } + for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) + getRewriter()->replaceOp(successorWriteOp, forOp->getResults()[idx]); + + getRewriter()->eraseOp(scOp); + clearCurrentOperationGroup(grpIdx); + return forOp; +} + +/// mark which operation need to set correct for loop var idx +/// due to sometimes we need to chage for loop order like reduce operation. +void ForLoopGenerator::getCurrentGroupIndiceLoopMap( + DenseMap> &indiceLoopMap, + const size_t groupId, Operation *op, + const DenseMap &setIdxMap) { + if (setIdxMap.empty()) { + DenseMap forIdxMap; + VectorType groupVector = + getVectorBasedFusion().getGroupBiggestRankVectorType()[groupId]; + for (size_t i = 0; (int64_t)i < groupVector.getRank(); i++) { + forIdxMap[i] = i; + } + indiceLoopMap[op] = forIdxMap; + return; + } + indiceLoopMap[op] = setIdxMap; +} + +void ForLoopGenerator::clearCurrentOperationGroup(size_t grpIdx) { + std::queue().swap(getVectorBasedFusion().getOpGroups()[grpIdx]); +}; + +scf::ForOp LoopGeneratorImpl::generateTransposeForLoop(const size_t grpIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); + // transpose rank must bigger than 2 + TransposeCanonicalizer &tpCanonicalizer = + getTransposeCanonicalizers()[grpIdx]; + vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + + VectorType vtType = tpOp.getResultVectorType(); + size_t rank = vtType.getRank(); + if (rank < 2) { + llvm::llvm_unreachable_internal( + "Wrong transpose operation appear. It's rank must bigger than 2."); + return nullptr; + } + + // permutation contains last dim can use optimizing algorithm + ArrayRef permutation = tpOp.getPermutation(); + DenseSet permuteSet(permutation.begin(), permutation.end()); + bool isTwoDTranspose = tpCanonicalizer.isTwoDTranspose(); + + Operation *successorWriteOp = + getVectorBasedFusion() + .getNextTargetOperationInCurrentGroup( + tpOp, grpIdx); + + DenseMap operandIdxMap; + DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, + forResultOrignalResultMap; + SmallVector iterArgs; + GenerateLoopHelper loopHelper(grpIdx, 0); + prepareForLoopArgs(grpIdx, loopHelper); + + getRewriter()->setInsertionPoint(tpOp); + int tpStep = TransposeCanonicalizer::TRANSPOSE_KERNEL::KERNEL_16X16; + // only contains last dim can use fast transpose algorithm + if ((tpCanonicalizer.getFirstTpIdx() == (rank - 1) or + tpCanonicalizer.getSecondTpIdx() == (rank - 1)) and + isTwoDTranspose) { + scf::ForOp forOp = generateTransposeForLoopWithLastDim( + *getRewriter(), tpStep, tpOp.getLoc(), successorWriteOp, loopHelper); + + getRewriter()->replaceOp(successorWriteOp, forOp); + // clear current group operation + clearCurrentOperationGroup(grpIdx); + return forOp; + } + DenseMap tpAxisMap; + size_t itrIdx = 0; + while (itrIdx < rank) { + tpAxisMap[itrIdx] = permutation[itrIdx]; + itrIdx++; + } + // scalar data movement + scf::ForOp forOp = generateTransposeScalarDataMovement( + *getRewriter(), tpOp.getLoc(), tpAxisMap, loopHelper); + + getRewriter()->replaceOp(successorWriteOp, forOp); + clearCurrentOperationGroup(grpIdx); + return forOp; +} + +template +SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { + return candidateRdOps; +}; + +void MultiReductionCanonicalizer::initReductionAxis() { + auto reductionAxisRange = getCandidateOps()[0].getReductionDims(); + reductionAxis.assign(reductionAxisRange.begin(), reductionAxisRange.end()); + llvm::sort(reductionAxis); +} + +void MultiReductionCanonicalizer::initParallelAxis() { + llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), + reductionAxis.end()); + for (int64_t i = 0; i < typeRank; ++i) + if (!reductionAxisSet.contains(i)) + parallelAxis.push_back(i); + + llvm::sort(parallelAxis); +} + +int64_t MultiReductionCanonicalizer::getTypeRank() { + auto srcRank = sourceType.getRank(); + typeRank = srcRank; + return srcRank; +} + +void MultiReductionCanonicalizer::getReductionAxisAndParallelAxis() { + initReductionAxis(); + initParallelAxis(); +} + +bool MultiReductionCanonicalizer::hasLastDimReduction() { + llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), + reductionAxis.end()); + bool res = false; + if (reductionAxisSet.contains(typeRank - 1)) + res = true; + + haslastDimReduction = res; + return res; +} + +void MultiReductionCanonicalizer::prepareSpecialInfo() { + if (getCandidateOps().empty()) + return; + + sourceType = getCandidateOps()[0].getSourceVectorType(); + accType = cast(getCandidateOps()[0].getAcc().getType()); + getTypeRank(); + getReductionAxisAndParallelAxis(); + hasLastDimReduction(); + + // whether all the reduction axis is 1 + for (auto axis : reductionAxis) { + if (sourceType.getShape()[axis] != 1) { + isEmptyReduction = false; + break; + } + } +}; + +bool TransposeCanonicalizer::isTransposeOnAllOneDim() { + vector::TransposeOp tpOp = getCandidateOps()[0]; + ArrayRef permutation = tpOp.getPermutation(); + VectorType tpVectorType = tpOp.getResultVectorType(); + int64_t itrIdx = 0; + while (itrIdx < tpVectorType.getRank()) { + if (itrIdx == permutation[itrIdx]) { + itrIdx++; + continue; + } + if (tpVectorType.getShape()[itrIdx] != 1) + return false; + + itrIdx++; + } + return true; +} + +bool TransposeCanonicalizer::isTwoDTranspose() { + ArrayRef permutation = getCandidateOps()[0].getPermutation(); + + size_t rank = permutation.size(); + int diffCount = 0; + // get the first transpose axis + size_t itrIdx = 0; + while (itrIdx < rank) { + if ((int64_t)itrIdx != permutation[itrIdx]) + diffCount += 1; + + itrIdx += 1; + } + + itrIdx = 0; + while (itrIdx < rank) { + if (permutation[itrIdx] != (int64_t)itrIdx) { + firstTpIdx = itrIdx; + break; + } + itrIdx++; + } + + itrIdx = 0; + // get the second transpose axis + while (itrIdx < rank) { + if (permutation[itrIdx] == (int64_t)firstTpIdx) { + secondTpIdx = itrIdx; + break; + } + itrIdx++; + } + + const int tpStep = TRANSPOSE_KERNEL::KERNEL_16X16; + VectorType vtType = getCandidateOps()[0].getResultVectorType(); + // currently we only support shape that is an integer multiple of tpStep + if (vtType.getShape()[getFirstTpIdx()] % tpStep != 0 or + vtType.getShape()[getSecondTpIdx()] % tpStep != 0) + return false; + + return diffCount == 2; +} + +bool TransposeCanonicalizer::transposeOnLastDim() { + ArrayRef permutation = getCandidateOps()[0].getPermutation(); + size_t rank = permutation.size(); + if (permutation[rank - 1] != (int64_t)rank - 1) + return false; + + VectorType vtType = getCandidateOps()[0].getResultVectorType(); + return vtType.getShape()[rank - 1] % getVectorStep() == 0; +} + +bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { + vector::ShapeCastOp &shapeCastOp = getCandidateOps()[0]; + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType destType = shapeCastOp.getResultVectorType(); + VectorType smallRankType = + sourceType.getRank() > destType.getRank() ? destType : sourceType; + VectorType largeRankType = + sourceType.getRank() < destType.getRank() ? destType : sourceType; + SmallVector visitedAxis(largeRankType.getRank(), false); + // Map the index of the larger rank shape to the index of the smaller rank + // shape. + DenseMap> shapeIdxMap; + for (size_t i = 0; (int64_t)i < smallRankType.getRank(); i++) + shapeIdxMap[i] = SmallVector(); + + int64_t itrIdx = 0; + while (itrIdx < smallRankType.getRank()) { + int64_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; + if (endShape >= largeRankType.getRank() or endShape < 0) + llvm_unreachable("Invalid endShape."); + + // skip non corresponding axis + // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> + while (largeRankType.getShape()[endShape] > + smallRankType.getShape()[itrIdx]) + endShape++; + + while (endShape < largeRankType.getRank()) { + visitedAxis[endShape] = true; + shapeIdxMap[itrIdx].emplace_back(endShape); + dimSize *= largeRankType.getShape()[endShape]; + if ((int64_t)dimSize == smallRankType.getShape()[itrIdx]) + break; + + endShape++; + } + itrIdx++; + } + // check if the last dim is read write + SmallVector lastDims = shapeIdxMap[smallRankType.getRank() - 1]; + DenseSet set(lastDims.begin(), lastDims.end()); + return set.contains(largeRankType.getRank() - 1); +} + +template +void addDummyInit(SmallVector &canonicalizer, size_t steps = 1) { + canonicalizer.emplace_back(T({}, steps)); +}; + +void LoopGeneratorImpl::clearSpecialOperationCanonicalizers() { + getMultiRdCanonicalizers().clear(); + getBroadcastCanonicalizers().clear(); + getTransposeCanonicalizers().clear(); + getShapeCastCanonicalizers().clear(); +} + +void LoopGeneratorImpl::dummyInitSpecialOperation(size_t steps) { + addDummyInit(getMultiRdCanonicalizers(), steps); + addDummyInit(getBroadcastCanonicalizers(), steps); + addDummyInit(getTransposeCanonicalizers(), steps); + addDummyInit(getShapeCastCanonicalizers(), steps); +} + +void LoopGeneratorImpl::initSpeicalOperationCanonicalizers() { + clearSpecialOperationCanonicalizers(); + SmallVector, 8> &opGroups = + getVectorBasedFusion().getOpGroups(); + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + dummyInitSpecialOperation(getVectorBasedFusion().getGroupMaxSteps()[idx]); + if (grp.empty()) + continue; + + std::queue tempQ(grp); + while (!tempQ.empty()) { + auto op = tempQ.front(); + tempQ.pop(); + TypeSwitch(op) + .Case([&](vector::MultiDimReductionOp + multiReductionOp) { + getMultiRdCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + getMultiRdCanonicalizers().back().prepareSpecialOperationInfo(); + }) + .Case([&](vector::TransposeOp tpOp) { + getTransposeCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + }) + .Case([&](vector::ShapeCastOp spOp) { + getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + }) + .Default([&](Operation *op) {}); + } + } +} + +template +void LoopGeneratorImpl::processSpecialOperation( + T &canonicalizers, const std::function &generateFunc) { + for (auto [groupId, canonicalizer] : llvm::enumerate(canonicalizers)) { + SmallVector &ops = canonicalizer.getCandidateOps(); + if (!ops.empty()) + // generate MultiReduction for loops + generateFunc(groupId); + } +} + +void LoopGeneratorImpl::canonicalizeSpecialOperation() { + + initSpeicalOperationCanonicalizers(); + // traverse all groups + llvm::SmallVector &multiRdCanonicalizers = + getMultiRdCanonicalizers(); + processSpecialOperation, + vector::MultiDimReductionOp>( + multiRdCanonicalizers, [this](const size_t grpIdx) { + (void)generateMultiReductionForLoop(grpIdx); + }); + // generate loop for transpose operation + SmallVector &transposeCanonicalizers = + getTransposeCanonicalizers(); + processSpecialOperation, + vector::TransposeOp>( + transposeCanonicalizers, + [this](const size_t grpIdx) { (void)generateTransposeForLoop(grpIdx); }); + // generate loop for shapecast opearation + SmallVector &shapeCastCanonicalizers = + getShapeCastCanonicalizers(); + processSpecialOperation, + vector::ShapeCastOp>( + shapeCastCanonicalizers, + [this](const size_t grpIdx) { (void)generateShapeCastForLoop(grpIdx); }); +} + +void VectorOperationCanonicalizer::run() { + auto &fusionStrategy = fusion.getGroupOperationFusion(); + if (kind == CanonicalizerKind::GroupOperations) { + fusion.run(); + // 1. Analysis the operation's operands and results + // We need to analyze which operation's result is needed by other + // operations, and we need to pass these results correctly. Mapping the + // operation result value with the forloop yeild result value. We can + // replace the operation operand as: map(operand, forloop yield result) -> + // operand = loop yield result We put all the operation result into this + // map. + + // 1.a. Find results which should be generated by current group for + // using as operands to other operations? + + // Traverse all operations. If the operand of operations in other groups + // or outside the group is the result of the operation in current group, + // then the current operation needs to generate a result. We use `setvector` + // to save the results that need to be generated by the current group. + + // 1.b. What operands are needed to find in the current group, and where + // can they be obtained ? + + // Thanks to 1.a, we get the result generated by the operations of + // each group, and this result will use `scf.yield` to generate a + // new result. Since the scope of the parent block of mlir is covered + // the current operation, the current operation does not need to pass + // these `for loop result` to the `iterArgs` of the required `for loop`. + // It only needs to replace the operand of the current operation with the + // corresponding `for loop yield result`. + + // However, for some operations that are not DPS, we need to canonicalize + // them. Canonicalization means that the operand of this operation is a + // vector but we can't get this vector due to it locates in another block + // which has a different scope. Therefore, it is necessary to write the + // vector results into a temporary tensor to save it. Then the vector + // needs to be read from the tensor before the current operation operate + // on it. Therefore, `empty tensor`, `transfer_write` and `transfer_read` + // need to be inserted at target place. + if (enableDebugPrinter) { + printGroupOps(fusion.getGroupOperationFusion().getOpGroups()); + LDBG("___________ before analysis ________________"); + } + fusion.canonicalizeEachOperationGroup(); + if (enableDebugPrinter) { + LDBG("___________ after analysis ________________"); + printGroupOps(fusion.getGroupOperationFusion().getOpGroups()); + } + + loopGenerator.setVectorBaseFusion(fusion.getGroupOperationFusion()); + // Speical Operation Canonicalization + loopGenerator.canonicalizeSpecialOperation(); + + // 2.Generate vectorized IR for each operation group + for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) + loopGenerator.generateGroupOpVectorizedIR(idx); + + // 3. Some IR cleanup work + DominanceInfo domInfo; + eliminateCommonSubExpressions( + *getRewriter(), domInfo, + loopGenerator.getVectorBasedFusion().getFunction()); + } else { + // TODO: need to add directly canonicalize operations logic + llvm_unreachable("Currently not support directly canonicalize operations."); + } +} + +/// +void ForLoopGenerator::setOperationCorrectOperand( + Operation *op, const DenseMap &opPermuationMap, + GenerateLoopHelper &loopHelperParam) { + for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { + if (not loopHelperParam.originalOperandLoopArgsMap.contains(opd)) + continue; + + Value loopArg = loopHelperParam.originalOperandLoopArgsMap[opd]; + if (not loopHelperParam.currentLoopStateIdxMap.contains(loopArg)) + continue; + + op->setOperand( + idx, + loopHelperParam + .loopIterArgs[loopHelperParam.currentLoopStateIdxMap.at(loopArg)]); + } + int operandOffset = isa(op) ? 2 : 1; + if (isReadOrWriteOperation(op)) { + if (not opPermuationMap.contains(op)) + llvm_unreachable("Map must contains operation."); + + auto permutationMap = opPermuationMap.at(op); + + auto dimExpr = permutationMap.getResults(); + for (auto [idx, x] : llvm::enumerate(dimExpr)) { + + if (not isa(x)) + llvm::llvm_unreachable_internal( + "Permuatation map must contains dim expr."); + + int64_t dim = 0; + if (auto d = dyn_cast(x)) { + dim = d.getPosition(); + } else if (auto d = dyn_cast(x)) { + dim = d.getValue(); + } + + ShapedType tensorType = + cast(op->getOperandTypes()[operandOffset - 1]); + int64_t varIdx = dim; + if (tensorType.getRank() > + (int64_t)loopHelperParam.inductionVars.size()) { + int64_t tensorOffset = + tensorType.getRank() - loopHelperParam.inductionVars.size(); + if (dim < tensorOffset) + continue; + + varIdx = dim - tensorOffset; + } + if (loopHelperParam.indiceLoopMap.contains(op)) + op->setOperand( + dim + operandOffset, + loopHelperParam + .inductionVars[loopHelperParam.indiceLoopMap[op][varIdx]]); + else + op->setOperand(dim + operandOffset, + loopHelperParam.inductionVars[varIdx]); + } + if (auto readOp = dyn_cast(op)) { + size_t grpIdx = getVectorBasedFusion().getOpGroupIndexMap()[op]; + VectorType loopType = + getVectorBasedFusion().getGroupBiggestRankVectorType()[grpIdx]; + SmallVector readIndices(readOp.getIndices().begin(), + readOp.getIndices().end()); + rectifyReadOperationIndice(&readOp, loopType, + loopHelperParam.inductionVars, readIndices); + readOp.getIndicesMutable().assign(readIndices); + } + } +} + +scf::ForOp ForLoopGenerator::constructNestedForOp( + const size_t groupIdx, OpBuilder &b, const Location &loc, + ArrayRef dims, GenerateLoopHelper &loopHelper) { + OpBuilder::InsertionGuard g(b); + const int loop_step = getVectorBasedFusion().getGroupMaxSteps()[groupIdx]; + // loop initialization variable + auto zero = makeIndexArithConstantOp(b, loc, 0); + auto forSteps = makeIndexArithConstantOp( + b, loc, loopHelper.anchorIdx == dims.size() - 1 ? loop_step : 1); + auto numIter = makeIndexArithConstantOp(b, loc, dims[loopHelper.anchorIdx]); + + // Create a loop and move vectorized operation into loops. + auto forOp = b.create( + loc, zero, numIter, forSteps, loopHelper.loopIterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + loopHelper.inductionVars.emplace_back(iv); + + // inner most body of the loop + if (loopHelper.anchorIdx == dims.size() - 1) { + std::queue &opQueue = + getVectorBasedFusion().getOpGroups()[groupIdx]; + loopHelper.loopIterArgs = loopState; + // 1. get operations in current anchor position + std::queue movingOperation; + getOperationInCurrentAnchor(loopHelper.anchorIdx, opQueue, + movingOperation); + + // 2. rewrite operation as vectorize IR + rewriteOperationAsVectorize(b, groupIdx, &movingOperation); + + // 3. move opeartions to current for block + moveOperationsToCurrentForBody(b, movingOperation, loopHelper); + + getResultInCurrentOps(loopHelper.anchorIdx, groupIdx, movingOperation, + loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, + loopHelper.nextAnchorResultOrignalResultMap); + maybeYieldValue(b, loc, loopHelper.nextAnchorResults); + } else { + // outter loop + + // 1. move pre-Op to current body + DenseMap nextAnchorArgsIdxMap; + SmallVector nextAnchorArgs; + DenseMap currentOriginalOperandMap = + loopHelper.originalOperandLoopArgsMap; + DenseMap currentOperandOriginalMap = + loopHelper.loopArgsOriginalOperandMap; + DenseMap currentArgsIdxMap = + loopHelper.currentLoopStateIdxMap; + + std::queue movedQueue; + std::queue &opQueue = + getVectorBasedFusion().getOpGroups()[groupIdx]; + SmallVector tmpArgs(loopState); + loopHelper.updateDataBeforePreOpMove(tmpArgs, opQueue, movedQueue); + movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, + loopHelper); + loopHelper.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, + nextAnchorArgs); + loopHelper.anchorIdx += 1; + auto nxtFor = + constructNestedForOp(groupIdx, b, loc, dims, loopHelper); + loopHelper.anchorIdx -= 1; + SmallVector currentArgs(loopState); + + loopHelper.updateCurrentArgsStatus(currentArgsIdxMap, currentArgs, + currentOriginalOperandMap, + currentOperandOriginalMap); + + loopHelper.updateDataBeforePostOpMove( + tmpArgs, currentArgsIdxMap, currentOriginalOperandMap, + currentOperandOriginalMap, nxtFor->getResults(), b.getBlock(), + movedQueue, loopHelper.anchorIdx); + movePostOpToCurrentAnchor(b, loopHelper); + + generateLoopResults(b, loc, loopHelper, nextAnchorArgsIdxMap); + + maybeYieldValue(b, loc, loopHelper.nextAnchorResults); + } + }); + return forOp; +} + +Value setOutGroupOperationOperandResult(Operation *op, + const VectorType &newOperandType, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + auto ret = + TypeSwitch(op) + .Case([&](arith::ConstantOp constantOp) { + rewriter.setInsertionPointAfter(op); + Type resultElementType = newOperandType.getElementType(); + auto value = constantOp.getValue(); + Attribute initValueAttr; + + if (isa(value)) { + auto valueType = mlir::dyn_cast(value); + if (valueType.isSplat()) { + if (isa(valueType.getElementType())) + initValueAttr = FloatAttr::get( + resultElementType, + valueType.getSplatValue().convertToDouble()); + else + initValueAttr = IntegerAttr::get( + resultElementType, + valueType.getSplatValue().getSExtValue()); + } else { + // write original vector into tensor + // then we transfer_read from the tensor + llvm_unreachable("Not support non-splat constant value."); + } + } else if (isa(resultElementType)) { + initValueAttr = FloatAttr::get( + resultElementType, cast(value).getValueAsDouble()); + } else { + initValueAttr = IntegerAttr::get( + resultElementType, cast(value).getInt()); + } + + auto cntOp = rewriter.create( + rewriter.getUnknownLoc(), + DenseElementsAttr::get(newOperandType, {initValueAttr})); + return cntOp->getResults()[0]; + }) + .Default([&](Operation *op) { return Value(); }); + return ret; +} + +void setOperationOperandResult(Operation *op, const VectorType &newOperandType, + const DenseMap &opMap, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + for (auto [idx, x] : llvm::enumerate(op->getOperands())) { + if (dyn_cast(x.getType())) { + if (!opMap.contains(x.getDefiningOp())) { + auto result = setOutGroupOperationOperandResult( + x.getDefiningOp(), newOperandType, rewriter); + op->setOperand(idx, result); + } else { + x.setType(newOperandType); + } + } + } + for (auto x : op->getResults()) + if (dyn_cast(x.getType())) + x.setType(newOperandType); +}; + +void ForLoopGenerator::createNewConstantOp( + Operation *srcOp, vector::TransferWriteOp *transferWriteOp, + size_t groupSteps) { + IRRewriter &rewriter = *getRewriter(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(srcOp); + DenseMap &opPermuationMap = + getVectorBasedFusion().getOpPermuationMap(); + + VectorType newOperandType = + getVectorBasedFusion().getTypeHelper().getVectorzedType( + cast(srcOp), groupSteps); + auto srcConstantOp = dyn_cast(srcOp); + Operation *newConstantOp; + if (isa(srcConstantOp.getValue())) { + auto valueType = dyn_cast(srcConstantOp.getValue()); + if (valueType.isSplat()) { + FailureOr res = createArithSplatConstantOp( + rewriter, srcOp->getLoc(), valueType, newOperandType); + if (failed(res)) { + llvm_unreachable("Wrong to create constant op."); + } + newConstantOp = res.value().getDefiningOp(); + } else { + // TODO: need to test not splat value + llvm_unreachable("Can't support not splat constant value."); + } + + newConstantOp->getResult(0).setType(newOperandType); + transferWriteOp->setOperand(0, newConstantOp->getResult(0)); + opPermuationMap.insert( + {*transferWriteOp, transferWriteOp->getPermutationMap()}); + setOpVectorizationPermutationMap( + *transferWriteOp, *getRewriter(), + cast(transferWriteOp->getResults()[0].getType()), + transferWriteOp->getPermutationMap()); + return; + } + llvm_unreachable("Can't support not DenseElementsAttr constant."); +} + +/// Rewrite the operations in the group to vectorized form. +void ForLoopGenerator::rewriteOperationAsVectorize( + OpBuilder &rewriter, size_t groupId, const std::queue *queue, + const size_t vectorizeStep) { + const std::queue groupOps = + !queue ? getVectorBasedFusion().getOpGroups()[groupId] : *queue; + + const DenseMap &opMap = + getVectorBasedFusion().getOpGroupIndexMap(); + DenseMap &opPermuationMap = + getVectorBasedFusion().getOpPermuationMap(); + std::queue transformQueue(groupOps); + size_t groupSteps = vectorizeStep == 0 + ? getVectorBasedFusion().getGroupMaxSteps()[groupId] + : vectorizeStep; + + while (!transformQueue.empty()) { + Operation *op = transformQueue.front(); + transformQueue.pop(); + VectorType newOperandType = + getVectorBasedFusion().getTypeHelper().getVectorzedType(op, groupSteps); + auto lowerResult = + TypeSwitch(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) { + Operation *srcOp = + transferWriteOp->getOperand(0).getDefiningOp(); + if (isa(srcOp)) { + createNewConstantOp(srcOp, &transferWriteOp, groupSteps); + } else { + opPermuationMap.insert( + {transferWriteOp, transferWriteOp.getPermutationMap()}); + transferWriteOp->getOperand(0).setType(newOperandType); + + setOpVectorizationPermutationMap( + transferWriteOp, rewriter, + cast( + transferWriteOp->getResult(0).getType()), + transferWriteOp.getPermutationMap()); + } + + return success(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) { + opPermuationMap.insert( + {transferReadOp, transferReadOp.getPermutationMap()}); + transferReadOp->getResult(0).setType(newOperandType); + setOpVectorizationPermutationMap( + transferReadOp, rewriter, + cast(transferReadOp.getSource().getType()), + transferReadOp.getPermutationMap()); + + return success(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + multiReductionOp.dump(); + llvm_unreachable("It should not appear this operation."); + return failure(); + }) + .Case([&](Operation *extfOp) { + extfOp->getResult(0).setType(newOperandType); + return success(); + }) + .Default([&](Operation *op) { + if (isSpecialOp(op)) { + op->dump(); + llvm_unreachable("It should not appear this operation."); + return failure(); + } + setOperationOperandResult(op, newOperandType, opMap, + *getRewriter()); + return success(); + }); + if (failed(lowerResult)) { + LDBG("Failed to rewrite operation: " << *op << "\n"); + llvm_unreachable("Failed to rewrite operation"); + } + } +} + +void GroupOperationFusionImpl::removeOpInCurrentGroups(size_t grpIdx, + Operation *op, + Operation *replacedOp) { + std::queue tmpOpQueue( + getGroupOperationFusion().getOpGroups()[grpIdx]); + std::queue newOpQueue; + while (!tmpOpQueue.empty()) { + auto curOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + if (curOp != op) { + newOpQueue.push(curOp); + continue; + } + getGroupOperationFusion().getOpGroupIndexMap().erase(curOp); + getGroupOperationFusion().getOpAnchorPos().erase(curOp); + } + getGroupOperationFusion().getOpGroups()[grpIdx] = newOpQueue; + + // erase and replace the operation + SmallVector usesOp(op->getUsers().begin(), op->getUsers().end()); + getRewriter()->replaceOp(op, replacedOp); +} + +void GroupOperationFusionImpl::updateOpGroupInfo(size_t grpIdx) { + std::queue tmpOpQueue( + getGroupOperationFusion().getOpGroups()[grpIdx]); + // dummy init + VectorType currentMaxRankType = + getOperationMaxVectorType(tmpOpQueue.front()).value(); + getGroupOperationFusion().getGroupBiggestRankVectorType()[grpIdx] = + currentMaxRankType; + + while (!tmpOpQueue.empty()) { + auto curOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + VectorType type = getOperationMaxVectorType(curOp).value(); + if (type.getRank() > currentMaxRankType.getRank()) + getGroupOperationFusion().getGroupBiggestRankVectorType()[grpIdx] = type; + } +} + +void GroupOperationFusionImpl::updateOpOperandResultInGroups( + size_t opGid, Operation *op, const Value &init, const Value &result) { + std::queue tmpOpQueue( + getGroupOperationFusion().getOpGroups()[opGid]); + std::queue newOpQueue; + while (!tmpOpQueue.empty()) { + auto curOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + + if (curOp != op) { + newOpQueue.push(curOp); + continue; + } + + if (!failed(getOperationVectorType(init.getDefiningOp()))) { + newOpQueue.push(init.getDefiningOp()); + getGroupOperationFusion().getOpGroupIndexMap()[init.getDefiningOp()] = + opGid; + getGroupOperationFusion().getOpAnchorPos()[init.getDefiningOp()] = + getGroupOperationFusion().getOpAnchorPos()[op]; + } + newOpQueue.push(op); + + if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { + newOpQueue.push(result.getDefiningOp()); + getGroupOperationFusion().getOpGroupIndexMap()[result.getDefiningOp()] = + opGid; + getGroupOperationFusion().getOpAnchorPos()[result.getDefiningOp()] = + getGroupOperationFusion().getOpGroupIndexMap()[op]; + } + } + getGroupOperationFusion().getOpGroups()[opGid] = newOpQueue; +} + +void GroupOperationFusionImpl::generateEmptyTensorAndWrite( + Operation *sourceOp, + DenseMap> &srcOpCanoniclizedMap, + size_t anchorPos, ReturnTypeKind retKind, + DenseMap &visitedOperation) { + DenseMap &opGroupIndexMap = + getGroupOperationFusion().getOpGroupIndexMap(); + SmallVector, 8> &groupOpInitArgs = + getGroupOperationFusion().getGroupOpInitArgs(); + SmallVector>, 8> + &groupOpResults = getGroupOperationFusion().getGroupOpResults(); + size_t sourceOpGid = opGroupIndexMap[sourceOp]; + + auto [tsr, writeOpresult] = + canonicalizeSourceOperation(sourceOp, visitedOperation, *getRewriter()); + auto writeOp = writeOpresult.getDefiningOp(); + srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); + updateOpOperandResultInGroups(sourceOpGid, sourceOp, tsr, writeOpresult); + groupOpInitArgs[sourceOpGid].insert(tsr); + groupOpResults[sourceOpGid].insert({writeOpresult, {retKind, anchorPos}}); + // write opeartion anchor pos is same with current operation + getGroupOperationFusion().getOpAnchorPos()[writeOp] = + writeOp.getVectorType().getRank() - 1; + getGroupOperationFusion().getOpPermuationMap()[writeOp] = + writeOp.getPermutationMap(); +} + +void GroupOperationFusionImpl::specialOperationRectify( + DenseMap &visitedOperation) { + auto &opGroups = getGroupOperationFusion().getOpGroups(); + + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + std::queue tmpQueue(grp); + std::queue newQueue; + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + // remain transfer read operation to do the broadcast fusion + if (isa(op)) { + auto srcOp = op->getOperand(0).getDefiningOp(); + if (not isa(srcOp)) + llvm_unreachable("Must be read operation."); + + // only have write operation, otherwise the group size will bigger + // than 1. Because the last operation is always a write operation in + // each group + getGroupOperationFusion().getOpAnchorPos()[srcOp] = + getGroupOperationFusion().getOpAnchorPos()[op]; + + getRewriter()->replaceOp(op, srcOp); + continue; + } + // anchor of multidim reduction rectify + if (isa(op)) { + auto accSourceOp = op->getOperand(1).getDefiningOp(); + getGroupOperationFusion().getOpAnchorPos()[accSourceOp] = + getOperationVectorType(accSourceOp)->getRank() - 1; + } + newQueue.push(op); + } + getGroupOperationFusion().getOpGroups()[idx] = newQueue; + } +} + +void GroupOperationFusionImpl::updateReturnResultKind(Operation *sourceOp, + size_t sourceOpGid, + ReturnTypeKind rtKind) { + SmallVector>, 8> + &groupOpResults = getGroupOperationFusion().getGroupOpResults(); + DenseMap &OpAnchorPos = + getGroupOperationFusion().getOpAnchorPos(); + + Value sourceResult = sourceOp->getResults()[0]; + if (srcOpCanoniclizedMap.contains(sourceOp)) + sourceResult = srcOpCanoniclizedMap[sourceOp].second; + + size_t srcOpAnchor = groupOpResults[sourceOpGid][sourceResult].second; + ReturnTypeKind prevRtKind = groupOpResults[sourceOpGid][sourceResult].first; + srcOpAnchor = std::min(srcOpAnchor, OpAnchorPos[sourceOp]); + if (prevRtKind != rtKind) { + groupOpResults[sourceOpGid][sourceResult] = + std::make_pair(ReturnTypeKind::RT_Both, srcOpAnchor); + return; + } + if (rtKind == ReturnTypeKind::RT_InGroup) + groupOpResults[sourceOpGid][sourceResult] = + std::make_pair(rtKind, srcOpAnchor); +} + +void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, + Operation *sourceOp, + size_t operandIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); + DenseMap &opGroupIndexMap = + getGroupOperationFusion().getOpGroupIndexMap(); + if (!opGroupIndexMap.contains(op)) { + return; + } + // TODO: add more operation to this case, write a constant value need + // to do this + if (isa(op) and operandIdx == 0) + return; + + if (isa(op)) { + if (operandIdx == 1) { + // accumulate value, just empty tensor is okay + auto resultTensor = + getOperationResultTensor(sourceOp, visitedOperation, *getRewriter()); + auto opInit = canonicalizeCurrentOperation(op, resultTensor, operandIdx); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + return; + } + // source operation is the value + llvm::llvm_unreachable_internal( + "Need to add reduce constant operation optimization."); + } + + auto constantOp = cast(sourceOp); + getRewriter()->setInsertionPoint(constantOp); + size_t groupSteps = + getGroupOperationFusion().getGroupMaxSteps()[opGroupIndexMap[op]]; + + if (isa(constantOp.getValue())) { + VectorType newOperandType = + getGroupOperationFusion().getTypeHelper().getVectorzedType(op, + groupSteps); + auto valueType = cast(constantOp.getValue()); + if (valueType.isSplat()) { + FailureOr res = createArithSplatConstantOp( + *getRewriter(), constantOp->getLoc(), valueType, newOperandType); + if (failed(res)) + llvm_unreachable("Wrong to create constant op."); + + op->setOperand(operandIdx, res.value()); + // transfer read operation just use the constant value to do + // calculation, don't need to read. + if (isa(op) and operandIdx == 0) + removeOpInCurrentGroups(opGroupIndexMap[op], op, + op->getOperand(0).getDefiningOp()); + return; + } + llvm_unreachable("Can't support not splat constant value."); + } +} + +void GroupOperationFusionImpl::makeSourceOpWriteResultToTensor( + Operation *sourceOp, size_t sourceOpGid, ReturnTypeKind rtKind) { + DenseMap &OpAnchorPos = + getGroupOperationFusion().getOpAnchorPos(); + SmallVector, 8> &groupOpInitArgs = + getGroupOperationFusion().getGroupOpInitArgs(); + SmallVector>, 8> + &groupOpResults = getGroupOperationFusion().getGroupOpResults(); + + if (!srcOpCanoniclizedMap.contains(sourceOp)) { + // get write operation + if (Operation *writeOp = + getGroupOperationFusion() + .getNextTargetOperationInCurrentGroup( + sourceOp, sourceOpGid)) { + auto writeOpresult = writeOp->getResults()[0]; + auto originalWriteTensor = writeOp->getOperands()[1]; + // find original tensor.empty operation + Value writeTensor = + findOriginalTensor(originalWriteTensor, sourceOp->getBlock()); + if (writeTensor != originalWriteTensor) + getGroupOperationFusion() + .getOperandOriginalValue()[originalWriteTensor] = writeTensor; + + srcOpCanoniclizedMap.insert({sourceOp, {writeTensor, writeOpresult}}); + groupOpInitArgs[sourceOpGid].insert(writeTensor); + groupOpResults[sourceOpGid].insert( + {writeOpresult, {rtKind, OpAnchorPos[sourceOp]}}); + return; + } + generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, + OpAnchorPos[sourceOp], rtKind, + visitedOperation); + return; + } + // udpate result return type + updateReturnResultKind(srcOpCanoniclizedMap[sourceOp].second.getDefiningOp(), + sourceOpGid, rtKind); +} + +void GroupOperationFusionImpl::GroupOperationReturnResultProcess( + size_t sourceOpGid, Operation *sourceOp, Operation *op, size_t operandIdx, + bool inSameGroupNeedReturn) { + ReturnTypeKind rtKind = inSameGroupNeedReturn ? ReturnTypeKind::RT_InGroup + : ReturnTypeKind::RT_OutGroup; + SmallVector, 8> &groupOpInitArgs = + getGroupOperationFusion().getGroupOpInitArgs(); + + DenseMap &opGroupIndexMap = + getGroupOperationFusion().getOpGroupIndexMap(); + // update init iterargs + auto dstRet = getOperationOperateTensor(sourceOp); + // need to generate tensor.emtpy and vector.transfer_write, write + // operand to tensor and read operand from the tensor, generate + // vector.transfer_read + if (failed(dstRet)) { + // already generate result tensor, special operation do the + // transformation by itself + if (isSpecialOp(sourceOp) and inSameGroupNeedReturn and + not isBroadcastOp(sourceOp)) + return; + makeSourceOpWriteResultToTensor(sourceOp, sourceOpGid, rtKind); + auto opInit = canonicalizeCurrentOperation( + op, srcOpCanoniclizedMap[sourceOp].second, operandIdx); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + return; + } + // if source operation is transfer_read, we need to generate a + // same transfer_read operation like source operation. + if (isa(sourceOp)) { + auto transferReadOp = cast(sourceOp); + auto opInit = canonicalizeCurrentOperation(op, dstRet.value(), operandIdx, + &transferReadOp); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + return; + } + // transfer write operation + groupOpInitArgs[sourceOpGid].insert(dstRet.value()); + auto writeTensor = sourceOp->getOperand(1); + if (dstRet.value() != writeTensor) + getGroupOperationFusion().getOperandOriginalValue()[writeTensor] = + dstRet.value(); + + updateReturnResultKind(sourceOp, sourceOpGid, rtKind); +} + +void GroupOperationFusionImpl::broadcastFromElements(Operation *op, + size_t grpIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); + if (not isa(op)) + llvm_unreachable("Must be broadcast operation."); + + if (not isa(op->getOperandTypes()[0])) { + auto inputBcastOp = cast(op); + size_t steps = getGroupOperationFusion().getGroupMaxSteps()[grpIdx]; + getRewriter()->setInsertionPoint(op); + VectorType newOperandType = + getGroupOperationFusion().getTypeHelper().getVectorzedType(op, steps); + if (isa_and_nonnull(op->getOperand(0).getDefiningOp())) { + auto constantOp = cast(op); + SmallVector shapes(1, steps); + auto dataType = mlir::VectorType::get( + shapes, inputBcastOp.getResultVectorType().getElementType()); + + FailureOr res = createArithSplatConstantOp( + *getRewriter(), op->getLoc(), + DenseElementsAttr::get(dataType, constantOp.getValue()), + newOperandType); + if (failed(res)) + llvm_unreachable("Wrong to create constant op."); + removeOpInCurrentGroups(grpIdx, op, res.value().getDefiningOp()); + + } else { + auto bcastOp = getRewriter()->create( + op->getLoc(), newOperandType, op->getOperands()[0]); + removeOpInCurrentGroups(grpIdx, op, bcastOp); + std::function candidateFunc = isBroadcastOp; + moveOpsFrontOrBack(&getGroupOperationFusion().getFunction(), + *getRewriter(), OPPRIORITY::THIRD, OPPRIORITY::THIRD); + } + } +} + +void GroupOperationFusionImpl::scalarOperandFromElements() { + SmallVector, 8> &opGroups = + getGroupOperationFusion().getOpGroups(); + size_t idx = 0; + for (auto &grp : opGroups) { + std::queue tmpQueue(grp); + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + TypeSwitch(op) + .Case([&](vector::BroadcastOp &bcOp) { + broadcastFromElements(bcOp, idx); + }) + .Default([&](Operation *op) { return; }); + } + idx++; + } +} + +void GroupOperationFusionImpl::canonicalizeEachOperationGroup() { + // record the operation which has been moved + DenseSet movedOperationSet; + // record the operation's visited order, inorder to ensure set + // correct operand + size_t opCounter = 0; + DenseMap &opGroupIndexMap = + getGroupOperationFusion().getOpGroupIndexMap(); + DenseMap &OpAnchorPos = + getGroupOperationFusion().getOpAnchorPos(); + func::FuncOp func = getGroupOperationFusion().getFunction(); + + analysisGroupMaxSteps(); + + func.walk([&](Operation *op) { + visitedOperation.insert({op, opCounter++}); + + for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { + Operation *sourceOp = opd.getDefiningOp(); + if (opGroupIndexMap.contains(sourceOp)) { + auto sourceOpGid = opGroupIndexMap[sourceOp]; + bool notInSameGroup = + opGroupIndexMap.contains(op) && sourceOpGid != opGroupIndexMap[op]; + bool outOfGroup = !opGroupIndexMap.contains(op); + // Different anchor in same group and source operation is in inner + // loop, we need to get source operation's result + bool inSameGroupNeedReturn = !outOfGroup and !notInSameGroup and + OpAnchorPos[sourceOp] > OpAnchorPos[op]; + + if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) + GroupOperationReturnResultProcess(sourceOpGid, sourceOp, op, idx, + inSameGroupNeedReturn); + + continue; + } + if (isa_and_nonnull(sourceOp)) + replaceConstantOpAsNewOp(op, sourceOp, idx); + } + }); + analysisEmptyGroup(); + scalarOperandFromElements(); + specialOperationRectify(visitedOperation); + LDBG("Complete analysis group operation results\n"); +} + +void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, + Value originalResult, + Value forResult) { + size_t totalGroupSize = getVectorBasedFusion().getOpGroups().size(); + size_t startGroup = currentGroupId; + DenseMap &operandOriginalMap = + getVectorBasedFusion().getOperandOriginalValue(); + if (operandOriginalMap.contains(originalResult)) + originalResult = operandOriginalMap[originalResult]; + while (startGroup < totalGroupSize) { + SetVector &operandVector = + getVectorBasedFusion().getGroupOpInitArgs()[startGroup++]; + if (not operandVector.contains(originalResult)) + continue; + SetVector replacedVector; + + for (auto v : operandVector) { + if (v == originalResult) { + replacedVector.insert(forResult); + continue; + } + replacedVector.insert(v); + } + getVectorBasedFusion().getGroupOpInitArgs()[startGroup - 1] = + replacedVector; + } +} + +mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( + const size_t groupId, IRRewriter &rewriter, VectorType vectorType) { + OpBuilder::InsertionGuard g(rewriter); + // prepare for loop iterargs + GenerateLoopHelper loopHelper(groupId, 0); + prepareForLoopArgs(groupId, loopHelper); + + ArrayRef shapes = vectorType.getShape(); + // generate for loop + auto forOp = constructNestedForOp(groupId, rewriter, rewriter.getUnknownLoc(), + shapes, loopHelper); + replaceOpUsersWithForLoopResult(forOp, groupId, loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, + loopHelper.nextAnchorResultOrignalResultMap); + + return forOp; +} + +bool LoopGeneratorImpl::isGroupHasSpecialOperation(const size_t grpIdx) { + auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; + auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; + auto &shapeCastCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; + return !rdCanonicalizer.getCandidateOps().empty() or + !tpCanonicalizer.getCandidateOps().empty() or + !shapeCastCanonicalizer.getCandidateOps().empty(); +} + +void LoopGeneratorImpl::generateGroupOpVectorizedIR(const int idx) { + OpBuilder::InsertionGuard g(*getRewriter()); + + auto &grp = getVectorBasedFusion().getOpGroups()[idx]; + if (grp.empty()) { + LDBG("Current operation Group is empty."); + return; + } + // TODO: special operation better fusion + if (isGroupHasSpecialOperation(idx)) + return; + + VectorType groupType = + getVectorBasedFusion().getGroupBiggestRankVectorType()[idx]; + // 1. Rewrite operation as vectorized form + // 2. Generate loop + getRewriter()->setInsertionPointAfter(grp.back()); + auto forOp = generateVectorizedForLoop(idx, *getRewriter(), groupType); + // special operation do not need to change anything + if (failed(forOp)) + return; + + moveLoopInvariantCode(forOp.value()); +} + +/// Pass that lower to physical vector. +struct CPUPhysicalRegisterPass + : public impl::CPUPhysicalRegisterPassBase { + + void runOnOperation() final { + // + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto func = getOperation(); + IRRewriter rewriter(func); + + if (hasNotSupportOperation(&func)) { + LDBG("Not support operation appears in current function."); + return; + } + // canonicalize vector operation, default use vector-based fusion + // strategy. + HardWareInfo hwInfo; + CPUTargetDescriptionAnalysis sysDesc = + getAnalysis(); + hwInfo.vectorWidth = sysDesc.getMaxVectorWidth(); + VectorOperationCanonicalizer canonicalizer( + func, hwInfo, &rewriter, CanonicalizerKind::GroupOperations); + + // affineApply operation is always used by other operations. + moveOpsFrontOrBack(&func, rewriter, OPPRIORITY::FIRST, OPPRIORITY::SECOND); + + canonicalizer.run(); + + moveOpsFrontOrBack(&func, rewriter, OPPRIORITY::LAST, OPPRIORITY::LAST); + + // transpose kernel + vector::VectorTransformsOptions transposeOptions = + vector::VectorTransformsOptions(); + transposeOptions.vectorTransposeLowering = + vector::VectorTransposeLowering::Shuffle16x16; + vector::populateVectorTransposeLoweringPatterns(patterns, transposeOptions); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 6fdc445cf..0ccbf0495 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -77,19 +77,20 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::OpPassManager &pm) { + pm.addNestedPass(createLowerToTileVector()); // Do promotion for math / arith ops pm.addNestedPass(math::createMathLegalizeToF32()); // sourceTypeStrs can be extended arith::ArithEmulateUnsupportedFloatsOptions options; - std::array typeStr = {"bf16"}; + std::array typeStr{"bf16"}; options.sourceTypeStrs = typeStr; options.targetTypeStr = "f32"; pm.addNestedPass( arith::createArithEmulateUnsupportedFloats(options)); // Bf16 cast elimilation pass pm.addNestedPass(mlir::createCanonicalizerPass()); - // oneDNN graph spec - pm.addNestedPass(arith::createArithExpandOpsPass()); + pm.addNestedPass(createCPUPhysicalRegisterPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); // todo: lower to physical vector pass, device dependent pass populateCleanUpPasses(pm); } @@ -150,6 +151,8 @@ void populateCPURuntimePasses(mlir::OpPassManager &pm) { } void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) { + pm.addPass(createConvertVectorToSCFPass()); + pm.addPass(createConvertVectorToLLVMPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createConvertSCFToCFPass()); diff --git a/lib/gc/Transforms/Utils/CMakeLists.txt b/lib/gc/Transforms/Utils/CMakeLists.txt index 94b700435..3b045fca0 100644 --- a/lib/gc/Transforms/Utils/CMakeLists.txt +++ b/lib/gc/Transforms/Utils/CMakeLists.txt @@ -2,6 +2,8 @@ gc_add_mlir_library(GcUtilsIR MatcherUtils.cpp StructuredOpMatcher.cpp ValueUtils.cpp + VectorUtils.cpp + NumericUtils.cpp DEPENDS MLIRLinalgDialect diff --git a/lib/gc/Transforms/Utils/NumericUtils.cpp b/lib/gc/Transforms/Utils/NumericUtils.cpp new file mode 100644 index 000000000..e1af31994 --- /dev/null +++ b/lib/gc/Transforms/Utils/NumericUtils.cpp @@ -0,0 +1,164 @@ +//===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Transforms/Utils/NumericUtils.h" + +namespace mlir { +namespace gc { + +const uint32_t kF32MantiBits = 23; +const uint32_t kF32HalfMantiBitDiff = 13; +const uint32_t kF32HalfBitDiff = 16; +const Float32Bits kF32Magic = {113 << kF32MantiBits}; +const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; +const uint32_t kF32BfMantiBitDiff = 16; + +/// Constructs the 16 bit representation for a half precision value from a float +/// value. This implementation is adapted from Eigen. +uint16_t float2half(float floatValue) { + const Float32Bits inf = {255 << kF32MantiBits}; + const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; + const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) + << kF32MantiBits}; + uint32_t signMask = 0x80000000u; + uint16_t halfValue = static_cast(0x0u); + Float32Bits f; + f.f = floatValue; + uint32_t sign = f.u & signMask; + f.u ^= sign; + + if (f.u >= f16max.u) { + const uint32_t halfQnan = 0x7e00; + const uint32_t halfInf = 0x7c00; + // Inf or NaN (all exponent bits set). + halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf + } else { + // (De)normalized number or zero. + if (f.u < kF32Magic.u) { + // The resulting FP16 is subnormal or zero. + // + // Use a magic value to align our 10 mantissa bits at the bottom of the + // float. As long as FP addition is round-to-nearest-even this works. + f.f += denormMagic.f; + + halfValue = static_cast(f.u - denormMagic.u); + } else { + uint32_t mantOdd = + (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. + + // Update exponent, rounding bias part 1. The following expressions are + // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + + // 0xfff`, but without arithmetic overflow. + f.u += 0xc8000fffU; + // Rounding bias part 2. + f.u += mantOdd; + halfValue = static_cast(f.u >> kF32HalfMantiBitDiff); + } + } + + halfValue |= static_cast(sign >> kF32HalfBitDiff); + return halfValue; +} + +/// Converts the 16 bit representation of a half precision value to a float +/// value. This implementation is adapted from Eigen. +float half2float(uint16_t halfValue) { + const uint32_t shiftedExp = + 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. + + // Initialize the float representation with the exponent/mantissa bits. + Float32Bits f = { + static_cast((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; + const uint32_t exp = shiftedExp & f.u; + f.u += kF32HalfExpAdjust; // Adjust the exponent + + // Handle exponent special cases. + if (exp == shiftedExp) { + // Inf/NaN + f.u += kF32HalfExpAdjust; + } else if (exp == 0) { + // Zero/Denormal? + f.u += 1 << kF32MantiBits; + f.f -= kF32Magic.f; + } + + f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. + return f.f; +} + +// Constructs the 16 bit representation for a bfloat value from a float value. +// This implementation is adapted from Eigen. +uint16_t float2bfloat(float floatValue) { + if (std::isnan(floatValue)) + return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; + + Float32Bits floatBits; + floatBits.f = floatValue; + uint16_t bfloatBits; + + // Least significant bit of resulting bfloat. + uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; + uint32_t roundingBias = 0x7fff + lsb; + floatBits.u += roundingBias; + bfloatBits = static_cast(floatBits.u >> kF32BfMantiBitDiff); + return bfloatBits; +} + +// Converts the 16 bit representation of a bfloat value to a float value. This +// implementation is adapted from Eigen. +float bfloat2float(uint16_t bfloatBits) { + Float32Bits floatBits; + floatBits.u = static_cast(bfloatBits) << kF32BfMantiBitDiff; + return floatBits.f; +} + +std::variant numeric_limits_minimum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return -std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)half2float( + float2half(-std::numeric_limits::infinity())); + } else if (t1.isSignedInteger(8)) { + return int64_t(-128); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::min()); + } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { + return int64_t(0); + } else { + llvm_unreachable("unsupported data type"); + return (int64_t)0; + } +} + +std::variant numericLimitsMaximum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)half2float( + float2half(std::numeric_limits::infinity())); + } else if (t1.isSignedInteger(8)) { + return int64_t(127); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else if (t1.isSignlessInteger(8)) { + return int64_t(255); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else { + llvm_unreachable("unsupported data type"); + return (int64_t)0; + } +} + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp new file mode 100644 index 000000000..2751b04e2 --- /dev/null +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -0,0 +1,361 @@ +//===- VectorUtils.cpp - vector utilities -----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Transforms/Utils/VectorUtils.h" +#include "mlir/Support/LLVM.h" +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "vector-utils" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") + +static inline void moveOpBeginingOfBlock(Operation *op, IRRewriter &rewriter) { + Block *block = op->getBlock(); + if (block->getOperations().empty()) + llvm_unreachable("Emtpy block."); + + if (&block->front() == op) + return; + rewriter.moveOpAfter(op, op->getBlock(), op->getBlock()->begin()); +} + +// Special behavior for ++OPPRIORITY +OPPRIORITY operator++(OPPRIORITY &c) { + using IntType = typename std::underlying_type::type; + c = static_cast(static_cast(c) + 1); + return c; +} + +LogicalResult moveFront(Operation *op, IRRewriter &rewriter) { + Operation *backOperation = nullptr; + // check all the operand is block argument + bool allBlockArgs = llvm::all_of(op->getOperands(), [](Value operand) { + return isa(operand); + }); + + if (allBlockArgs) { + moveOpBeginingOfBlock(op, rewriter); + return success(); + } + for (auto operand : op->getOperands()) { + if (isa(operand)) + continue; + + Operation *sourceOp = operand.getDefiningOp(); + if (sourceOp->getBlock() != op->getBlock()) + continue; + if (not backOperation) { + backOperation = sourceOp; + continue; + } + + if (backOperation->isBeforeInBlock(sourceOp)) + backOperation = sourceOp; + } + if (not backOperation) { + // extract operand operation all in previous block + moveOpBeginingOfBlock(op, rewriter); + return success(); + } + rewriter.moveOpAfter(op, backOperation); + return success(); +} + +LogicalResult moveBack(Operation *op, IRRewriter &rewriter) { + Operation *firstOperation = nullptr; + for (auto user : op->getUsers()) { + if (user->getBlock() != op->getBlock()) + continue; + if (not firstOperation) { + firstOperation = user; + continue; + } + if (user->isBeforeInBlock(firstOperation)) + firstOperation = user; + } + if (not firstOperation) { + // Don't move. + // TODO: need to consider move before the block which use it. + return success(); + } + rewriter.moveOpBefore(op, firstOperation); + return success(); +} + +void moveCandidateOperation( + std::queue> &candidateOps, + IRRewriter &rewriter, OPPRIORITY start, OPPRIORITY end) { + std::queue> remainOps; + OPPRIORITY itrBegin = start; + while (not remainOps.empty() or not candidateOps.empty()) { + while (not candidateOps.empty()) { + std::pair cur = candidateOps.front(); + candidateOps.pop(); + if (cur.second < start or cur.second > end) + continue; + if (cur.second != itrBegin) { + remainOps.push(cur); + continue; + } + + Operation *op = cur.first; + auto ret = + TypeSwitch(op) + .Case([&](affine::AffineApplyOp affineOp) { + return moveFront(op, rewriter); + }) + .Case( + [&](tensor::ExtractSliceOp extractOp) { + return moveFront(op, rewriter); + }) + .Case([&](tensor::EmptyOp emptyOp) { + return moveFront(op, rewriter); + }) + .Case([&](tensor::InsertSliceOp insertOp) { + return moveBack(op, rewriter); + }) + .Case([&](vector::TransferReadOp readOp) { + return moveFront(op, rewriter); + }) + .Case( + [&](vector::TransferWriteOp writeOp) { + return moveBack(op, rewriter); + }) + .Case([&](vector::BroadcastOp bcOp) { + return moveFront(op, rewriter); + }) + .Default([&](Operation *op) { return success(); }); + if (failed(ret)) { + LDBG("Wrong to move operation:" << *op << "\n"); + return; + } + } + candidateOps.swap(remainOps); + ++itrBegin; + } +} + +// Get operation priority +void getOperationPriority( + func::FuncOp *func, + std::queue> &candidateOps) { + // get the position of each operation + func->walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](auto op) { + candidateOps.push(std::make_pair(op, OPPRIORITY::FIRST)); + return; + }) + .Case( + [&](auto op) { + candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); + return; + }) + .Case([&](auto op) { + candidateOps.push(std::make_pair(op, OPPRIORITY::LAST)); + return; + }) + .Case([&](auto op) { + candidateOps.push(std::make_pair(op, OPPRIORITY::THIRD)); + return; + }) + .Default([&](Operation *op) { return; }); + }); +} + +// Need to move some operations like extract_slice or insert_slice. +// Because those operation may interpret our analysis result. e.g.: +// ``` +// clang-format off + // %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %22 = arith.addf %21, %20 : vector<16x16xf32> + // %23 = vector.transfer_write %22, %extracted_slice_12[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> + // %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> + // %extracted_slice_14 = tensor.extract_slice %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<32x16xf32> to tensor<16x16xf32> + // %24 = vector.transfer_read %cst_0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %25 = arith.maximumf %22, %24 : vector<16x16xf32> + // %26 = vector.transfer_write %25, %extracted_slice_14[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> + // %inserted_slice_15 = tensor.insert_slice %23 into %arg15[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> + // %inserted_slice_16 = tensor.insert_slice %26 into %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> +// clang-format on +// ``` +// The maximumf and addf operation can be a same group, but the extract_slice +// operation interpret us. +// The move operation(extra_slice) will check its parameters. In order to +// ensure that it does not affect the correctness of the result, we will only +// move the moved op after the op to which the parameters belong to. If it's +// operand is all the block argument, we will move it to the begining of the +// block. +// insert_slice just move them to the privious of the first operation which +// use it. +void moveOpsFrontOrBack(func::FuncOp *func, IRRewriter &rewriter, + OPPRIORITY start, OPPRIORITY end) { + // Pre-order traversal of each op + std::queue> candidateOps; + getOperationPriority(func, candidateOps); + moveCandidateOperation(candidateOps, rewriter, start, end); + // eliminate some useless operation + RewritePatternSet patterns(rewriter.getContext()); + (void)applyPatternsAndFoldGreedily(*func, std::move(patterns)); +} + +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp) { + if (not op) + return failure(); + + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = dyn_cast( + transferWriteOp.getOperandTypes()[0])) + return retType; + + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + if (isPrevOp) + return cast( + multiReductionOp->getResultTypes()[0]); + + // TODO: may need to add accumulate value vectortype + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (isPrevOp) { + if (op->getResultTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getResultTypes()[0])) + return shapedType; + + return failure(); + } + if (op->getOperandTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getOperandTypes()[0])) + return shapedType; + + return failure(); + }); + if (!failed(ret) and isDynamicType(ret.value())) + return failure(); + + return ret; +} + +mlir::FailureOr getOperationMaxVectorType(Operation *op) { + if (not op) + return failure(); + + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = + cast(transferWriteOp.getOperandTypes()[0])) + return retType; + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (op->getResultTypes().empty() and op->getOperandTypes().empty()) + return failure(); + + if (op->getResultTypes().empty() or + not isa(op->getResultTypes()[0])) + return cast(op->getOperandTypes()[0]); + + if (op->getOperandTypes().empty() or + not isa(op->getOperandTypes()[0])) + return cast(op->getResultTypes()[0]); + + auto opdType = cast(op->getOperandTypes()[0]); + auto retType = cast(op->getResultTypes()[0]); + return opdType.getRank() > retType.getRank() ? opdType : retType; + }); + if (!failed(ret) and isDynamicType(ret.value())) + return failure(); + + return ret; +} + +int getNearestVectorStep(const int step) { + if (step <= 0) + llvm_unreachable("Wrong step."); + + int nbits = 0, n = step; + while (n) { + n = n >> 1; + nbits++; + } + if (nbits > 6 and (nbits != 7 or step != 64)) + llvm_unreachable("wrong nbits appear"); + return (1 << (nbits - 1)) == step ? step : (1 << nbits); +} + +Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, + int64_t x) { + return opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), x)); +} + +Value findOriginalTensor(Value writeTensor, Block *block) { + while (auto wtOp = dyn_cast_or_null( + writeTensor.getDefiningOp())) { + if (block != writeTensor.getDefiningOp()->getBlock()) + break; + + writeTensor = wtOp->getOperand(1); + } + return writeTensor; +} + +mlir::FailureOr getOperationOperateTensor(Operation *op) { + return TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) { + // find original tensor.empty operation + auto writeTensor = transferWriteOp->getOperand(1); + writeTensor = + findOriginalTensor(writeTensor, transferWriteOp->getBlock()); + return writeTensor; + }) + .Case([&](vector::TransferReadOp transferReadOp) { + return transferReadOp->getOperand(0); + }) + .Default([&](Operation *op) { return failure(); }); +} + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/test/mlir/test/gc/Transforms/cpu-physical-register.mlir b/test/mlir/test/gc/Transforms/cpu-physical-register.mlir new file mode 100644 index 000000000..5a97a3509 --- /dev/null +++ b/test/mlir/test/gc/Transforms/cpu-physical-register.mlir @@ -0,0 +1,710 @@ +// RUN: gc-opt %s --split-input-file --fold-tensor-operation --lower-to-tile-vector --CPU-physical-register-pass | FileCheck %s + + +// CHECK-DAG: #[[map0:.*]] = affine_map<()[s0, s1] -> (s0 * 64 + s1)> +// CHECK-DAG: #[[map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[map2:.*]] = affine_map<(d0) -> (d0 * 4)> +// CHECK-DAG: #[[map3:.*]] = affine_map<(d0) -> (d0 * 128)> +// CHECK-DAG: #[[map4:.*]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: #[[map5:.*]] = affine_map<(d0, d1) -> (d0 floordiv 32 + d1 floordiv 32)> +// CHECK-DAG: #[[map6:.*]] = affine_map<(d0, d1) -> (d0 floordiv 16 + d1 floordiv 16)> +// CHECK-DAG: #[[map7:.*]] = affine_map<()[s0, s1] -> (s0 * 32 + s1)> +// CHECK-DAG: #[[map8:.*]] = affine_map<()[s0, s1] -> (s0 * 16 + s1)> +// CHECK-DAG: #[[map9:.*]] = affine_map<(d0, d1) -> (d0 + d1)> + + + +// CHECK-LABEL: func @add_tensor_test0 +// CHECK: %[[C4096:.*]] = arith.constant 4096 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C11008:.*]] = arith.constant 11008 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[TENSOR0:.*]] = tensor.empty() : tensor<11008x4096xf32> +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> +func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { + %0 = tensor.empty() : tensor<11008x4096xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> + %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> + return %2 : tensor<11008x4096xf32> +} + +// CHECK-LABEL: func @reduce_keepdimtest1 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x64xf32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<16x1x64xf32> +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], {{.*}} : vector<16xf32> +// CHECK: scf.yield +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +// CHECK: scf.yield +// CHECK: scf.for %[[arg1:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg2:.*]] = %[[EMPTY1]]) -> (tensor<16x1x64xf32>) +// CHECK: scf.for %[[arg3:.*]] = %[[C0]] to %[[C1]] step %[[C1]] iter_args(%[[arg4:.*]] = %[[arg2]]) -> (tensor<16x1x64xf32>) +// CHECK: scf.for %[[arg5:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg6:.*]] = %[[arg4]]) -> (tensor<16x1x64xf32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map0]]()[%[[arg3]], %[[arg5]]] +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x1x64xf32> +func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { + %0 = tensor.empty() : tensor<16x64xf32> + %reduce = linalg.reduce + ins(%arg0:tensor<16x32x64xf32>) + outs(%0:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %1 = arith.addf %out, %in: f32 + linalg.yield %1: f32 + } + %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> + return %2 : tensor<16x1x64xf32> +} + +// CHECK-LABEL: func @fc_relu_test2 +// CHECK: %[[MATMUL:.*]] = linalg.matmul +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> +func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK-LABEL: func @matmul_add_test3 +// CHECK: %[[MATMUL0:.*]] = linalg.matmul +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> +// CHECK: scf.yield +// CHECK: scf.yield +func.func @matmul_add_test3(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { + + %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %1 = tensor.empty() : tensor<8192x16384xf32> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<8192x16384xf32> + %c0 = arith.constant 0 : index + %c8192 = arith.constant 8192 : index + %c128 = arith.constant 128 : index + %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { + %c0_0 = arith.constant 0 : index + %c16384 = arith.constant 16384 : index + %c128_1 = arith.constant 128 : index + %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> + %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> + %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> + %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %8 = arith.addf %in, %in_9 : f32 + linalg.yield %8 : f32 + } -> tensor<128x128xf32> + %c0_7 = arith.constant 0 : index + %c128_8 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { + %c0_9 = arith.constant 0 : index + %c128_10 = arith.constant 128 : index + %c32_11 = arith.constant 32 : index + %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { + %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> + %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> + %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_19: f32, %out: f32): + %11 = arith.addf %in, %in_19 : f32 + linalg.yield %11 : f32 + } -> tensor<32x32xf32> + %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> + scf.yield %inserted_slice_18 : tensor<128x128xf32> + } {__parallel_loop__ = 1 : i64} + scf.yield %8 : tensor<128x128xf32> + } {__parallel_loop__ = 1 : i64} + %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> + scf.yield %inserted_slice : tensor<8192x16384xf32> + } {__parallel_loop__ = 0 : i64} + scf.yield %4 : tensor<8192x16384xf32> + } {__parallel_loop__ = 0 : i64} + return %3 : tensor<8192x16384xf32> +} + +// CHECK-LABEL: func @fuse_mlp_test4 +#map = affine_map<(d0) -> (d0 * 64)> +#map1 = affine_map<(d0) -> (d0 * 128)> +#map2 = affine_map<(d0) -> (d0 * 4)> +#map3 = affine_map<(d0) -> (d0 floordiv 16)> +#map4 = affine_map<(d0) -> (d0 floordiv 32)> +func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = tensor.empty() : tensor<512x256xbf16> + %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> + %5 = affine.apply #map2(%arg4) + %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> + %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> + %extracted_slice_2 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_3 = tensor.extract_slice %arg2[%4] [128] [1] : tensor<256xbf16> to tensor<128xbf16> + %extracted_slice_4 = tensor.extract_slice %0[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_5 = tensor.extract_slice %arg6[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_6 = tensor.extract_slice %arg7[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %6:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %7:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %8:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %9 = affine.apply #map3(%arg16) + %10 = affine.apply #map4(%arg12) + %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%9, %10, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> + %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> + %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> + %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %11:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %12:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %13:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %14 = affine.apply #map3(%arg28) + %15 = affine.apply #map4(%arg24) + %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> + %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> + %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C512:.*]] = arith.constant 512 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: scf.forall +// CHECK-COUNT-6: scf.for +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<32xbf16> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x16x1x32xbf16> +// CHECK: %[[FILL0:.*]] = linalg.fill +// CHECK-COUNT-3: scf.for +// CHECK: %[[APPLY0:.*]] = affine.apply +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> +// CHECK-COUNT-4: scf.for +// CHECK: %[[APPLY1:.*]] = affine.apply +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> +// CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> +// CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ4]], %[[READ3]] : vector<32xbf16> +// CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> +// CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> +// CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> + %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%16 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> + %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] + %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %18 = linalg.add ins(%17, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_25 = tensor.insert_slice %17 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %19 = linalg.exp ins(%18 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_27 = tensor.insert_slice %18 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_28 = tensor.insert_slice %19 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %13#0, %13#1, %13#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %12#0, %12#1, %12#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + %inserted_slice = tensor.insert_slice %11#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_15 = tensor.insert_slice %11#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %11#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %8#0, %8#1, %8#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %6#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + } + } + return %2#2 : tensor<128x256xbf16> + } + +// CHECK-LABEL: func @elem_pack_transpose_inner_dims_test5 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C32I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<4x16x16x32xi32> +#map5 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_dims_test5(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> + return %pack : tensor<4x16x16x32xi32> +} + +// CHECK-LABEL: func @elem_pack_transpose_outer_dims_test6 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<16x4x32x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x4x32x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x4x32x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<16x4x32x16xi32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<16xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<16x4x32x16xi32> +#map6 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims_test6(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} + +// CHECK-LABEL: func @elem_pack_transpose_inner_and_outer_dims_test7 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<16x4x16x32xi32> +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<16x4x16x32xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x4x16x32xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x4x16x32xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<16x4x16x32xi32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<16x4x16x32xi32> +#map7 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_and_outer_dims_test7(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> + return %pack : tensor<16x4x16x32xi32> +} + +// CHECK-LABEL: func @elem_pack_transpose_inner_and_outer_dims2_test8 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C57:.*]] = arith.constant 57 : index +// CHECK: %[[C56:.*]] = arith.constant 56 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<1x56x57x2x32xf32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<1x56x57x64xf32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x2x56x57x32xf32> +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<1x56x57x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C1]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C56]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C57]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg10:.*]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%[[arg11:.*]] = %[[arg9]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg8]], %[[arg10]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<1x56x57x64xf32>, vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<1x56x57x2x32xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C1]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C56]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C57]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg10:.*]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%[[arg11:.*]] = %[[arg9]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<1x56x57x2x32xf32>, vector<16xf32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<1x2x56x57x32xf32> +#map8 = affine_map<(d0, d1, d2, d3) -> (d3)> +#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @elem_pack_transpose_inner_and_outer_dims2_test8(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> { + %0 = tensor.empty() : tensor<1x56x57x64xf32> + %1 = linalg.generic { + indexing_maps = [#map8, #map9], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<64xf32>) + outs(%0 : tensor<1x56x57x64xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x56x57x64xf32> + %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32> + return %2 : tensor<1x2x56x57x32xf32> +} + + +// CHECK-LABEL: func @broadcast_same_shape_test9 +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<2x16xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<2x16xf32> +func.func @broadcast_same_shape_test9(%input: tensor<16xf32>, %init: tensor<2x16xf32>) -> tensor<2x16xf32> { + %empty = tensor.empty() : tensor<2x16xf32> + %0 = linalg.broadcast ins(%input: tensor<16xf32>) outs(%empty: tensor<2x16xf32>) dimensions = [0] + %1 = linalg.add ins(%0, %init : tensor<2x16xf32>, tensor<2x16xf32>) outs(%init : tensor<2x16xf32>) -> tensor<2x16xf32> + return %1 : tensor<2x16xf32> +} + +// CHECK-LABEL: func @reduce_single_test10 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg6]], %[[arg4]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[arg7]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +func.func @reduce_single_test10(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %reduce = linalg.reduce + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %out, %in: f32 + linalg.yield %0: f32 + } + func.return %reduce : tensor<16x64xf32> +} + +// CHECK-LABEL: func @reduce_fusePostOp_test11 +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x32x64xf32> +// CHECK-COUNT-3: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x32x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x64xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ1]]) -> (vector<16xf32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[arg2]], %[[arg6]], %[[arg4]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ2]], %[[arg7]] : vector<16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf {{.*}}, {{.*}} : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +func.func @reduce_fusePostOp_test11(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %0 = linalg.add ins(%input, %input : tensor<16x32x64xf32>,tensor<16x32x64xf32>) + outs(%input : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> + %reduce = linalg.reduce + ins(%0:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %2 = arith.addf %out, %in: f32 + linalg.yield %2: f32 + } + %1 = linalg.mul ins(%reduce, %reduce : tensor<16x64xf32>, tensor<16x64xf32>) outs(%init: tensor<16x64xf32>) -> tensor<16x64xf32> + func.return %1 : tensor<16x64xf32> +} + +// CHECK-LABEL: func @reduce_fuse_test12 +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x32xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32xf32>, vector<16xf32> +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[CST]]) -> (vector<16xf32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map9]](%[[arg4]], %[[arg6]]) +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[APPLY0]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ1]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[arg9]] : vector<16xf32> +// CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 +// CHECK: %[[INSERT:.*]] = vector.insert %[[REDUCTION]], %[[arg7]] [%[[arg6]]] : f32 into vector<16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf {{.*}}, {{.*}} : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x32xf32> +func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, + %init: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.add ins(%input, %input : tensor<16x32x64xf32>,tensor<16x32x64xf32>) + outs(%input : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> + %reduce = linalg.reduce + ins(%0:tensor<16x32x64xf32>) + outs(%init:tensor<16x32xf32>) + dimensions = [2] + (%in: f32, %out: f32) { + %2 = arith.addf %out, %in: f32 + linalg.yield %2: f32 + } + %1 = linalg.mul ins(%reduce, %reduce : tensor<16x32xf32>, tensor<16x32xf32>) outs(%init: tensor<16x32xf32>) -> tensor<16x32xf32> + func.return %1 : tensor<16x32xf32> +} + +// CHECK-LABEL: func @reduce_fuse_test13 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x32x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<16x32x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x32x64xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x32x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %{{.*}}[%[[arg2]], %[[arg4]], %[[arg6]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[ADD0]], %[[arg7]][%[[arg2]], %[[arg4]], %[[arg6]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x32x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg3:.*]] = %[[arg1]]) -> (tensor<16xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[arg3]][%[[arg2]]], %[[CST_0]] {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[READ1]]) -> (vector<16xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[CST]]) -> (vector<16xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (vector<16xf32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map9]](%[[arg2]], %[[arg4]]) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[APPLY0]], %[[arg6]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ2]], %[[arg9]] : vector<16xf32> +// CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 +// CHECK: %[[INSERT:.*]] = vector.insert %[[REDUCTION]], %[[arg5]] [%[[arg4]]] : f32 into vector<16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf {{.*}}, {{.*}} : vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, %[[arg3]][%[[arg2]]] {in_bounds = [true]} : vector<16xf32>, tensor<16xf32> +func.func @reduce_fuse_test13(%input: tensor<16x32x64xf32>, + %init: tensor<16xf32>) -> tensor<16xf32> { + %0 = linalg.add ins(%input, %input : tensor<16x32x64xf32>,tensor<16x32x64xf32>) + outs(%input : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> + %reduce = linalg.reduce + ins(%0:tensor<16x32x64xf32>) + outs(%init:tensor<16xf32>) + dimensions = [1, 2] + (%in: f32, %out: f32) { + %2 = arith.addf %out, %in: f32 + linalg.yield %2: f32 + } + %1 = linalg.mul ins(%reduce, %reduce : tensor<16xf32>, tensor<16xf32>) outs(%init: tensor<16xf32>) -> tensor<16xf32> + func.return %1 : tensor<16xf32> +} + +// CHECK-LABEL: func @add_small_tensor_test14 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[TENSOR0:.*]] = tensor.empty() : tensor<2xf32> +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<2xf32>, vector<1xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<2xf32>, vector<1xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<1xf32> +// CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], %[[CST]] : vector<1xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<1xf32>, tensor<2xf32> +func.func @add_small_tensor_test14(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = tensor.empty() : tensor<2xf32> + %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32> + %2 = linalg.max ins(%1, %cst : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: func @broadcast_add_test15 +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg3:.*]] = {{.*}}) -> (tensor<64x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<64x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %arg0[%[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[ADD0]], %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32> +func.func @broadcast_add_test15(%arg0: tensor<64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %bcast = linalg.broadcast + ins(%arg0:tensor<64xf32>) + outs(%0:tensor<64x64xf32>) + dimensions = [0] + %out3 = linalg.add ins(%bcast, %arg1: tensor<64x64xf32>, tensor<64x64xf32>) + outs(%arg1: tensor<64x64xf32>) -> tensor<64x64xf32> + return %out3: tensor<64x64xf32> +} + +// CHECK-LABEL: func @broadcast_single_test16 +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: scf.for %[[arg1:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg2:.*]] = %[[EMPTY0]]) -> (tensor<64x64xf32>) +// CHECK: scf.for %[[arg3:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg4:.*]] = %[[arg2]]) -> (tensor<64x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %arg0[%[[arg3]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[READ0]], %[[arg4]][%[[arg1]], %[[arg3]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32> +func.func @broadcast_single_test16(%arg0: tensor<64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %bcast = linalg.broadcast + ins(%arg0: tensor<64xf32>) + outs(%0:tensor<64x64xf32>) + dimensions = [0] + return %bcast: tensor<64x64xf32> +} +