Skip to content

Commit 814ca0b

Browse files
author
pytorchbot
committed
2024-11-29 nightly release (3475707)
1 parent c312d06 commit 814ca0b

23 files changed

+510
-77
lines changed

backends/arm/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,14 @@ python_library(
110110
"//executorch/backends/arm/operators:node_visitor",
111111
],
112112
)
113+
114+
python_library(
115+
name = "arm_model_evaluator",
116+
src = [
117+
"util/arm_model_evaluator.py",
118+
],
119+
typing = True,
120+
deps = [
121+
"//caffe2:torch",
122+
]
123+
)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from . import ( # noqa
99
mean_dim_support,
1010
right_shift_support,
11+
to_copy_support,
1112
tosa_supported_operators,
1213
var_correction_support,
1314
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import logging
8+
9+
import torch
10+
11+
import torch.fx as fx
12+
13+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
14+
register_tosa_support_check,
15+
SupportedTOSAOperatorCheck,
16+
)
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
@register_tosa_support_check
24+
class ToCopySupported(SupportedTOSAOperatorCheck):
25+
targets = [exir_ops.edge.aten._to_copy.default]
26+
27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
30+
]
31+
32+
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
33+
34+
@staticmethod
35+
def _merge_supported_types(
36+
dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict
37+
) -> SupportedTypeDict:
38+
merged_dtypes = dtypes1
39+
for k, v in dtypes2.items():
40+
merged_dtypes[k] = merged_dtypes.get(k, []) + v
41+
return merged_dtypes
42+
43+
SUPPORTED_INT_TYPES: SupportedTypeDict = {
44+
torch.bool: [torch.int8, torch.int16, torch.int32],
45+
torch.int8: [torch.bool, torch.int16, torch.int32],
46+
torch.int16: [torch.bool, torch.int8, torch.int32],
47+
torch.int32: [torch.bool, torch.int8, torch.int16],
48+
}
49+
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
50+
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
51+
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
52+
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
53+
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
54+
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
55+
torch.float32: [
56+
torch.int8,
57+
torch.int16,
58+
torch.int32,
59+
torch.bfloat16,
60+
torch.float16,
61+
],
62+
}
63+
ALL_SUPPORTED_TYPES = _merge_supported_types(
64+
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
65+
)
66+
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}
67+
68+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
69+
assert node.target in self.targets
70+
71+
if tosa_spec not in self.tosa_specs:
72+
return False
73+
74+
assert tosa_spec.support_integer()
75+
supported_dtypes = (
76+
self.ALL_SUPPORTED_TYPES
77+
if tosa_spec.support_float()
78+
else self.SUPPORTED_INT_TYPES
79+
)
80+
# Take into account possible type conversions
81+
supported_dtypes.update(
82+
(k, supported_dtypes[v])
83+
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
84+
if v in supported_dtypes
85+
)
86+
87+
# Check input type
88+
assert len(node.all_input_nodes) == 1
89+
input_val = node.all_input_nodes[0].meta["val"]
90+
assert isinstance(input_val, torch._subclasses.FakeTensor)
91+
input_dtype = input_val.dtype
92+
if input_dtype not in supported_dtypes:
93+
logger.info(
94+
f"Input dtype {input_val.dtype} is not supported in "
95+
f"{node.target.name()}."
96+
)
97+
return False
98+
99+
# Check output type
100+
output_val = node.meta["val"]
101+
assert isinstance(output_val, torch._subclasses.FakeTensor)
102+
if output_val.dtype not in supported_dtypes[input_dtype]:
103+
logger.info(
104+
f"Output dtype {output_val.dtype} is not supported in "
105+
f"{node.target.name()} for input dtype {input_dtype}. "
106+
f"Supported output types: "
107+
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
108+
)
109+
return False
110+
111+
# Check memory format
112+
if "memory_format" in node.kwargs:
113+
if node.kwargs["memory_format"] in (torch.preserve_format,):
114+
logger.info(
115+
f"Argument 'memory_format' is not supported for "
116+
f"{node.target.name()} right now."
117+
)
118+
return False
119+
120+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
op_sub,
3737
op_sum,
3838
op_tanh,
39+
op_to_copy,
3940
op_transpose,
4041
op_unsqueeze,
4142
op_upsample_nearest2d,

backends/arm/operators/op_to_copy.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
import tosa.Op as TosaOp
12+
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
19+
20+
@register_node_visitor
21+
class ToCopyVisitor(NodeVisitor):
22+
"""
23+
Implement the type cast functionality of _to_copy.
24+
25+
Other features like setting of the memory_format or moving a tensor to a
26+
different device are not supported.
27+
28+
Also note that the node should not be quantized.
29+
"""
30+
31+
target = "aten._to_copy.default"
32+
33+
def define_node(
34+
self,
35+
node: torch.fx.Node,
36+
tosa_graph: ts.TosaSerializer,
37+
inputs: List[TosaArg],
38+
output: TosaArg,
39+
is_quant_node: bool,
40+
) -> None:
41+
assert not is_quant_node, "Casting of quantized values is not supported."
42+
assert inputs
43+
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])

