Skip to content

Commit a01571f

Browse files
Arm: Enable rsub lowering (#8525)
Enable rsub lowering Update ScalarsToAttrbibutePass to replace rsub with sub. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 14ff52f commit a01571f

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

backends/arm/_passes/scalars_to_attribute_pass.py

+12
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,17 @@ def call(self, graph_module: GraphModule) -> PassResult:
7676
new_args.append(get_attr_node)
7777
n.args = tuple(new_args)
7878

79+
# Replace rsub.Scalar with sub.Tensor as retracing will fail otherwise
80+
if n.target == torch.ops.aten.rsub.Scalar:
81+
with graph_module.graph.inserting_after(n):
82+
reversed_args = (n.args[1], n.args[0])
83+
sub = graph_module.graph.create_node(
84+
"call_function", torch.ops.aten.sub.Tensor, reversed_args, {}
85+
)
86+
n.replace_all_uses_with(sub)
87+
sub.meta["val"] = n.meta["val"]
88+
graph_module.graph.erase_node(n)
89+
7990
graph_module.recompile()
91+
graph_module = super().call(graph_module).graph_module
8092
return PassResult(graph_module, True)

backends/arm/test/ops/test_scalars.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
"""
1515
Summary of non-working cases.
1616
MI:
17-
Any case with int scalar: A to_copy is inserted to cast the value which we don't partition.
18-
This makes the constant end up outside our partition and the input to the delegate becomes
19-
a to_copy placeholder. In ArmTester, the placeholder is then interpreted as an input.
20-
Potential fix: partition int -> float to_copy-ops in ArmBackend.
21-
# MLETORCH-407
2217
Op(scalar, tensor):
2318
One issue is that lift_constant_tensor_pass looks for a fake_tensor in the meta of the first
2419
node which does not work the first node is a scalar.
@@ -27,17 +22,12 @@
2722
somewhere in _transform in the to_edge step. This makes ArmPartitioner miss tagging the
2823
data in tag_constant_data.
2924
# MLETORCH-408
30-
31-
BI:
32-
sub(Scalar, Tensor) becomes rsub, which either fails since the scalar does not become an attribute
33-
in scalars_to_attribute_pass, or, if added to targeted_ops in that pass, fails since rsub expects a
34-
Scalar.
35-
Potential fix: Create pass to convert rsub.Scalar to sub.Tensor
25+
Sub or inplace-sub with an integer input.
3626
"""
3727

3828

3929
class TestScalars(unittest.TestCase):
40-
"""Tests various scalar cases for for"""
30+
"""Tests various scalar cases"""
4131

4232
class Add(torch.nn.Module):
4333
def forward(self, x, y):
@@ -133,13 +123,10 @@ def forward(self, x):
133123
scalar = dtype[1]
134124
tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar))
135125

136-
# Don't add (scalar, tensor) test case for inplace and .Scalar ops.
137-
if op[0][-1] == "_" or op[0][-6:] == "Scalar":
126+
# Don't add (scalar, tensor) test case for .Scalar ops.
127+
if op[0][-6:] == "Scalar":
138128
continue
139129

140-
# sub(scalar, tensor) does not work in any case.
141-
if op[0][0:3] == "Sub":
142-
continue
143130
tensor_scalar_tests.append((test_name + "_st", op[1], scalar, tensor))
144131

145132
tensor_const_tests = []
@@ -182,8 +169,8 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
182169
def test_MI(self, test_name: str, op: torch.nn.Module, x, y):
183170
expected_exception = None
184171
if any(token in test_name for token in ("Sub_int", "Sub__int")):
185-
expected_exception = (AssertionError, ValueError)
186-
elif test_name.endswith("_st"):
172+
expected_exception = AssertionError
173+
if test_name.endswith("_st"):
187174
expected_exception = AttributeError
188175

189176
if expected_exception:
@@ -204,5 +191,13 @@ def test_MI_const(self, test_name: str, op: torch.nn.Module, x):
204191
def test_BI(self, test_name: str, op: torch.nn.Module, x, y):
205192
self._test_add_tosa_BI_pipeline(op, (x, y))
206193

194+
# op(Scalar float, tensor) works if the scalar is constant.
195+
@parameterized.expand(tensor_const_tests)
196+
def test_BI_const(self, test_name: str, op: torch.nn.Module, x):
197+
self._test_add_tosa_BI_pipeline(op, (x,))
198+
207199
def test_shift_sub_inplace_tosa_MI(self):
208200
self._test_add_tosa_MI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),))
201+
202+
def test_shift_sub_inplace_tosa_BI(self):
203+
self._test_add_tosa_BI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),))

0 commit comments

Comments
 (0)