diff --git a/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc b/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc index 1babba1c947..965e8ccfcd6 100644 --- a/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc +++ b/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc @@ -90,6 +90,55 @@ struct MulOneTensorOp : public OpRewritePattern { } }; +struct SelectSimplifierPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::SelectOp op, + PatternRewriter& rewriter) const override { + DenseElementsAttr valueAttr; + // select(x, y, y) -> y. + if (op.getOperand(1) == op.getOperand(2)) { + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } else if (allElementsAreSameValue(op.getOperand(0), -1)) { + // select(true, x, y) -> x. + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } else if (allElementsAreSameValue(op.getOperand(0), 0)) { + // select(false, x, y) -> y. + rewriter.replaceOp(op, op.getOperand(2)); + return success(); + } else if (matchPattern(op.getOperand(0), m_Op()) && + matchPattern(op.getOperand(1), m_Op()) && + matchPattern(op.getOperand(2), m_Op())) { + // Select(Pad(true, False), Pad(weight, 0), 0) -> Pad(weight, 0) + auto condition = dyn_cast(op.getOperand(0).getDefiningOp()); + auto value1 = dyn_cast(op.getOperand(1).getDefiningOp()); + auto value2 = + dyn_cast(op.getOperand(2).getDefiningOp()); + + if (!matchPattern(condition->getOperand(1), m_Op())) { + return failure(); + } + + if (condition.getEdgePaddingLow() != value1.getEdgePaddingLow() || + condition.getEdgePaddingHigh() != value1.getEdgePaddingHigh() || + condition.getInteriorPadding() != value1.getInteriorPadding()) { + return failure(); + } + + if (allElementsAreSameValue(condition->getOperand(0), -1) && + allElementsAreSameValue(condition->getOperand(1), 0) && + allElementsAreSameValue(value1->getOperand(1), 0) && + allElementsAreSameValue(op.getOperand(2), 0)) { + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } + } + return failure(); + } +}; + // convert: // mhlo.pow(x, const integer n) // to: @@ -554,7 +603,6 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) { TrunciSimplifierPattern, IndexCastSimplifierPattern >(patterns.getContext()); - if (isMemIntensiveOptExperimentalEnabled()) { // Will be enabled by default after a set of robustness testing. patterns.insert( @@ -564,7 +612,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) { // zero tensor related patterns patterns.insert< AddZeroTensorOp, - MulOneTensorOp + MulOneTensorOp, + SelectSimplifierPattern >(patterns.getContext()); // clang-format on } @@ -600,4 +649,4 @@ createDiscAlgebraicSimplifierPass() { } } // namespace disc_ral -} // namespace mlir +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-algebraic-simplifier.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-algebraic-simplifier.mlir index d331a8ed5a2..590af202c02 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-algebraic-simplifier.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-algebraic-simplifier.mlir @@ -238,3 +238,22 @@ func.func @index_cast_simp(%arg0: index) -> index { %1 = arith.index_cast %0 : i64 to index return %1 : index } + +// ---- + +// CHECK-LABEL: @select_simp +func.func @select_simp(%arg0: tensor<16xf16>) -> (tensor<20xf16>, tensor<20xf16>) { + // CHECK: %0 = mhlo.constant dense<0.000000e+00> : tensor<20xf16> + // CHECK: %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %2 = "mhlo.pad"(%arg0, %1) {edge_padding_high = dense<4> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<16xf16>, tensor) -> tensor<20xf16> + %1 = mhlo.constant dense : tensor + %3 = mhlo.constant dense<0.000000e+00> : tensor<20xf16> + %4 = mhlo.constant dense<0.000000e+00> : tensor + %5 = mhlo.constant dense : tensor<16xi1> + %6 = "mhlo.pad"(%5, %1) {edge_padding_high = dense<4> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<16xi1>, tensor) -> tensor<20xi1> + %7 = "mhlo.pad"(%arg0, %4) {edge_padding_high = dense<4> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<16xf16>, tensor) -> tensor<20xf16> + %8 = "mhlo.select"(%6, %7, %3): (tensor<20xi1>, tensor<20xf16>, tensor<20xf16>) -> tensor<20xf16> + %10 = mhlo.constant dense : tensor<20xi1> + %11 = "mhlo.select"(%10, %7, %3): (tensor<20xi1>, tensor<20xf16>, tensor<20xf16>) -> tensor<20xf16> + return %8, %3 : tensor<20xf16>, tensor<20xf16> +} \ No newline at end of file