14
14
"""
15
15
Summary of non-working cases.
16
16
MI:
17
- Any case with int scalar: A to_copy is inserted to cast the value which we don't partition.
18
- This makes the constant end up outside our partition and the input to the delegate becomes
19
- a to_copy placeholder. In ArmTester, the placeholder is then interpreted as an input.
20
- Potential fix: partition int -> float to_copy-ops in ArmBackend.
21
- # MLETORCH-407
22
17
Op(scalar, tensor):
23
18
One issue is that lift_constant_tensor_pass looks for a fake_tensor in the meta of the first
24
19
node which does not work the first node is a scalar.
27
22
somewhere in _transform in the to_edge step. This makes ArmPartitioner miss tagging the
28
23
data in tag_constant_data.
29
24
# MLETORCH-408
30
-
31
- BI:
32
- sub(Scalar, Tensor) becomes rsub, which either fails since the scalar does not become an attribute
33
- in scalars_to_attribute_pass, or, if added to targeted_ops in that pass, fails since rsub expects a
34
- Scalar.
35
- Potential fix: Create pass to convert rsub.Scalar to sub.Tensor
25
+ Sub or inplace-sub with an integer input.
36
26
"""
37
27
38
28
39
29
class TestScalars (unittest .TestCase ):
40
- """Tests various scalar cases for for """
30
+ """Tests various scalar cases"""
41
31
42
32
class Add (torch .nn .Module ):
43
33
def forward (self , x , y ):
@@ -133,13 +123,10 @@ def forward(self, x):
133
123
scalar = dtype [1 ]
134
124
tensor_scalar_tests .append ((test_name + "_ts" , op [1 ], tensor , scalar ))
135
125
136
- # Don't add (scalar, tensor) test case for inplace and .Scalar ops.
137
- if op [0 ][- 1 ] == "_" or op [ 0 ][ - 6 :] == "Scalar" :
126
+ # Don't add (scalar, tensor) test case for .Scalar ops.
127
+ if op [0 ][- 6 :] == "Scalar" :
138
128
continue
139
129
140
- # sub(scalar, tensor) does not work in any case.
141
- if op [0 ][0 :3 ] == "Sub" :
142
- continue
143
130
tensor_scalar_tests .append ((test_name + "_st" , op [1 ], scalar , tensor ))
144
131
145
132
tensor_const_tests = []
@@ -182,8 +169,8 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
182
169
def test_MI (self , test_name : str , op : torch .nn .Module , x , y ):
183
170
expected_exception = None
184
171
if any (token in test_name for token in ("Sub_int" , "Sub__int" )):
185
- expected_exception = ( AssertionError , ValueError )
186
- elif test_name .endswith ("_st" ):
172
+ expected_exception = AssertionError
173
+ if test_name .endswith ("_st" ):
187
174
expected_exception = AttributeError
188
175
189
176
if expected_exception :
@@ -204,5 +191,13 @@ def test_MI_const(self, test_name: str, op: torch.nn.Module, x):
204
191
def test_BI (self , test_name : str , op : torch .nn .Module , x , y ):
205
192
self ._test_add_tosa_BI_pipeline (op , (x , y ))
206
193
194
+ # op(Scalar float, tensor) works if the scalar is constant.
195
+ @parameterized .expand (tensor_const_tests )
196
+ def test_BI_const (self , test_name : str , op : torch .nn .Module , x ):
197
+ self ._test_add_tosa_BI_pipeline (op , (x ,))
198
+
207
199
def test_shift_sub_inplace_tosa_MI (self ):
208
200
self ._test_add_tosa_MI_pipeline (self .ShiftInplaceSub (), (torch .IntTensor (5 ),))
201
+
202
+ def test_shift_sub_inplace_tosa_BI (self ):
203
+ self ._test_add_tosa_BI_pipeline (self .ShiftInplaceSub (), (torch .IntTensor (5 ),))
0 commit comments