Skip to content

Commit

Permalink
add more algebraic simplifierrules
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Apr 15, 2024
1 parent f160eb2 commit aac25a3
Showing 1 changed file with 233 additions and 3 deletions.
236 changes: 233 additions & 3 deletions tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,233 @@ struct MulOneTensorOp : public OpRewritePattern<mhlo::MulOp> {
}
};

struct SelectSimplifierPattern : public OpRewritePattern<mhlo::SelectOp> {
using OpRewritePattern<mhlo::SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::SelectOp op,
PatternRewriter& rewriter) const override {
DenseElementsAttr valueAttr;
// select(x, y, y) -> y.
if (op.getOperand(1) == op.getOperand(2)) {
rewriter.replaceOp(op, op.getOperand(1));
return success();
} else if (allElementsAreSameValue(op.getOperand(0), -1)) {
// select(true, x, y) -> x.
rewriter.replaceOp(op, op.getOperand(1));
return success();
} else if (allElementsAreSameValue(op.getOperand(0), 0)) {
// select(false, x, y) -> y.
rewriter.replaceOp(op, op.getOperand(2));
return success();
} else if (matchPattern(op.getOperand(0), m_Op<mhlo::PadOp>()) &&
matchPattern(op.getOperand(1), m_Op<mhlo::PadOp>()) &&
matchPattern(op.getOperand(2), m_Op<mhlo::ConstantOp>())) {
// Select(Pad(true, False), Pad(weight, 0), 0) -> Pad(weight, 0)
auto condition = dyn_cast<mhlo::PadOp>(op.getOperand(0).getDefiningOp());
auto value1 = dyn_cast<mhlo::PadOp>(op.getOperand(1).getDefiningOp());
auto value2 =
dyn_cast<mhlo::ConstantOp>(op.getOperand(2).getDefiningOp());

if (!matchPattern(condition->getOperand(1), m_Op<mhlo::ConstantOp>())) {
return failure();
}

if (condition.getEdgePaddingLow() != value1.getEdgePaddingLow() ||
condition.getEdgePaddingHigh() != value1.getEdgePaddingHigh() ||
condition.getInteriorPadding() != value1.getInteriorPadding()) {
return failure();
}

if (allElementsAreSameValue(condition->getOperand(0), -1) &&
allElementsAreSameValue(condition->getOperand(1), 0) &&
allElementsAreSameValue(value1->getOperand(1), 0) &&
allElementsAreSameValue(op.getOperand(2), 0)) {
rewriter.replaceOp(op, op.getOperand(1));
return success();
}
}
return failure();
}
};

struct SliceSimplifierPattern : public OpRewritePattern<mhlo::SliceOp> {
using OpRewritePattern<mhlo::SliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::SliceOp op,
PatternRewriter& rewriter) const override {
if (matchPattern(op->getOperand(0), m_Op<mhlo::SliceOp>())) {
// Slice(Slice(A)) -> Slice(A)
auto innerSliceOp =
dyn_cast<mhlo::SliceOp>(op->getOperand(0).getDefiningOp());

auto innerStartIndicesAttr = innerSliceOp.getStartIndices();
auto outterStartIndicesAttr = op.getStartIndices();

auto innerLimitIndicesAttr = innerSliceOp.getLimitIndices();
auto outterLimitIndicesAttr = op.getLimitIndices();

auto innerStridesAttr = innerSliceOp.getStrides();
auto outterStridesAttr = op.getStrides();

std::vector<int64_t> newStartIndicesVec;
std::vector<int64_t> newLimitIndicesVec;
std::vector<int64_t> newStridesVec;

int rank = innerStartIndicesAttr.getNumElements();
for (int i = 0; i < rank; i++) {
// newStart = innerStart + outterStart * innerStride
// newLimit = newStart + (outterLimit - outterStart) * innerStride;
// newStride = innerStride * outterStride
auto newStart = *(innerStartIndicesAttr.begin() + i) +
*(outterStartIndicesAttr.begin() + i) *
*(innerStridesAttr.begin() + i);
auto newLimit = newStart + (*(outterLimitIndicesAttr.begin() + i) -
*(outterStartIndicesAttr.begin() + i)) *
*(innerStridesAttr.begin() + i);
auto newStride =
*(innerStridesAttr.begin() + i) * *(outterStridesAttr.begin() + i);

newStartIndicesVec.push_back(newStart.getSExtValue());
newLimitIndicesVec.push_back(newLimit.getSExtValue());
newStridesVec.push_back(newStride.getSExtValue());
}
auto newSliceOp = rewriter.create<mhlo::SliceOp>(
op.getLoc(), op.getType(), innerSliceOp->getOperand(0),
rewriter.getI64TensorAttr(newStartIndicesVec),
rewriter.getI64TensorAttr(newLimitIndicesVec),
rewriter.getI64TensorAttr(newStridesVec));
rewriter.replaceOp(op, {newSliceOp.getResult()});
return success();
} else if (matchPattern(op->getOperand(0), m_Op<mhlo::ReshapeOp>())) {
// Slice(Reshape(A)) -> Reshape(Slice(A))
auto reshapeOp =
dyn_cast<mhlo::ReshapeOp>(op->getOperand(0).getDefiningOp());
auto innerOperand = reshapeOp->getOperand(0);
auto innerOperandShape =
innerOperand.getType().dyn_cast<RankedTensorType>().getShape();
auto outterOperand = op->getOperand(0);
auto outterOperandShape =
outterOperand.getType().dyn_cast<RankedTensorType>().getShape();

auto startIndicesAttr = op.getStartIndices();
auto limitIndicesAttr = op.getLimitIndices();
auto stridesAttr = op.getStrides();

std::vector<int64_t> newStartIndicesVec;
std::vector<int64_t> newLimitIndicesVec;
std::vector<int64_t> newStridesVec;
std::vector<int64_t> newSliceResultShape;

int64_t flattenedStartIndice = 0, flattenedLimitIndice = 0;
int64_t dimAccum = 1;
for (int i = outterOperandShape.size() - 1; i >= 0; i--) {
flattenedStartIndice +=
dimAccum * (*(startIndicesAttr.begin() + i)).getSExtValue();
if (*(stridesAttr.begin() + i) != 1) {
return failure();
}
dimAccum *= outterOperandShape[i];
}
flattenedLimitIndice =
flattenedStartIndice +
op.getType().dyn_cast<RankedTensorType>().getNumElements() - 1;

dimAccum = 1;
for (int i = 0; i < innerOperandShape.size(); i++) {
dimAccum *= innerOperandShape[i];
}

size_t expected_ele_num =
op.getType().dyn_cast<RankedTensorType>().getNumElements();
size_t total_ele_num = 1;
bool shapeEnd = false;
for (int i = 0; i < innerOperandShape.size(); i++) {
dimAccum /= innerOperandShape[i];
newStartIndicesVec.push_back(flattenedStartIndice / dimAccum);
newLimitIndicesVec.push_back(flattenedLimitIndice / dimAccum + 1);
if (newLimitIndicesVec[i] <= newStartIndicesVec[i]) {
return failure();
}

newSliceResultShape.push_back(newLimitIndicesVec[i] -
newStartIndicesVec[i]);
newStridesVec.push_back(1);
total_ele_num *= (newLimitIndicesVec[i] - newStartIndicesVec[i]);
flattenedStartIndice %= dimAccum;
flattenedLimitIndice %= dimAccum;
}

if (total_ele_num != expected_ele_num) {
return failure();
}

auto newSliceResultType = mlir::RankedTensorType::get(
newSliceResultShape,
innerOperand.getType().dyn_cast<RankedTensorType>().getElementType());

auto newSliceOp = rewriter.create<mhlo::SliceOp>(
op.getLoc(), newSliceResultType, innerOperand,
rewriter.getI64TensorAttr(newStartIndicesVec),
rewriter.getI64TensorAttr(newLimitIndicesVec),
rewriter.getI64TensorAttr(newStridesVec));

auto newReshapeOp = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), op.getType(), newSliceOp);

