Skip to content

Commit

Permalink
fix arith cmpi (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Feb 11, 2025
1 parent 4d77301 commit 56ac9da
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 29 deletions.
3 changes: 2 additions & 1 deletion mlir/extras/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def _binary_op(
if signedness is not None:
predicate = signedness + predicate
else:
if lhs.dtype.is_signed:
if lhs.dtype.is_signed or lhs.dtype.is_signless:
predicate = "s" + predicate
else:
assert lhs.dtype.is_unsigned
predicate = "u" + predicate
return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype)
else:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_operator_overloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def test_arith_cmp(ctx: MLIRContext):
module {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi ule, %c1_i32, %c2_i32 : i32
%2 = arith.cmpi ugt, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi uge, %c1_i32, %c2_i32 : i32
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi sle, %c1_i32, %c2_i32 : i32
%2 = arith.cmpi sgt, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi sge, %c1_i32, %c2_i32 : i32
%4 = arith.cmpi eq, %c1_i32, %c2_i32 : i32
%5 = arith.cmpi ne, %c1_i32, %c2_i32 : i32
%6 = arith.andi %c1_i32, %c2_i32 : i32
Expand Down
6 changes: 3 additions & 3 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def foo1():
^bb1: // no predecessors
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%0 = arith.cmpi ult, %c2_i32, %c3_i32 : i32
%0 = arith.cmpi slt, %c2_i32, %c3_i32 : i32
cf.cond_br %0, ^bb2, ^bb3
^bb2: // pred: ^bb1
%c4_i32 = arith.constant 4 : i32
Expand Down Expand Up @@ -448,7 +448,7 @@ def foo1():
^bb1: // no predecessors
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%0 = arith.cmpi ult, %c2_i32, %c3_i32 : i32
%0 = arith.cmpi slt, %c2_i32, %c3_i32 : i32
cf.cond_br %0, ^bb2(%c2_i32, %c3_i32 : i32, i32), ^bb3(%c2_i32, %c3_i32 : i32, i32)
^bb2(%1: i32, %2: i32): // pred: ^bb1
%c4_i32 = arith.constant 4 : i32
Expand Down Expand Up @@ -583,7 +583,7 @@ def foo1():
^bb1: // no predecessors
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%0 = arith.cmpi ult, %c2_i32, %c3_i32 : i32
%0 = arith.cmpi slt, %c2_i32, %c3_i32 : i32
cf.cond_br %0, ^bb2, ^bb3
^bb2: // pred: ^bb1
%c4_i32 = arith.constant 4 : i32
Expand Down
42 changes: 21 additions & 21 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2685,7 +2685,7 @@ def test_while_2(ctx: MLIRContext):
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%1 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%1) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2715,7 +2715,7 @@ def foo():
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%1 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%1) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2745,7 +2745,7 @@ def foo():
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%1 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%1) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2775,10 +2775,10 @@ def foo():
module {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.if %0 {
%1:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%2 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%2 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%2) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2815,28 +2815,28 @@ def foo():
module {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.if %0 {
%1:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%2 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%2 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%2) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
scf.yield %c1_i32, %c2_i32 : i32, i32
}
} else {
%1 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.if %1 {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
scf.yield %c1_i32, %c2_i32 : i32, i32
}
} else {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2873,10 +2873,10 @@ def foo():
module {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
%1:2 = scf.if %0 -> (i32, i32) {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand All @@ -2885,7 +2885,7 @@ def foo():
scf.yield %2#0, %2#1 : i32, i32
} else {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2922,10 +2922,10 @@ def foo():
module {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
%1:2 = scf.if %0 -> (i32, i32) {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand All @@ -2934,7 +2934,7 @@ def foo():
scf.yield %2#0, %2#1 : i32, i32
} else {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -2971,10 +2971,10 @@ def foo():
module {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
%1:2 = scf.if %0 -> (i32, i32) {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand All @@ -2983,7 +2983,7 @@ def foo():
scf.yield %2#0, %2#1 : i32, i32
} else {
%2:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%3 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%3 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%3) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
Expand Down Expand Up @@ -3017,11 +3017,11 @@ def foo():
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%0:2 = scf.while (%arg0 = %c1_i32, %arg1 = %c2_i32) : (i32, i32) -> (i32, i32) {
%1 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.condition(%1) %arg0, %arg1 : i32, i32
} do {
^bb0(%arg0: i32, %arg1: i32):
%1 = arith.cmpi ult, %c1_i32, %c2_i32 : i32
%1 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
scf.if %1 {
%c3_i32 = arith.constant 3 : i32
}
Expand Down

0 comments on commit 56ac9da

Please sign in to comment.