From aac25a3b3aa828b758d82230740699029c8ff05f Mon Sep 17 00:00:00 2001 From: eedalong Date: Mon, 15 Apr 2024 20:33:05 +0800 Subject: [PATCH] add more algebraic simplifierrules --- .../transforms/disc_algebraic_simplifier.cc | 236 +++++++++++++++++- 1 file changed, 233 insertions(+), 3 deletions(-) mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc diff --git a/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc b/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc old mode 100644 new mode 100755 index 1babba1c947..17ad00282c3 --- a/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc +++ b/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc @@ -90,6 +90,233 @@ struct MulOneTensorOp : public OpRewritePattern { } }; +struct SelectSimplifierPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::SelectOp op, + PatternRewriter& rewriter) const override { + DenseElementsAttr valueAttr; + // select(x, y, y) -> y. + if (op.getOperand(1) == op.getOperand(2)) { + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } else if (allElementsAreSameValue(op.getOperand(0), -1)) { + // select(true, x, y) -> x. + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } else if (allElementsAreSameValue(op.getOperand(0), 0)) { + // select(false, x, y) -> y. + rewriter.replaceOp(op, op.getOperand(2)); + return success(); + } else if (matchPattern(op.getOperand(0), m_Op()) && + matchPattern(op.getOperand(1), m_Op()) && + matchPattern(op.getOperand(2), m_Op())) { + // Select(Pad(true, False), Pad(weight, 0), 0) -> Pad(weight, 0) + auto condition = dyn_cast(op.getOperand(0).getDefiningOp()); + auto value1 = dyn_cast(op.getOperand(1).getDefiningOp()); + auto value2 = + dyn_cast(op.getOperand(2).getDefiningOp()); + + if (!matchPattern(condition->getOperand(1), m_Op())) { + return failure(); + } + + if (condition.getEdgePaddingLow() != value1.getEdgePaddingLow() || + condition.getEdgePaddingHigh() != value1.getEdgePaddingHigh() || + condition.getInteriorPadding() != value1.getInteriorPadding()) { + return failure(); + } + + if (allElementsAreSameValue(condition->getOperand(0), -1) && + allElementsAreSameValue(condition->getOperand(1), 0) && + allElementsAreSameValue(value1->getOperand(1), 0) && + allElementsAreSameValue(op.getOperand(2), 0)) { + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } + } + return failure(); + } +}; + +struct SliceSimplifierPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::SliceOp op, + PatternRewriter& rewriter) const override { + if (matchPattern(op->getOperand(0), m_Op())) { + // Slice(Slice(A)) -> Slice(A) + auto innerSliceOp = + dyn_cast(op->getOperand(0).getDefiningOp()); + + auto innerStartIndicesAttr = innerSliceOp.getStartIndices(); + auto outterStartIndicesAttr = op.getStartIndices(); + + auto innerLimitIndicesAttr = innerSliceOp.getLimitIndices(); + auto outterLimitIndicesAttr = op.getLimitIndices(); + + auto innerStridesAttr = innerSliceOp.getStrides(); + auto outterStridesAttr = op.getStrides(); + + std::vector newStartIndicesVec; + std::vector newLimitIndicesVec; + std::vector newStridesVec; + + int rank = innerStartIndicesAttr.getNumElements(); + for (int i = 0; i < rank; i++) { + // newStart = innerStart + outterStart * innerStride + // newLimit = newStart + (outterLimit - outterStart) * innerStride; + // newStride = innerStride * outterStride + auto newStart = *(innerStartIndicesAttr.begin() + i) + + *(outterStartIndicesAttr.begin() + i) * + *(innerStridesAttr.begin() + i); + auto newLimit = newStart + (*(outterLimitIndicesAttr.begin() + i) - + *(outterStartIndicesAttr.begin() + i)) * + *(innerStridesAttr.begin() + i); + auto newStride = + *(innerStridesAttr.begin() + i) * *(outterStridesAttr.begin() + i); + + newStartIndicesVec.push_back(newStart.getSExtValue()); + newLimitIndicesVec.push_back(newLimit.getSExtValue()); + newStridesVec.push_back(newStride.getSExtValue()); + } + auto newSliceOp = rewriter.create( + op.getLoc(), op.getType(), innerSliceOp->getOperand(0), + rewriter.getI64TensorAttr(newStartIndicesVec), + rewriter.getI64TensorAttr(newLimitIndicesVec), + rewriter.getI64TensorAttr(newStridesVec)); + rewriter.replaceOp(op, {newSliceOp.getResult()}); + return success(); + } else if (matchPattern(op->getOperand(0), m_Op())) { + // Slice(Reshape(A)) -> Reshape(Slice(A)) + auto reshapeOp = + dyn_cast(op->getOperand(0).getDefiningOp()); + auto innerOperand = reshapeOp->getOperand(0); + auto innerOperandShape = + innerOperand.getType().dyn_cast().getShape(); + auto outterOperand = op->getOperand(0); + auto outterOperandShape = + outterOperand.getType().dyn_cast().getShape(); + + auto startIndicesAttr = op.getStartIndices(); + auto limitIndicesAttr = op.getLimitIndices(); + auto stridesAttr = op.getStrides(); + + std::vector newStartIndicesVec; + std::vector newLimitIndicesVec; + std::vector newStridesVec; + std::vector newSliceResultShape; + + int64_t flattenedStartIndice = 0, flattenedLimitIndice = 0; + int64_t dimAccum = 1; + for (int i = outterOperandShape.size() - 1; i >= 0; i--) { + flattenedStartIndice += + dimAccum * (*(startIndicesAttr.begin() + i)).getSExtValue(); + if (*(stridesAttr.begin() + i) != 1) { + return failure(); + } + dimAccum *= outterOperandShape[i]; + } + flattenedLimitIndice = + flattenedStartIndice + + op.getType().dyn_cast().getNumElements() - 1; + + dimAccum = 1; + for (int i = 0; i < innerOperandShape.size(); i++) { + dimAccum *= innerOperandShape[i]; + } + + size_t expected_ele_num = + op.getType().dyn_cast().getNumElements(); + size_t total_ele_num = 1; + bool shapeEnd = false; + for (int i = 0; i < innerOperandShape.size(); i++) { + dimAccum /= innerOperandShape[i]; + newStartIndicesVec.push_back(flattenedStartIndice / dimAccum); + newLimitIndicesVec.push_back(flattenedLimitIndice / dimAccum + 1); + if (newLimitIndicesVec[i] <= newStartIndicesVec[i]) { + return failure(); + } + + newSliceResultShape.push_back(newLimitIndicesVec[i] - + newStartIndicesVec[i]); + newStridesVec.push_back(1); + total_ele_num *= (newLimitIndicesVec[i] - newStartIndicesVec[i]); + flattenedStartIndice %= dimAccum; + flattenedLimitIndice %= dimAccum; + } + + if (total_ele_num != expected_ele_num) { + return failure(); + } + + auto newSliceResultType = mlir::RankedTensorType::get( + newSliceResultShape, + innerOperand.getType().dyn_cast().getElementType()); + + auto newSliceOp = rewriter.create( + op.getLoc(), newSliceResultType, innerOperand, + rewriter.getI64TensorAttr(newStartIndicesVec), + rewriter.getI64TensorAttr(newLimitIndicesVec), + rewriter.getI64TensorAttr(newStridesVec)); + + auto newReshapeOp = rewriter.create( + op.getLoc(), op.getType(), newSliceOp); + + rewriter.replaceOp(op, newReshapeOp.getResult()); + + return success(); + } + + return failure(); + } +}; + +struct ReshapeSimplifierPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::ReshapeOp op, + PatternRewriter& rewriter) const override { + if (matchPattern(op->getOperand(0), m_Op())) { + // Reshape(Reshape(A)) -> Reshape(A) + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } + + return failure(); + } +}; + +struct OptimizationBarrierSimplifierPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::OptimizationBarrierOp op, + PatternRewriter& rewriter) const override { + if (op->getOperandNumber() == 1 && + matchPattern(op->getOperand(0), m_Op())) { + // OptimizationBarrier(Reshape(AllGather)) -> + // Reshape(OptimizationBarrier(AllGather)) + auto reshapeOp = + dyn_cast(op->getOperand(0).getDefiningOp()); + if (!matchPattern(reshapeOp->getOperand(0), m_Op())) { + return failure(); + } + auto newOptimizationBarrierOp = + rewriter.create( + op.getLoc(), reshapeOp->getOperand(0).getType(), + reshapeOp->getOperand(0)); + auto newReshapeOp = rewriter.create( + op.getLoc(), reshapeOp.getType(), + newOptimizationBarrierOp->getResult(0)); + rewriter.replaceOp(op, newReshapeOp.getResult()); + return success(); + } + + return failure(); + } +}; + // convert: // mhlo.pow(x, const integer n) // to: @@ -554,7 +781,6 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) { TrunciSimplifierPattern, IndexCastSimplifierPattern >(patterns.getContext()); - if (isMemIntensiveOptExperimentalEnabled()) { // Will be enabled by default after a set of robustness testing. patterns.insert( @@ -564,7 +790,11 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) { // zero tensor related patterns patterns.insert< AddZeroTensorOp, - MulOneTensorOp + MulOneTensorOp, + SelectSimplifierPattern, + ReshapeSimplifierPattern, + SliceSimplifierPattern, + OptimizationBarrierSimplifierPattern >(patterns.getContext()); // clang-format on } @@ -600,4 +830,4 @@ createDiscAlgebraicSimplifierPass() { } } // namespace disc_ral -} // namespace mlir +} // namespace mlir \ No newline at end of file