rewriter.replaceOp(op, newReshapeOp.getResult());

return success();
}

return failure();
}
};

struct ReshapeSimplifierPattern : public OpRewritePattern<mhlo::ReshapeOp> {
using OpRewritePattern<mhlo::ReshapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::ReshapeOp op,
PatternRewriter& rewriter) const override {
if (matchPattern(op->getOperand(0), m_Op<mhlo::ReshapeOp>())) {
// Reshape(Reshape(A)) -> Reshape(A)
rewriter.replaceOp(op, op->getOperand(0));
return success();
}

return failure();
}
};

struct OptimizationBarrierSimplifierPattern
: public OpRewritePattern<mhlo::OptimizationBarrierOp> {
using OpRewritePattern<mhlo::OptimizationBarrierOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::OptimizationBarrierOp op,
PatternRewriter& rewriter) const override {
if (op->getOperandNumber() == 1 &&
matchPattern(op->getOperand(0), m_Op<mhlo::ReshapeOp>())) {
// OptimizationBarrier(Reshape(AllGather)) ->
// Reshape(OptimizationBarrier(AllGather))
auto reshapeOp =
dyn_cast<mhlo::ReshapeOp>(op->getOperand(0).getDefiningOp());
if (!matchPattern(reshapeOp->getOperand(0), m_Op<mhlo::AllGatherOp>())) {
return failure();
}
auto newOptimizationBarrierOp =
rewriter.create<mhlo::OptimizationBarrierOp>(
op.getLoc(), reshapeOp->getOperand(0).getType(),
reshapeOp->getOperand(0));
auto newReshapeOp = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), reshapeOp.getType(),
newOptimizationBarrierOp->getResult(0));
rewriter.replaceOp(op, newReshapeOp.getResult());
return success();
}

return failure();
}
};

// convert:
// mhlo.pow(x, const integer n)
// to:
Expand Down Expand Up @@ -554,7 +781,6 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
TrunciSimplifierPattern,
IndexCastSimplifierPattern
>(patterns.getContext());

if (isMemIntensiveOptExperimentalEnabled()) {
// Will be enabled by default after a set of robustness testing.
patterns.insert<FoldBcastOfComputationOnConstantPattern>(
Expand All @@ -564,7 +790,11 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
// zero tensor related patterns
patterns.insert<
AddZeroTensorOp,
MulOneTensorOp
MulOneTensorOp,
SelectSimplifierPattern,
ReshapeSimplifierPattern,
SliceSimplifierPattern,
OptimizationBarrierSimplifierPattern
>(patterns.getContext());
// clang-format on
}
Expand Down Expand Up @@ -600,4 +830,4 @@ createDiscAlgebraicSimplifierPass() {
}

} // namespace disc_ral
} // namespace mlir
} // namespace mlir

0 comments on commit aac25a3

Please sign in to comment.