Skip to content

Commit

Permalink
2024-11-29 nightly release (3475707)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 29, 2024
1 parent c312d06 commit 814ca0b
Show file tree
Hide file tree
Showing 23 changed files with 510 additions and 77 deletions.
11 changes: 11 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,14 @@ python_library(
"//executorch/backends/arm/operators:node_visitor",
],
)

python_library(
name = "arm_model_evaluator",
src = [
"util/arm_model_evaluator.py",
],
typing = True,
deps = [
"//caffe2:torch",
]
)
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import ( # noqa
mean_dim_support,
right_shift_support,
to_copy_support,
tosa_supported_operators,
var_correction_support,
)
120 changes: 120 additions & 0 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
import logging

import torch

import torch.fx as fx

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)


@register_tosa_support_check
class ToCopySupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten._to_copy.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]

@staticmethod
def _merge_supported_types(
dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict
) -> SupportedTypeDict:
merged_dtypes = dtypes1
for k, v in dtypes2.items():
merged_dtypes[k] = merged_dtypes.get(k, []) + v
return merged_dtypes

SUPPORTED_INT_TYPES: SupportedTypeDict = {
torch.bool: [torch.int8, torch.int16, torch.int32],
torch.int8: [torch.bool, torch.int16, torch.int32],
torch.int16: [torch.bool, torch.int8, torch.int32],
torch.int32: [torch.bool, torch.int8, torch.int16],
}
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
torch.float32: [
torch.int8,
torch.int16,
torch.int32,
torch.bfloat16,
torch.float16,
],
}
ALL_SUPPORTED_TYPES = _merge_supported_types(
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
)
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
assert node.target in self.targets

if tosa_spec not in self.tosa_specs:
return False

assert tosa_spec.support_integer()
supported_dtypes = (
self.ALL_SUPPORTED_TYPES
if tosa_spec.support_float()
else self.SUPPORTED_INT_TYPES
)
# Take into account possible type conversions
supported_dtypes.update(
(k, supported_dtypes[v])
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
if v in supported_dtypes
)

# Check input type
assert len(node.all_input_nodes) == 1
input_val = node.all_input_nodes[0].meta["val"]
assert isinstance(input_val, torch._subclasses.FakeTensor)
input_dtype = input_val.dtype
if input_dtype not in supported_dtypes:
logger.info(
f"Input dtype {input_val.dtype} is not supported in "
f"{node.target.name()}."
)
return False

# Check output type
output_val = node.meta["val"]
assert isinstance(output_val, torch._subclasses.FakeTensor)
if output_val.dtype not in supported_dtypes[input_dtype]:
logger.info(
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target.name()} for input dtype {input_dtype}. "
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
)
return False

# Check memory format
if "memory_format" in node.kwargs:
if node.kwargs["memory_format"] in (torch.preserve_format,):
logger.info(
f"Argument 'memory_format' is not supported for "
f"{node.target.name()} right now."
)
return False

return True
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
op_sub,
op_sum,
op_tanh,
op_to_copy,
op_transpose,
op_unsqueeze,
op_upsample_nearest2d,
Expand Down
43 changes: 43 additions & 0 deletions backends/arm/operators/op_to_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import torch
import tosa.Op as TosaOp

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg


@register_node_visitor
class ToCopyVisitor(NodeVisitor):
"""
Implement the type cast functionality of _to_copy.
Other features like setting of the memory_format or moving a tensor to a
different device are not supported.
Also note that the node should not be quantized.
"""

target = "aten._to_copy.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert not is_quant_node, "Casting of quantized values is not supported."
assert inputs
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])
23 changes: 23 additions & 0 deletions backends/arm/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "common",
srcs = ["common.py"],
deps = [
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/arm:arm_backend",
"//executorch/exir:lib",
"//executorch/exir/backend:compile_spec_schema",
]
)

