-
Notifications
You must be signed in to change notification settings - Fork 413
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2024-11-29 nightly release (3475707)
- Loading branch information
pytorchbot
committed
Nov 29, 2024
1 parent
c312d06
commit 814ca0b
Showing
23 changed files
with
510 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
import logging | ||
|
||
import torch | ||
|
||
import torch.fx as fx | ||
|
||
from executorch.backends.arm.operator_support.tosa_supported_operators import ( | ||
register_tosa_support_check, | ||
SupportedTOSAOperatorCheck, | ||
) | ||
from executorch.backends.arm.tosa_specification import TosaSpecification | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@register_tosa_support_check | ||
class ToCopySupported(SupportedTOSAOperatorCheck): | ||
targets = [exir_ops.edge.aten._to_copy.default] | ||
|
||
tosa_specs = [ | ||
TosaSpecification.create_from_string("TOSA-0.80.0+BI"), | ||
TosaSpecification.create_from_string("TOSA-0.80.0+MI"), | ||
] | ||
|
||
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] | ||
|
||
@staticmethod | ||
def _merge_supported_types( | ||
dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict | ||
) -> SupportedTypeDict: | ||
merged_dtypes = dtypes1 | ||
for k, v in dtypes2.items(): | ||
merged_dtypes[k] = merged_dtypes.get(k, []) + v | ||
return merged_dtypes | ||
|
||
SUPPORTED_INT_TYPES: SupportedTypeDict = { | ||
torch.bool: [torch.int8, torch.int16, torch.int32], | ||
torch.int8: [torch.bool, torch.int16, torch.int32], | ||
torch.int16: [torch.bool, torch.int8, torch.int32], | ||
torch.int32: [torch.bool, torch.int8, torch.int16], | ||
} | ||
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { | ||
torch.int8: [torch.float16, torch.bfloat16, torch.float32], | ||
torch.int16: [torch.float16, torch.bfloat16, torch.float32], | ||
torch.int32: [torch.float16, torch.bfloat16, torch.float32], | ||
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32], | ||
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32], | ||
torch.float32: [ | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.bfloat16, | ||
torch.float16, | ||
], | ||
} | ||
ALL_SUPPORTED_TYPES = _merge_supported_types( | ||
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES | ||
) | ||
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} | ||
|
||
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: | ||
assert node.target in self.targets | ||
|
||
if tosa_spec not in self.tosa_specs: | ||
return False | ||
|
||
assert tosa_spec.support_integer() | ||
supported_dtypes = ( | ||
self.ALL_SUPPORTED_TYPES | ||
if tosa_spec.support_float() | ||
else self.SUPPORTED_INT_TYPES | ||
) | ||
# Take into account possible type conversions | ||
supported_dtypes.update( | ||
(k, supported_dtypes[v]) | ||
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items() | ||
if v in supported_dtypes | ||
) | ||
|
||
# Check input type | ||
assert len(node.all_input_nodes) == 1 | ||
input_val = node.all_input_nodes[0].meta["val"] | ||
assert isinstance(input_val, torch._subclasses.FakeTensor) | ||
input_dtype = input_val.dtype | ||
if input_dtype not in supported_dtypes: | ||
logger.info( | ||
f"Input dtype {input_val.dtype} is not supported in " | ||
f"{node.target.name()}." | ||
) | ||
return False | ||
|
||
# Check output type | ||
output_val = node.meta["val"] | ||
assert isinstance(output_val, torch._subclasses.FakeTensor) | ||
if output_val.dtype not in supported_dtypes[input_dtype]: | ||
logger.info( | ||
f"Output dtype {output_val.dtype} is not supported in " | ||
f"{node.target.name()} for input dtype {input_dtype}. " | ||
f"Supported output types: " | ||
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}" | ||
) | ||
return False | ||
|
||
# Check memory format | ||
if "memory_format" in node.kwargs: | ||
if node.kwargs["memory_format"] in (torch.preserve_format,): | ||
logger.info( | ||
f"Argument 'memory_format' is not supported for " | ||
f"{node.target.name()} right now." | ||
) | ||
return False | ||
|
||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
op_sub, | ||
op_sum, | ||
op_tanh, | ||
op_to_copy, | ||
op_transpose, | ||
op_unsqueeze, | ||
op_upsample_nearest2d, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
from typing import List | ||
|
||
import serializer.tosa_serializer as ts | ||
import torch | ||
import tosa.Op as TosaOp | ||
|
||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
|
||
|
||
@register_node_visitor | ||
class ToCopyVisitor(NodeVisitor): | ||
""" | ||
Implement the type cast functionality of _to_copy. | ||
Other features like setting of the memory_format or moving a tensor to a | ||
different device are not supported. | ||
Also note that the node should not be quantized. | ||
""" | ||
|
||
target = "aten._to_copy.default" | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
assert not is_quant_node, "Casting of quantized values is not supported." | ||
assert inputs | ||
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
load("@fbcode_macros//build_defs:python_library.bzl", "python_library") | ||
|
||
python_library( | ||
name = "common", | ||
srcs = ["common.py"], | ||
deps = [ | ||
"//executorch/backends/xnnpack/test/tester:tester", | ||
"//executorch/backends/arm:arm_backend", | ||
"//executorch/exir:lib", | ||
"//executorch/exir/backend:compile_spec_schema", | ||
] | ||
) | ||
|
||
python_library( | ||
name = "runner_utils", | ||
srcs = ["runner_utils.py"], | ||
deps = [ | ||
"//executorch/backends/xnnpack/test/tester:tester", | ||
"//executorch/backends/arm:arm_backend", | ||
"//executorch/exir:lib", | ||
"//executorch/exir/backend:compile_spec_schema", | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.