forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomparisons.py
141 lines (115 loc) · 3.37 KB
/
comparisons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
from npcomp.compiler.numpy import test_config
import_global = test_config.create_import_dump_decorator()
# CHECK-LABEL: func @binary_lt_
@import_global
def binary_lt_():
# CHECK: %[[A:.*]] = constant 1 : i64
# CHECK: %[[B:.*]] = constant 2 : i64
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare %[[A]] "Lt" %[[B]] : i64, i64
return x < y
# CHECK-LABEL: func @binary_gt_
@import_global
def binary_gt_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Gt" {{.*}} : i64, i64
return x > y
# CHECK-LABEL: func @binary_lte_
@import_global
def binary_lte_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "LtE" {{.*}} : i64, i64
return x <= y
# CHECK-LABEL: func @binary_gte_
@import_global
def binary_gte_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "GtE" {{.*}} : i64, i64
return x >= y
# CHECK-LABEL: func @binary_eq_
@import_global
def binary_eq_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Eq" {{.*}} : i64, i64
return x == y
# CHECK-LABEL: func @binary_neq_
@import_global
def binary_neq_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotEq" {{.*}} : i64, i64
return x != y
# CHECK-LABEL: func @binary_is_
@import_global
def binary_is_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Is" {{.*}} : i64, i64
return x is y
# CHECK-LABEL: func @binary_is_not_
@import_global
def binary_is_not_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "IsNot" {{.*}} : i64, i64
return x is not y
# CHECK-LABEL: func @binary_in_
@import_global
def binary_in_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "In" {{.*}} : i64, i64
return x in y
# CHECK-LABEL: func @binary_not_in_
@import_global
def binary_not_in_():
x = 1
y = 2
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotIn" {{.*}} : i64, i64
return x not in y
@import_global
def short_circuit():
# CHECK: %[[X:.*]] = constant 1 : i64
# CHECK: %[[Y:.*]] = constant 2 : i64
# CHECK: %[[Z:.*]] = constant 3 : i64
# CHECK: %[[OMEGA:.*]] = constant 5 : i64
x = 1
y = 2
z = 3
omega = 5
# CHECK: %[[FALSE:.*]] = basicpy.bool_constant false
# CHECK: %[[CMP0:.*]] = basicpy.binary_compare %[[X]] "Lt" %[[Y]]
# CHECK: %[[CMP0_CAST:.*]] = basicpy.bool_cast %[[CMP0]] : !basicpy.BoolType -> i1
# CHECK: %[[IF0:.*]] = scf.if %[[CMP0_CAST]] -> (!basicpy.BoolType) {
# CHECK: %[[CMP1:.*]] = basicpy.binary_compare %[[Y]] "Eq" %[[Z]]
# CHECK: %[[CMP1_CAST:.*]] = basicpy.bool_cast %[[CMP1]] : !basicpy.BoolType -> i1
# CHECK: %[[IF1:.*]] = scf.if %[[CMP1_CAST]] {{.*}} {
# CHECK: %[[CMP2:.*]] = basicpy.binary_compare %[[Z]] "GtE" %[[OMEGA]]
# CHECK: scf.yield %[[CMP2]]
# CHECK: } else {
# CHECK: scf.yield %[[FALSE]]
# CHECK: }
# CHECK: scf.yield %[[IF1]]
# CHECK: } else {
# CHECK: scf.yield %[[FALSE]]
# CHECK: }
# CHECK: %[[RESULT:.*]] = basicpy.unknown_cast %[[IF0]]
# CHECK: return %[[RESULT]]
return x < y == z >= omega
# CHECK-LABEL: nested_short_circuit_expression
@import_global
def nested_short_circuit_expression():
x = 1
y = 2
z = 3
# Verify that the (z + 5) gets nested into the if.
# CHECK: scf.if {{.*}} {
# CHECK-NEXT: constant 6
# CHECK-NEXT: binary_expr
return x < y == (z + 6)