python_library(
name = "runner_utils",
srcs = ["runner_utils.py"],
deps = [
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/arm:arm_backend",
"//executorch/exir:lib",
"//executorch/exir/backend:compile_spec_schema",
]
)
19 changes: 12 additions & 7 deletions backends/arm/test/ops/test_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class TestBMM(unittest.TestCase):

class BMM(torch.nn.Module):
test_parameters = [
(torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
(torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
(torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
(torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
(10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
(-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
Expand Down Expand Up @@ -147,32 +147,37 @@ def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):

@parameterized.expand(BMM.test_parameters)
@unittest.expectedFailure
def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_bmm_ethosu_BI_pipeline(
self.BMM(), common.get_u55_compile_spec(), test_data
)

@parameterized.expand(BMM.test_parameters)
@common.expectedFailureOnFVP
@parameterized.expand(BMM.test_parameters[:1])
def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_bmm_ethosu_BI_pipeline(
self.BMM(), common.get_u85_compile_spec(), test_data
)

@parameterized.expand(BMM.test_parameters[1:])
@common.expectedFailureOnFVP
def test_bmm_u85_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_bmm_ethosu_BI_pipeline(
self.BMM(), common.get_u85_compile_spec(), test_data
)

# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
@parameterized.expand(BMMSingleInput.test_parameters)
@unittest.expectedFailure
def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor):
def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor):
test_data = (operand1,)
self._test_bmm_ethosu_BI_pipeline(
self.BMMSingleInput(), common.get_u55_compile_spec(), test_data
)

# Numerical issues on FVP, MLETORCH 534
@parameterized.expand(BMMSingleInput.test_parameters)
@common.expectedFailureOnFVP
def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor):
test_data = (operand1,)
self._test_bmm_ethosu_BI_pipeline(
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,6 @@ def test_conv_meandim_u55_BI(self):
model.get_inputs(),
)

# Numerical Issues on FVP, MLETORCH-520
@common.expectedFailureOnFVP
def test_conv_meandim_u85_BI(self):
model = ComboConv2dMeandim()
self._test_conv_combo_ethos_BI_pipeline(
Expand Down
27 changes: 20 additions & 7 deletions backends/arm/test/ops/test_depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@
("two_dw_conv2d", two_dw_conv2d),
]

testsuite_conv2d_u85 = [
("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1),
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1),
("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1),
("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias),
]

testsuite_conv2d_u85_xfails = [
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
("two_dw_conv2d", two_dw_conv2d),
]


testsuite_conv1d = [
("2_1x6x4_gp6_st1", dw_conv1d_2_1x6x4_gp6_st1),
("two_dw_conv1d", two_dw_conv1d),
Expand Down Expand Up @@ -247,7 +260,7 @@ def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module):
) # Works

@parameterized.expand(testsuite_conv2d, skip_on_empty=True)
@common.expectedFailureOnFVP
@unittest.expectedFailure
def test_dw_conv2d_u55_BI(
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
):
Expand All @@ -274,10 +287,8 @@ def test_dw_conv1d_u55_BI(
model.get_inputs(),
)

# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
@parameterized.expand(testsuite_conv1d[:-2] + testsuite_conv2d)
@common.expectedFailureOnFVP
def test_dw_conv_u85_BI_xfails(
@parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85)
def test_dw_conv_u85_BI(
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
):
self._test_dw_conv_ethos_BI_pipeline(
Expand All @@ -288,8 +299,10 @@ def test_dw_conv_u85_BI_xfails(
model.get_inputs(),
)

@parameterized.expand(testsuite_conv1d[-2:])
def test_dw_conv_u85_BI(
# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
@parameterized.expand(testsuite_conv2d_u85_xfails)
@common.expectedFailureOnFVP
def test_dw_conv_u85_BI_xfails(
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
):
self._test_dw_conv_ethos_BI_pipeline(
Expand Down
Loading

0 comments on commit 814ca0b

Please sign in to comment.