Skip to content

Commit

Permalink
dialects: (arith) select canonicalization patterns (#3368)
Browse files Browse the repository at this point in the history
Adds some simple canonicalization patterns for `arith.select`
  • Loading branch information
alexarice authored Oct 31, 2024
1 parent a0d3757 commit 656457a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 8 deletions.
40 changes: 34 additions & 6 deletions tests/filecheck/dialects/arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
%lhsf32, %rhsf32 = "test.op"() : () -> (f32, f32)
%lhsvec, %rhsvec = "test.op"() : () -> (vector<4xf32>, vector<4xf32>)

// CHECK: %lhsf32, %rhsf32 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: %lhsvec, %rhsvec = "test.op"() : () -> (vector<4xf32>, vector<4xf32>)
// CHECK-NEXT: %addf = arith.addf %lhsf32, %rhsf32 : f32
// CHECK-NEXT: %addf_vector = arith.addf %lhsvec, %rhsvec : vector<4xf32>
// CHECK-NEXT: "test.op"(%addf, %addf_vector) : (f32, vector<4xf32>) -> ()
%addf = arith.addf %lhsf32, %rhsf32 : f32
%addf_1 = arith.addf %lhsf32, %rhsf32 : f32
%addf_vector = arith.addf %lhsvec, %rhsvec : vector<4xf32>
%addf_vector_1 = arith.addf %lhsvec, %rhsvec : vector<4xf32>

"test.op"(%addf, %addf_vector) : (f32, vector<4xf32>) -> ()

// CHECK: %lhsf32, %rhsf32 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: %lhsvec, %rhsvec = "test.op"() : () -> (vector<4xf32>, vector<4xf32>)
// CHECK-NEXT: %addf = arith.addf %lhsf32, %rhsf32 : f32
// CHECK-NEXT: %addf_vector = arith.addf %lhsvec, %rhsvec : vector<4xf32>
// CHECK-NEXT: "test.op"(%addf, %addf_vector) : (f32, vector<4xf32>) -> ()

func.func @test_const_const() {
%a = arith.constant 2.9979 : f32
%b = arith.constant 3.1415 : f32
Expand Down Expand Up @@ -62,3 +61,32 @@ func.func @test_const_var_const() {
// CHECK-NEXT: %5 = arith.mulf %4, %0 fastmath<fast> : f32
// CHECK-NEXT: "test.op"(%3, %5) : (f32, f32) -> ()
}

// CHECK: %lhs, %rhs = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: %ctrue = arith.constant true
// CHECK-NEXT: "test.op"(%lhs, %lhs) : (f32, f32) -> ()

%lhs, %rhs = "test.op"() : () -> (f32, f32)
%ctrue = arith.constant true
%cfalse = arith.constant false
%select_true = arith.select %ctrue, %lhs, %rhs : f32
%select_false = arith.select %ctrue, %lhs, %rhs : f32
"test.op"(%select_true, %select_false) : (f32, f32) -> ()

// CHECK: %cond = "test.op"() : () -> i1
// CHECK-NEXT: %select_false_true = arith.xori %cond, %ctrue : i1
// CHECK-NEXT: "test.op"(%cond, %select_false_true) : (i1, i1) -> ()

%cond = "test.op"() : () -> (i1)
%select_true_false = arith.select %cond, %ctrue, %cfalse : i1
%select_false_true = arith.select %cond, %cfalse, %ctrue : i1
"test.op"(%select_true_false, %select_false_true) : (i1, i1) -> ()

%x, %y = "test.op"() : () -> (i1, i64)

// CHECK: %x, %y = "test.op"() : () -> (i1, i64)
// CHECK-NEXT: "test.op"(%y) : (i64) -> ()

%z = arith.select %x, %y, %y : i64

"test.op"(%z) : (i64) -> ()
14 changes: 13 additions & 1 deletion xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,18 @@ def print(self, printer: Printer):
printer.print_attribute(self.lhs.type)


class SelectHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import (
SelectConstPattern,
SelectSamePattern,
SelectTrueFalsePattern,
)

return (SelectConstPattern(), SelectTrueFalsePattern(), SelectSamePattern())


@irdl_op_definition
class Select(IRDLOperation):
"""
Expand All @@ -873,7 +885,7 @@ class Select(IRDLOperation):
rhs = operand_def(Attribute)
result = result_def(Attribute)

traits = frozenset([Pure()])
traits = frozenset([Pure(), SelectHasCanonicalizationPatterns()])

# TODO replace with trait
def verify_(self) -> None:
Expand Down
58 changes: 57 additions & 1 deletion xdsl/transforms/canonicalization_patterns/arith.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from xdsl.dialects import arith, builtin
from xdsl.dialects.builtin import IntegerAttr
from xdsl.dialects.builtin import IntegerAttr, IntegerType
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
Expand Down Expand Up @@ -102,3 +102,59 @@ def match_and_rewrite(
rebuild = type(op)(cnsts, val, flags)
rewriter.replace_matched_op([cnsts, rebuild])
rewriter.replace_op(u, [], [rebuild.result])


class SelectConstPattern(RewritePattern):
"""
arith.select %true %x %y = %x
arith.select %false %x %y = %y
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.Select, rewriter: PatternRewriter):
if not isinstance(condition := op.cond.owner, arith.Constant):
return

assert isinstance(const_cond := condition.value, IntegerAttr)

if const_cond.value.data == 1:
rewriter.replace_matched_op((), (op.lhs,))
if const_cond.value.data == 0:
rewriter.replace_matched_op((), (op.rhs,))


class SelectTrueFalsePattern(RewritePattern):
"""
arith.select %x %true %false = %x
arith.select %x %false %true = %x xor 1
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.Select, rewriter: PatternRewriter):
if op.result.type != IntegerType(1):
return

if not isinstance(lhs := op.lhs.owner, arith.Constant) or not isinstance(
rhs := op.rhs.owner, arith.Constant
):
return

assert isinstance(lhs_value := lhs.value, IntegerAttr)
assert isinstance(rhs_value := rhs.value, IntegerAttr)

if lhs_value.value.data == 1 and rhs_value.value.data == 0:
rewriter.replace_matched_op((), (op.cond,))

if lhs_value.value.data == 0 and rhs_value.value.data == 1:
rewriter.replace_matched_op(arith.XOrI(op.cond, rhs))


class SelectSamePattern(RewritePattern):
"""
arith.select %x %y %y = %y
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.Select, rewriter: PatternRewriter):
if op.lhs == op.rhs:
rewriter.replace_matched_op((), (op.lhs,))

0 comments on commit 656457a

Please sign in to comment.