backends/arm/test/TARGETS

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "common",
5+
srcs = ["common.py"],
6+
deps = [
7+
"//executorch/backends/xnnpack/test/tester:tester",
8+
"//executorch/backends/arm:arm_backend",
9+
"//executorch/exir:lib",
10+
"//executorch/exir/backend:compile_spec_schema",
11+
]
12+
)
13+
14+
python_library(
15+
name = "runner_utils",
16+
srcs = ["runner_utils.py"],
17+
deps = [
18+
"//executorch/backends/xnnpack/test/tester:tester",
19+
"//executorch/backends/arm:arm_backend",
20+
"//executorch/exir:lib",
21+
"//executorch/exir/backend:compile_spec_schema",
22+
]
23+
)

backends/arm/test/ops/test_bmm.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class TestBMM(unittest.TestCase):
2222

2323
class BMM(torch.nn.Module):
2424
test_parameters = [
25-
(torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
2625
(torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
26+
(torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
2727
(torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
2828
(10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
2929
(-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
@@ -147,32 +147,37 @@ def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):
147147

148148
@parameterized.expand(BMM.test_parameters)
149149
@unittest.expectedFailure
150-
def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
150+
def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor):
151151
test_data = (operand1, operand2)
152152
self._test_bmm_ethosu_BI_pipeline(
153153
self.BMM(), common.get_u55_compile_spec(), test_data
154154
)
155155

156-
@parameterized.expand(BMM.test_parameters)
157-
@common.expectedFailureOnFVP
156+
@parameterized.expand(BMM.test_parameters[:1])
158157
def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
159158
test_data = (operand1, operand2)
160159
self._test_bmm_ethosu_BI_pipeline(
161160
self.BMM(), common.get_u85_compile_spec(), test_data
162161
)
163162

163+
@parameterized.expand(BMM.test_parameters[1:])
164+
@common.expectedFailureOnFVP
165+
def test_bmm_u85_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor):
166+
test_data = (operand1, operand2)
167+
self._test_bmm_ethosu_BI_pipeline(
168+
self.BMM(), common.get_u85_compile_spec(), test_data
169+
)
170+
164171
# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
165172
@parameterized.expand(BMMSingleInput.test_parameters)
166173
@unittest.expectedFailure
167-
def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor):
174+
def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor):
168175
test_data = (operand1,)
169176
self._test_bmm_ethosu_BI_pipeline(
170177
self.BMMSingleInput(), common.get_u55_compile_spec(), test_data
171178
)
172179

173-
# Numerical issues on FVP, MLETORCH 534
174180
@parameterized.expand(BMMSingleInput.test_parameters)
175-
@common.expectedFailureOnFVP
176181
def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor):
177182
test_data = (operand1,)
178183
self._test_bmm_ethosu_BI_pipeline(

backends/arm/test/ops/test_conv_combos.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,6 @@ def test_conv_meandim_u55_BI(self):
275275
model.get_inputs(),
276276
)
277277

278-
# Numerical Issues on FVP, MLETORCH-520
279-
@common.expectedFailureOnFVP
280278
def test_conv_meandim_u85_BI(self):
281279
model = ComboConv2dMeandim()
282280
self._test_conv_combo_ethos_BI_pipeline(

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,19 @@
156156
("two_dw_conv2d", two_dw_conv2d),
157157
]
158158

159+
testsuite_conv2d_u85 = [
160+
("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1),
161+
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1),
162+
("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1),
163+
("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias),
164+
]
165+
166+
testsuite_conv2d_u85_xfails = [
167+
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
168+
("two_dw_conv2d", two_dw_conv2d),
169+
]
170+
171+
159172
testsuite_conv1d = [
160173
("2_1x6x4_gp6_st1", dw_conv1d_2_1x6x4_gp6_st1),
161174
("two_dw_conv1d", two_dw_conv1d),
@@ -247,7 +260,7 @@ def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module):
247260
) # Works
248261

249262
@parameterized.expand(testsuite_conv2d, skip_on_empty=True)
250-
@common.expectedFailureOnFVP
263+
@unittest.expectedFailure
251264
def test_dw_conv2d_u55_BI(
252265
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
253266
):
@@ -274,10 +287,8 @@ def test_dw_conv1d_u55_BI(
274287
model.get_inputs(),
275288
)
276289

277-
# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
278-
@parameterized.expand(testsuite_conv1d[:-2] + testsuite_conv2d)
279-
@common.expectedFailureOnFVP
280-
def test_dw_conv_u85_BI_xfails(
290+
@parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85)
291+
def test_dw_conv_u85_BI(
281292
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
282293
):
283294
self._test_dw_conv_ethos_BI_pipeline(
@@ -288,8 +299,10 @@ def test_dw_conv_u85_BI_xfails(
288299
model.get_inputs(),
289300
)
290301

291-
@parameterized.expand(testsuite_conv1d[-2:])
292-
def test_dw_conv_u85_BI(
302+
# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
303+
@parameterized.expand(testsuite_conv2d_u85_xfails)
304+
@common.expectedFailureOnFVP
305+
def test_dw_conv_u85_BI_xfails(
293306
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
294307
):
295308
self._test_dw_conv_ethos_BI_pipeline(

0 commit comments

Comments
 (0)