Skip to content

Commit

Permalink
Add more algebraic simplify rules (#1291)
Browse files Browse the repository at this point in the history
add more algebraic simplify rules
  • Loading branch information
eedalong authored May 20, 2024
1 parent b90fb75 commit 59c9279
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
55 changes: 52 additions & 3 deletions tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,55 @@ struct MulOneTensorOp : public OpRewritePattern<mhlo::MulOp> {
}
};

struct SelectSimplifierPattern : public OpRewritePattern<mhlo::SelectOp> {
using OpRewritePattern<mhlo::SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::SelectOp op,
PatternRewriter& rewriter) const override {
DenseElementsAttr valueAttr;
// select(x, y, y) -> y.
if (op.getOperand(1) == op.getOperand(2)) {
rewriter.replaceOp(op, op.getOperand(1));
return success();
} else if (allElementsAreSameValue(op.getOperand(0), -1)) {
// select(true, x, y) -> x.
rewriter.replaceOp(op, op.getOperand(1));
return success();
} else if (allElementsAreSameValue(op.getOperand(0), 0)) {
// select(false, x, y) -> y.
rewriter.replaceOp(op, op.getOperand(2));
return success();
} else if (matchPattern(op.getOperand(0), m_Op<mhlo::PadOp>()) &&
matchPattern(op.getOperand(1), m_Op<mhlo::PadOp>()) &&
matchPattern(op.getOperand(2), m_Op<mhlo::ConstantOp>())) {
// Select(Pad(true, False), Pad(weight, 0), 0) -> Pad(weight, 0)
auto condition = dyn_cast<mhlo::PadOp>(op.getOperand(0).getDefiningOp());
auto value1 = dyn_cast<mhlo::PadOp>(op.getOperand(1).getDefiningOp());
auto value2 =
dyn_cast<mhlo::ConstantOp>(op.getOperand(2).getDefiningOp());

if (!matchPattern(condition->getOperand(1), m_Op<mhlo::ConstantOp>())) {
return failure();
}

if (condition.getEdgePaddingLow() != value1.getEdgePaddingLow() ||
condition.getEdgePaddingHigh() != value1.getEdgePaddingHigh() ||
condition.getInteriorPadding() != value1.getInteriorPadding()) {
return failure();
}

if (allElementsAreSameValue(condition->getOperand(0), -1) &&
allElementsAreSameValue(condition->getOperand(1), 0) &&
allElementsAreSameValue(value1->getOperand(1), 0) &&
allElementsAreSameValue(op.getOperand(2), 0)) {
rewriter.replaceOp(op, op.getOperand(1));
return success();
}
}
return failure();
}
};

// convert:
// mhlo.pow(x, const integer n)
// to:
Expand Down Expand Up @@ -554,7 +603,6 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
TrunciSimplifierPattern,
IndexCastSimplifierPattern
>(patterns.getContext());

if (isMemIntensiveOptExperimentalEnabled()) {
// Will be enabled by default after a set of robustness testing.
patterns.insert<FoldBcastOfComputationOnConstantPattern>(
Expand All @@ -564,7 +612,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
// zero tensor related patterns
patterns.insert<
AddZeroTensorOp,
MulOneTensorOp
MulOneTensorOp,
SelectSimplifierPattern
>(patterns.getContext());
// clang-format on
}
Expand Down Expand Up @@ -600,4 +649,4 @@ createDiscAlgebraicSimplifierPass() {
}

} // namespace disc_ral
} // namespace mlir
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,22 @@ func.func @index_cast_simp(%arg0: index) -> index {
%1 = arith.index_cast %0 : i64 to index
return %1 : index
}

// ----

// CHECK-LABEL: @select_simp
func.func @select_simp(%arg0: tensor<16xf16>) -> (tensor<20xf16>, tensor<20xf16>) {
// CHECK: %0 = mhlo.constant dense<0.000000e+00> : tensor<20xf16>
// CHECK: %1 = mhlo.constant dense<0.000000e+00> : tensor<f16>
// CHECK: %2 = "mhlo.pad"(%arg0, %1) {edge_padding_high = dense<4> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<16xf16>, tensor<f16>) -> tensor<20xf16>
%1 = mhlo.constant dense<true> : tensor<i1>
%3 = mhlo.constant dense<0.000000e+00> : tensor<20xf16>
%4 = mhlo.constant dense<0.000000e+00> : tensor<f16>
%5 = mhlo.constant dense<true> : tensor<16xi1>
%6 = "mhlo.pad"(%5, %1) {edge_padding_high = dense<4> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<16xi1>, tensor<i1>) -> tensor<20xi1>
%7 = "mhlo.pad"(%arg0, %4) {edge_padding_high = dense<4> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<16xf16>, tensor<f16>) -> tensor<20xf16>
%8 = "mhlo.select"(%6, %7, %3): (tensor<20xi1>, tensor<20xf16>, tensor<20xf16>) -> tensor<20xf16>
%10 = mhlo.constant dense<false> : tensor<20xi1>
%11 = "mhlo.select"(%10, %7, %3): (tensor<20xi1>, tensor<20xf16>, tensor<20xf16>) -> tensor<20xf16>
return %8, %3 : tensor<20xf16>, tensor<20xf16>
}

0 comments on commit 59c9279

Please sign in to comment.