Skip to content

Commit

Permalink
use m_Constant
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw authored and wsmoses committed Feb 2, 2025
1 parent 4a93aa7 commit 26719d4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
26 changes: 12 additions & 14 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2935,13 +2935,13 @@ template <typename T> struct BinOpConstSimplify : public OpRewritePattern<T> {
auto lhs = op.getLhs();
auto rhs = op.getRhs();

auto lhsConst = lhs.template getDefiningOp<stablehlo::ConstantOp>();
auto rhsConst = rhs.template getDefiningOp<stablehlo::ConstantOp>();
auto lhsConst = matchPattern(lhs, m_Constant());
auto rhsConst = matchPattern(rhs, m_Constant());

if (!lhsConst && !rhsConst)
return failure();

auto constOp = lhsConst ? lhsConst : rhsConst;
auto constVal = lhsConst ? lhs : rhs;
auto otherOp = lhsConst ? rhs.template getDefiningOp<T>()
: lhs.template getDefiningOp<T>();

Expand All @@ -2951,21 +2951,19 @@ template <typename T> struct BinOpConstSimplify : public OpRewritePattern<T> {
auto otherLhs = otherOp.getRhs();
auto otherRhs = otherOp.getLhs();

if (!otherLhs.template getDefiningOp<stablehlo::ConstantOp>() &&
!otherRhs.template getDefiningOp<stablehlo::ConstantOp>())
auto otherLhsConst = matchPattern(otherLhs, m_Constant());
auto otherRhsConst = matchPattern(otherRhs, m_Constant());

if (!otherLhsConst && !otherRhsConst)
return failure();

// Both op and other have a constant operand
// group constants to a new op.
auto otherConst =
otherLhs.template getDefiningOp<stablehlo::ConstantOp>()
? otherLhs.template getDefiningOp<stablehlo::ConstantOp>()
: otherRhs.template getDefiningOp<stablehlo::ConstantOp>();
auto otherOperand =
otherConst.getResult() == otherLhs ? otherRhs : otherLhs;

auto constantAdd = rewriter.create<T>(op.getLoc(), op.getResult().getType(),
constOp, otherConst);
auto otherConstVal = otherLhsConst ? otherLhs : otherRhs;
auto otherOperand = otherLhsConst ? otherRhs : otherLhs;

auto constantAdd = rewriter.create<T>(
otherOp.getLoc(), op.getResult().getType(), constVal, otherConstVal);
rewriter.replaceOpWithNewOp<T>(op, otherOperand, constantAdd);

return success();
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def hlo_opts():
binary_op_transpose_simplify_or<1>;
binary_op_transpose_simplify_xor<1>;
binary_op_transpose_simplify_rem<1>;
binop_const_simplify<1>;
compare_select_simplify;
common_compare_expression_rewrite;
not_select_simplify;
Expand Down

0 comments on commit 26719d4

Please sign in to comment.