Skip to content

Commit

Permalink
rebase master branch
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Apr 2, 2024
2 parents 4d2a04c + 7414b51 commit 3b5a624
Show file tree
Hide file tree
Showing 13 changed files with 728 additions and 209 deletions.
22 changes: 22 additions & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,27 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_collective_ops_rewriter",
srcs = ["transforms/disc_collective_ops_rewriter.cc"],
hdrs = ["transforms/passes.h"],
includes = [
"tensorflow/compiler/xla/mlir_hlo/include",
"."
],
deps = [
":disc_util",
":pass_details",
":mhlo_disc",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

cc_library(
name = "mhlo_decomp_rewriters",
srcs = ["transforms/mhlo_decomp_rewriters.cc"],
Expand Down Expand Up @@ -2407,6 +2428,7 @@ cc_library(
":input_inline_fusion",
":lhlo_fusion_inliner",
":mhlo_decomp_rewriters",
":disc_collective_ops_rewriter",
":mhlo_mark_shape_calc",
":mhlo_placer",
":ral_inject_execution_context",
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addPass(disc_ral::createDiscInputOutputAliasPass());
pm.addPass(mlir::createInlinerPass());
// TODO(disc): Lower HLO shape constraints instead of eliding them here.
pm.addNestedPass<FuncOp>(disc_ral::createDiscCollectiveOpsRewriterPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscMhloDecompositionRewriterPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscRemoveShapeConstraintsPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
Expand Down
52 changes: 51 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,55 @@ struct MulOneTensorOp : public OpRewritePattern<mhlo::MulOp> {
}
};

struct SelectSimplifier : public OpRewritePattern<mhlo::SelectOp> {
using OpRewritePattern<mhlo::SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::SelectOp op,
PatternRewriter& rewriter) const override {
DenseElementsAttr valueAttr;
// select(x, y, y) -> y.
if (op.getOperand(1) == op.getOperand(2)) {
rewriter.replaceOp(op, op.getOperand(1));
return success();
} else if (allElementsAreSameValue(op.getOperand(0), -1)) {
// select(true, x, y) -> x.
rewriter.replaceOp(op, op.getOperand(1));
return success();
} else if (allElementsAreSameValue(op.getOperand(0), 0)) {
// select(false, x, y) -> y.
rewriter.replaceOp(op, op.getOperand(2));
return success();
} else if (matchPattern(op.getOperand(0), m_Op<mhlo::PadOp>()) &&
matchPattern(op.getOperand(1), m_Op<mhlo::PadOp>()) &&
matchPattern(op.getOperand(2), m_Op<mhlo::ConstantOp>())) {
// Select(Pad(true, False), Pad(weight, 0), 0) -> Pad(weight, 0)
auto condition = dyn_cast<mhlo::PadOp>(op.getOperand(0).getDefiningOp());
auto value1 = dyn_cast<mhlo::PadOp>(op.getOperand(1).getDefiningOp());
auto value2 =
dyn_cast<mhlo::ConstantOp>(op.getOperand(2).getDefiningOp());

if (!matchPattern(condition->getOperand(1), m_Op<mhlo::ConstantOp>())) {
return failure();
}

if (condition.getEdgePaddingLow() != value1.getEdgePaddingLow() ||
condition.getEdgePaddingHigh() != value1.getEdgePaddingHigh() ||
condition.getInteriorPadding() != value1.getInteriorPadding()) {
return failure();
}

if (allElementsAreSameValue(condition->getOperand(0), -1) &&
allElementsAreSameValue(condition->getOperand(1), 0) &&
allElementsAreSameValue(value1->getOperand(1), 0) &&
allElementsAreSameValue(op.getOperand(2), 0)) {
rewriter.replaceOp(op, op.getOperand(1));
return success();
}
}
return failure();
}
};

// convert:
// mhlo.pow(x, const integer n)
// to:
Expand Down Expand Up @@ -564,7 +613,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
// zero tensor related patterns
patterns.insert<
AddZeroTensorOp,
MulOneTensorOp
MulOneTensorOp,
SelectSimplifier
>(patterns.getContext());
// clang-format on
}
Expand Down
Loading

0 comments on commit 3b5a624

Please sign in to comment.