diff --git a/backends/arm/README.md b/backends/arm/README.md index 6f4642f8d4..a7458db07c 100644 --- a/backends/arm/README.md +++ b/backends/arm/README.md @@ -104,8 +104,8 @@ The Arm Backend should be considered a prototype quality at this point, likely s ## Current flows The ArmBackend has a two stage process, -- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80.0 TOSA BI with specific concern to a subset which gives support on Ethos-U55, the target of the initial prototype efforts. -- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80.0 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution. +- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80 TOSA BI with specific concern to a subset which gives support on Ethos-U55, the target of the initial prototype efforts. +- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution. The ArmPartitioner is currenly used to ensure the operations converted are Ethos-U compatible, but will be extended to offer spec-correct TOSA Base inference and TOSA Main Inference generation in future. diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 25811d077b..b4bb809b85 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -29,6 +29,10 @@ DecomposeSoftmaxesPass, ) from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, + QuantizeFullArgument, +) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, ) @@ -50,6 +54,7 @@ from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_manager import PassManager @@ -80,6 +85,19 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(QuantizeFullArgument()) + self.add_pass( + FoldAndAnnotateQParamsPass( + [ + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.full.default, + ] + ) + ) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py new file mode 100644 index 0000000000..f078cf2118 --- /dev/null +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -0,0 +1,183 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +from typing import cast, Iterable + +from executorch.backends.arm.tosa_quant_utils import QuantArgs + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + +q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default +dq_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + + +def get_input_qparams(node: Node) -> dict[int, QuantArgs]: + """ + Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. + Raises a ValueError if the node doesn't have any parameters set. + """ + if "input_qparams" not in node.meta.keys(): + raise ValueError(f"No input quantization parameter found in node {node}") + input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"]) + if len(input_qparams) == 0: + raise ValueError(f"No input quantization parameter found in node {node}") + return input_qparams + + +def get_output_qparams(node: Node) -> dict[int, QuantArgs]: + """ + Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. + Raises a ValueError if the node doesn't have any parameters set. + """ + if "output_qparams" not in node.meta.keys(): + raise ValueError(f"No output quantization parameter found in node {node}") + input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"]) + if len(input_qparams) == 0: + raise ValueError(f"No output quantization parameter found in node {node}") + return input_qparams + + +class FoldAndAnnotateQParamsPass(ExportPass): + """ + A pass that walks the graph and removes any DQ and Q nodes before and after the target + node in the supplied list of operators. + The quantization parameters from the DQ/Q nodes are stored as meta values to be + accessible for later lowering and serialization passes. + The assumption is that the quantization annotatation adds DQ nodes for all tensor + inputs to the target one Q node to the output. + + Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability): + + x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) + + x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8) + aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq) + aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8) + + output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) + + Becomes: + x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) + + aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q) + + output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) + + The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node. + + """ + + def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None: + super().__init__() + self.targeted_ops = targeted_ops + + def call(self, graph_module: GraphModule) -> PassResult: + + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.op != "call_function" or n.target not in self.targeted_ops: + continue + + # Make sure we haven't already set qparams meta information on the node + assert "input_qparams" not in n.meta.keys() + assert "output_qparams" not in n.meta.keys() + + # for the inputs and outputs search the graph for quantization info and + # store the information in a dict with order of the _tensor_ inputs as key, + # ignoring any other arguments to the target node. + n.meta["input_qparams"] = {} + n.meta["output_qparams"] = {} + for i, arg in enumerate(n.args): + if not isinstance(arg, Node): + continue + + # Make sure arg has requires_grad set to False + # For parameters that are not quantized, sometimes (i.e. convolution) + # the Parameter(FakeTensor(...)) has requires_grad set to True, which + # causes the retracing of the graph to fail with: + # + # E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. + # E + # E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + # E Original traceback: + # E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward + # E x = conv(x) + # + if arg.op == "placeholder": + arg.meta["val"].requires_grad = False + + if arg.target != dq_op: + continue + + # arg.target for argument i is a dequant node, extract the information + n.meta["input_qparams"][i] = QuantArgs.from_operator( + arg.target, arg.args + ) + + # arg.args[0] is the tensor input, replace the input usage + tensor_input = cast(Node, arg.args[0]) + n.replace_input_with(arg, tensor_input) + graph_module.graph.erase_node(arg) + + # Copy the users, since we are modifying it. + users_copy = copy.copy(n.users) + for i, user in enumerate(users_copy): + if user.target != q_op: + continue + + # quantization node found here, store the quantization parameters in meta value + n.meta["output_qparams"][i] = QuantArgs.from_operator( + user.target, user.args + ) + + user.replace_all_uses_with(n) + graph_module.graph.erase_node(user) + + # retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + + graph_module.recompile() + return PassResult(graph_module, True) + + +class QuantizeFullArgument(ExportPass): + """ + Make sure the fill_value for full.default is quantized. This pass needs to be run before + the folding pass above to make sure that the retraced output of the full.default op is + the right dtype. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.target != exir_ops.edge.aten.full.default: + continue + + # Make sure we have a quantized operator + user = list(n.users)[0] + if user.target != q_op: + continue + + qargs = QuantArgs.from_operator(user.target, user.args) + if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: + # replace the node arg with a quantized dito and also set dtype + # to get the right output according to the Edge IR specification: + # exir/dialects/edge/edge.yaml:3596 + quantized_full_value = qargs.quantize_value(n.args[1]).item() + n.update_arg(1, quantized_full_value) + n.update_kwarg("dtype", qargs.dtype) + modified = True + + return PassResult(graph_module, modified) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 93b4548390..e2fdc42b11 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -90,7 +90,7 @@ def ethosu_compile_spec( if extra_flags is not None: self.compiler_flags.append(extra_flags) - base_tosa_version = "TOSA-0.80.0+BI" + base_tosa_version = "TOSA-0.80+BI" if "u55" in config: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index ee8d5965a1..bf548982b7 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -23,8 +23,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten.__rshift__.Scalar] tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+BI"), - TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), ] def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index dcf2ce316b..f2968585f2 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -25,8 +25,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten._to_copy.default] tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+BI"), - TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), ] SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7072ba6a82..7f92574cfd 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -35,8 +35,8 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool _tosa_spec_dicts: dict[ TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]] ] = { - TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, - TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, + TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, + TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, } @@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 8c4aa85e57..6db9c968f0 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,7 +19,9 @@ op_get_item, op_hardtanh, op_log, + op_max, op_max_pool2d, + op_min, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 9e98ebcab9..87ef6ed4c6 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -25,8 +25,8 @@ class NodeVisitor: # When all node_visitors has been refactored to target a specific # version, this list should be removed. tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+BI"), - TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), ] def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): @@ -46,8 +46,8 @@ def define_node( # container for all node visitors _node_visitor_dicts = { - TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, - TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, + TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, + TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, } diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index e52f3eddae..a81e52c5c6 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,7 +11,6 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts -import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -27,7 +26,7 @@ class AddVisitor_080_BI(NodeVisitor): target = "aten.add.Tensor" tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), ] def __init__(self, *args): @@ -41,33 +40,27 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - input_nodes = tutils.get_two_inputs(node) - - if not is_quant_node and not all( - tensor.meta["val"].dtype in (torch.int8, torch.int32) - for tensor in input_nodes - ): - raise RuntimeError( - f"Unexpected non quantized {AddVisitor_080_BI.target} node." - ) - - needs_rescale = not ( - all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) - and node.meta["val"].dtype == torch.int32 - ) - - if needs_rescale: - # Rescale inputs to 32 bit - rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( - input_nodes, tosa_graph + # Specification (0.80) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + # Handle int8 (quantized) and int32 + assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.ADD + rescaled_inputs = inputs - # Prepare add output tensor + if output.dtype == ts.DType.INT8: broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: + # output.dtype == ts.DType.INT32 add_output = output - rescaled_inputs = inputs # Do the INT32 Add tosa_graph.addOperator( @@ -80,10 +73,10 @@ def define_node( None, ) - if needs_rescale: + if output.dtype == ts.DType.INT8: # Scale output back to 8 bit # pyre-ignore - tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) + tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) @register_node_visitor @@ -91,7 +84,7 @@ class AddVisitor_080_MI(AddVisitor_080_BI): # inheriting 'target' from BI class tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), ] def __init__(self, *args): @@ -105,11 +98,19 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: + # Specification (0.80) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering + assert inputs[0].dtype == ts.DType.FP32 + assert output.dtype == ts.DType.FP32 + + # MI lowering tosa_graph.addOperator( TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 4caaad9202..9d8dd13e7e 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -8,30 +8,43 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass` +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common +from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class AvgPool2dVisitor(NodeVisitor): +class AvgPool2dVisitor_0_80_BI(NodeVisitor): target = "aten.avg_pool2d.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + def __init__(self, *args): super().__init__(*args) - def define_node( + def _build_generic_avgpool2d( self, node: torch.fx.Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, + input_zp: int, + output_zp: int, + accumulator_type, ) -> None: input_tensor = inputs[0] + kernel_size_list = inputs[1].special stride_size_list = inputs[2].special try: @@ -39,13 +52,76 @@ def define_node( except IndexError: pad_size_list = [0, 0, 0, 0] - build_avg_pool_2d_common( - node, - tosa_graph, - input_tensor, - kernel_size_list, - stride_size_list, - pad_size_list, - is_quant_node, - output, + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size_list, + stride=stride_size_list, + pad=pad_size_list, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().AVG_POOL2D, + [input_tensor.name], + [output.name], + attr, ) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_tensor = inputs[0] + assert input_tensor.dtype == ts.DType.INT8 + + accumulator_type = ts.DType.INT32 + + input_qargs = get_input_qparams(node) # pyre-ignore[16] + input_zp = input_qargs[0].zp + + output_qargs = get_output_qparams(node) # pyre-ignore[16] + output_zp = output_qargs[0].zp + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) + + +@register_node_visitor +class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert ( + inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 + ), "Only FP32 and INT8 supported" + + if inputs[0].dtype == ts.DType.INT8: + super().define_node(node, tosa_graph, inputs, output, is_quant_node) + + if inputs[0].dtype == ts.DType.FP32: + accumulator_type = ts.DType.FP32 + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py index d17c3a1b81..c3b9bb0c43 100644 --- a/backends/arm/operators/op_batch_norm.py +++ b/backends/arm/operators/op_batch_norm.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,10 @@ class BatchNormVisitor(NodeVisitor): target = "aten._native_batch_norm_legit_no_training.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index ffbeee7306..5913cb0c34 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -8,16 +8,18 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale_conv_output, - get_quant_arg_downstream, - get_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -57,9 +59,6 @@ def define_node( ) -> None: input, weight, bias, stride, pad, dilation, _, _, group = inputs - # Currently only int8 is supported in quantized types. - actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype - # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -82,9 +81,11 @@ def define_node( dilation_attr[1], ) - input_zp = ( - get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 - ) + input_zp = 0 + if inputs[0].dtype == ts.DType.INT8: + # int8 input requires quantization information + input_qparams = get_input_qparams(node) # pyre-ignore[16] + input_zp = input_qparams[0].zp attr.ConvAttribute( pad=pad_attr, @@ -100,16 +101,22 @@ def define_node( # Create a zero bias tensor if not presented out_channels = weight.shape[0] bias_name = "bias" + node.name.split("default", 1)[1] + bias_type = output.dtype + if output.dtype == ts.DType.INT8: + # Conv is quantized to int8, but the TOSA operator has + # output type int32, and the bias must be the same type + # as the TOSA output type + bias_type = ts.DType.INT32 bias = tosa_graph.addConst( [out_channels], - ts.DType.INT32 if is_quant_node else output.dtype, + bias_type, [0] * out_channels, name=bias_name, ) # The output type is int32 when input type is int8. conv2d_output_name = output.name - if is_quant_node: + if output.dtype == ts.DType.INT8: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) @@ -132,7 +139,7 @@ def define_node( weight_reshaped = tosa_graph.addIntermediate( weight_post_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, + weight.dtype, ) build_reshape( tosa_graph, weight.name, weight_post_shape, weight_reshaped.name @@ -157,20 +164,19 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if is_quant_node: + if inputs[0].dtype == ts.DType.INT8: # Get scale_factor from input, weight, and output. - input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale - weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale - output_qargs = get_quant_arg_downstream(list(node.users)[0]) - + input_scale = input_qparams[0].scale # pyre-ignore [61] + weight_scale = input_qparams[1].scale # pyre-ignore [61] + output_qargs = get_output_qparams(node) # pyre-ignore [16] build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. conv2d_res, output.name, - actual_out_type, + output.dtype, input_scale, weight_scale, - output_qargs.scale, - output_qargs.zp, + output_qargs[0].scale, + output_qargs[0].zp, ) diff --git a/backends/arm/operators/op_div.py b/backends/arm/operators/op_div.py index 0857e0ed32..2332e807c4 100644 --- a/backends/arm/operators/op_div.py +++ b/backends/arm/operators/op_div.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,11 @@ class DivVisitor(NodeVisitor): target = "aten.div.Tensor" + # Only supported for MI + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index d2bc1377ce..23a13dd486 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,10 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - quantize_value, -) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -41,19 +37,14 @@ def define_node( shape = tosa_shape(inputs[0].special, output.dim_order) value = inputs[1].number - if is_quant_node: - qargs = get_quant_arg_downstream(list(node.users)[0]) - qvalue = quantize_value(value, qargs) - dtype = ts.DType.INT8 - data = np.full(shape, qvalue, dtype=np.int8) + + if output.dtype == ts.DType.INT8: + fill_dtype = np.int8 else: - assert ( - output.dtype == ts.DType.FP32 - ), "'Full' currently only supports FP32 for unquantized models." - dtype = ts.DType.FP32 - data = np.full(shape, value, dtype=np.float32) + fill_dtype = np.float32 + data = np.full(shape, value, dtype=fill_dtype) - tosa_graph.addConst(shape, dtype, data, node.name + "full-const") + tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const") tosa_graph.addOperator( ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name] ) diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py new file mode 100644 index 0000000000..58c0d44821 --- /dev/null +++ b/backends/arm/operators/op_max.py @@ -0,0 +1,79 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import serializer.tosa_serializer as ts + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class MaxVisitor(NodeVisitor): + target = "aten.maximum.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert inputs[0].dtype == inputs[1].dtype + + scale_back = 1.0 + max_output = output + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams( # pyre-ignore[16]: 'Module `executorch.backends.arm` has no attribute `_passes`.' + node + ) + assert ( + len(input_qparams) == 2 + ), f"Both inputs needs to have quantization information for {node}" + # insert RESCALEs to int32 + assert ( + input_qparams[0] == input_qparams[1] + ), "Both inputs must have same quantization for MAX" + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + tosa_graph.addOperator( + TosaOp.Op().MAXIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [max_output.name], + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 74e33ddb02..0a4092e3a9 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import ( +from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_downstream, get_quant_arg_upstream, ) diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py new file mode 100644 index 0000000000..61b9e459ca --- /dev/null +++ b/backends/arm/operators/op_min.py @@ -0,0 +1,80 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils + +import serializer.tosa_serializer as ts + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class MinVisitor(NodeVisitor): + target = "aten.minimum.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert inputs[0].dtype == inputs[1].dtype + + scale_back = 1.0 + min_output = output + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams( # pyre-ignore[16]: 'Module `executorch.backends.arm` has no attribute `_passes`.' + node + ) + assert ( + len(input_qparams) == 2 + ), f"Both inputs needs to have quantization information for {node}" + # insert RESCALEs to int32 + assert ( + input_qparams[0] == input_qparams[1] + ), "Both inputs must have same quantization for MIN" + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + tosa_graph.addOperator( + TosaOp.Op().MINIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [min_output.name], + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 2d3a0c2786..1b9b96b5ad 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,10 +11,14 @@ import serializer.tosa_serializer as ts import torch import torch.fx + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_upstream, get_quantized_node_output_dtype, is_node_quantized, ) @@ -110,8 +114,12 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_arg_upstream(input_node).scale - weight_node_scale = get_quant_arg_upstream(weight_node).scale + input_qargs = get_input_qparams( # pyre-ignore[16]: Module `executorch.backends.arm` has no attribute `_passes`. + consumer_node + ) + + input_node_scale = input_qargs[0].scale + weight_node_scale = input_qargs[1].scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 6f2a5689d3..8815d40b0b 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -77,6 +77,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern ], "mul": [[torch.mul]], "sub": [[torch.sub]], + "min_max": [[torch.min], [torch.max]], } return copy.deepcopy(supported_operators) @@ -267,6 +268,7 @@ class ArmQuantizer(Quantizer): "add", "sub", "mul", + "min_max", "mm", "one_to_one", "generic", diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 1201df51ad..d9d27cee2a 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -55,6 +55,7 @@ def decorator(annotator: AnnotatorType): generic_annotator, linear_annotator, max_pool2d_annotator, + min_max_annotator, mm_annotator, mul_annotator, one_to_one_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py new file mode 100644 index 0000000000..43c4d20c13 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.fx import GraphModule, Node + + +@register_annotator("min_max") +def _annotate_min_max( + gm: GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.target not in ( + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + ): + continue + annotated_partitions.append(node) + min_max_node = node + if arm_quantizer_utils.is_annotated(min_max_node): + continue + + input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( + min_max_node, gm, quantization_config + ) + if input_qspec_map is not None: + min_max_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + return annotated_partitions diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index f82b5afc3b..47d259da47 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -49,7 +49,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -63,7 +63,7 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() @@ -111,7 +111,7 @@ def test_numerical_diff_prints(self): model, example_inputs=model.get_inputs(), compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", + "TOSA-0.80+MI", permute_memory_to_nhwc=True, custom_path=tempfile.mkdtemp("diff_print_test"), ), @@ -138,7 +138,7 @@ def test_dump_ops_and_dtypes(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .dump_dtype_distribution() @@ -159,7 +159,7 @@ def test_dump_ops_and_dtypes_parseable(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .dump_dtype_distribution(print_table=False) @@ -187,7 +187,7 @@ def test_collate_tosa_BI_tests(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() @@ -216,7 +216,7 @@ def test_dump_tosa_ops(caplog): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py index d7406afe95..0698773e6f 100644 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ b/backends/arm/test/misc/test_dim_order_guards.py @@ -34,7 +34,7 @@ def test_tosa_MI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -48,7 +48,7 @@ def test_tosa_BI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 12b8d0665b..a16b1e639b 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -60,7 +60,7 @@ def test_partition_lifted_tensor_tosa_MI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -77,7 +77,7 @@ def test_partition_lifted_tensor_tosa_BI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() @@ -95,7 +95,7 @@ def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -110,7 +110,7 @@ def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index 5cbad140b7..77b10cf315 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -16,9 +16,9 @@ from parameterized import parameterized test_valid_0_80_strings = [ - "TOSA-0.80.0+BI", - "TOSA-0.80.0+MI+8k", - "TOSA-0.80.0+BI+u55", + "TOSA-0.80+BI", + "TOSA-0.80+MI+8k", + "TOSA-0.80+BI+u55", ] test_valid_1_00_strings = [ "TOSA-1.00.0+INT+FP+fft", @@ -35,11 +35,11 @@ } test_invalid_strings = [ - "TOSA-0.80.0+bi", - "TOSA-0.80.0", - "TOSA-0.80.0+8k", - "TOSA-0.80.0+BI+MI", - "TOSA-0.80.0+BI+U55", + "TOSA-0.80+bi", + "TOSA-0.80", + "TOSA-0.80+8k", + "TOSA-0.80+BI+MI", + "TOSA-0.80+BI+U55", "TOSA-1.00.0+fft", "TOSA-1.00.0+fp+bf16+fft", "TOSA-1.00.0+INT+INT4+cf", @@ -50,13 +50,13 @@ ] test_compile_specs = [ - ([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],), - ([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],), + ([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],), + ([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],), ([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],), ] test_compile_specs_no_version = [ - ([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("other_key", "TOSA-0.80+BI".encode())],), ([CompileSpec("other_key", "some_value".encode())],), ] diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index 5d37858d2b..fca743a6fa 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -57,7 +57,7 @@ def test_mv2_tosa_MI(self): self.mv2, example_inputs=self.model_inputs, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -72,7 +72,7 @@ def test_mv2_tosa_BI(self): self.mv2, example_inputs=self.model_inputs, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 4a89eee454..24faace007 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -62,7 +62,7 @@ def _test_add_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.add.Tensor": 1}) @@ -81,7 +81,7 @@ def _test_add_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index f952558a63..27629701c3 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -58,7 +58,7 @@ def _test_avgpool2d_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -81,7 +81,7 @@ def _test_avgpool2d_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 297ac0af1c..17fe1bbea9 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -533,7 +533,7 @@ def _test_batchnorm2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -561,7 +561,7 @@ def _test_batchnorm2d_no_stats_tosa_MI_pipeline( ArmTester( module, example_example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten._native_batch_norm_legit.no_stats": 1}) @@ -590,7 +590,7 @@ def _test_batchnorm2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e9c89c1f2a..3a3c2772bd 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -58,7 +58,7 @@ def _test_bmm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -78,7 +78,7 @@ def _test_bmm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index c0171f572e..115b4402f5 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -57,7 +57,7 @@ def _test_cat_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.cat.default": 1}) @@ -77,7 +77,7 @@ def _test_cat_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 28fee664fb..2ec2f621fa 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -49,7 +49,7 @@ def _test_clone_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.clone.default": 1}) @@ -68,7 +68,7 @@ def _test_clone_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index 8a0b69299e..593260ac56 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -228,7 +228,7 @@ def _test_conv1d_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -250,7 +250,7 @@ def _test_conv1d_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index fd9572ec0b..9ccac53940 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -255,7 +255,7 @@ def _test_conv2d_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -277,7 +277,7 @@ def _test_conv2d_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 252695b911..17083b129f 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -193,7 +193,7 @@ def _test_conv_combo_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -217,7 +217,7 @@ def _test_conv_combo_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 3034798c1e..3ce7584086 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -191,7 +191,7 @@ def _test_dw_conv_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -211,7 +211,7 @@ def _test_dw_conv_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index 8800c17a3e..d5f6174469 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -104,7 +104,7 @@ def _test_div_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) @@ -123,7 +123,7 @@ def _test_div_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_exp.py b/backends/arm/test/ops/test_exp.py index 7c49ad1aaf..3fa9f8c99f 100644 --- a/backends/arm/test/ops/test_exp.py +++ b/backends/arm/test/ops/test_exp.py @@ -42,7 +42,7 @@ def _test_exp_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.exp.default"]) @@ -60,7 +60,7 @@ def _test_exp_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 6bf9c920d5..084eec138f 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -50,7 +50,7 @@ def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.expand.default": 1}) @@ -68,7 +68,7 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 08f4f4f84f..1b6d6e6ae3 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -60,7 +60,7 @@ def _test_full_tosa_MI_pipeline( ArmTester( module, example_inputs=example_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.full.default": 1}) @@ -83,7 +83,7 @@ def _test_full_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80+BI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .quantize() diff --git a/backends/arm/test/ops/test_hardtanh.py b/backends/arm/test/ops/test_hardtanh.py index 49cc1f8d57..7125920c8c 100644 --- a/backends/arm/test/ops/test_hardtanh.py +++ b/backends/arm/test/ops/test_hardtanh.py @@ -55,7 +55,7 @@ def _test_hardtanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.hardtanh.default"]) @@ -76,7 +76,7 @@ def _test_hardtanh_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 2dbfcf8a75..2d88421fb5 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -78,7 +78,7 @@ def _test_layernorm_tosa_MI_pipeline( model=module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -99,7 +99,7 @@ def _test_layernorm_tosa_BI_pipeline( model=module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 45af7f48d4..cd14b7801d 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -137,7 +137,7 @@ def _test_linear_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -157,7 +157,7 @@ def _test_linear_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_log.py b/backends/arm/test/ops/test_log.py index f1d21a2c17..0226a62328 100644 --- a/backends/arm/test/ops/test_log.py +++ b/backends/arm/test/ops/test_log.py @@ -42,7 +42,7 @@ def _test_log_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.log.default"]) @@ -60,7 +60,7 @@ def _test_log_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 910384e0a0..69c8ee06ec 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -61,7 +61,7 @@ def _test_logsoftmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.log_softmax.int"]) @@ -81,7 +81,7 @@ def _test_logsoftmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 55e541b610..a693c7d549 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -65,7 +65,7 @@ def _test_maxpool2d_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -92,7 +92,7 @@ def _test_maxpool2d_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) diff --git a/backends/arm/test/ops/test_maximum.py b/backends/arm/test/ops/test_maximum.py new file mode 100644 index 0000000000..7e75064522 --- /dev/null +++ b/backends/arm/test/ops/test_maximum.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestMaximum(unittest.TestCase): + """Tests a single maximum op""" + + class Maximum(torch.nn.Module): + test_parameters = [ + ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.maximum(x, y) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_maximum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check_count({"torch.ops.aten.maximum.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_maximum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.maximum.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_maximum_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_maximum_tosa_MI_pipeline(self.Maximum(), test_data) + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_maximum_tosa_BI_pipeline(self.Maximum(), test_data) + + @parameterized.expand(Maximum.test_parameters) + @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 + def test_maximum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_maximum_ethos_BI_pipeline( + self.Maximum(), common.get_u55_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_maximum_ethos_BI_pipeline( + self.Maximum(), common.get_u85_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index e725eb1ef4..e4f6afcbd6 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -81,7 +81,7 @@ def _test_adaptive_avg_pool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.adaptive_avg_pool2d.default"]) @@ -101,7 +101,7 @@ def _test_adaptive_avg_pool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() @@ -150,7 +150,7 @@ def _test_meandim_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -169,7 +169,7 @@ def _test_meandim_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_minimum.py b/backends/arm/test/ops/test_minimum.py new file mode 100644 index 0000000000..ddbdb24657 --- /dev/null +++ b/backends/arm/test/ops/test_minimum.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestMinimum(unittest.TestCase): + """Tests a single minimum op""" + + class Minimum(torch.nn.Module): + test_parameters = [ + ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.minimum(x, y) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_minimum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check_count({"torch.ops.aten.minimum.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_minimum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.minimum.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_minimum_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_minimum_tosa_MI_pipeline(self.Minimum(), test_data) + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_minimum_tosa_BI_pipeline(self.Minimum(), test_data) + + @parameterized.expand(Minimum.test_parameters) + @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 + def test_minimum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_minimum_ethos_BI_pipeline( + self.Minimum(), common.get_u55_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_minimum_ethos_BI_pipeline( + self.Minimum(), common.get_u85_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index 21b02bbd10..5fa28076aa 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -54,7 +54,7 @@ def _test_mm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.mm.default": 1}) @@ -74,7 +74,7 @@ def _test_mm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index ed0cbbc828..9d789a8e33 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -73,7 +73,7 @@ def _test_mul_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + "TOSA-0.80+MI", permute_memory_to_nhwc=True ), ) .export() @@ -94,7 +94,7 @@ def _test_mul_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + "TOSA-0.80+BI", permute_memory_to_nhwc=True ), ) .quantize() diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index cadfe8cd85..b373af1401 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -60,7 +60,7 @@ def _test_permute_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80+MI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .export() @@ -82,7 +82,7 @@ def _test_permute_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_reciprocal.py b/backends/arm/test/ops/test_reciprocal.py index 2da2dffb5b..b3233d02a9 100644 --- a/backends/arm/test/ops/test_reciprocal.py +++ b/backends/arm/test/ops/test_reciprocal.py @@ -48,7 +48,7 @@ def _test_reciprocal_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.reciprocal.default": 1}) @@ -67,7 +67,7 @@ def _test_reciprocal_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py index 595c907b32..5a7bd4f5ec 100644 --- a/backends/arm/test/ops/test_relu.py +++ b/backends/arm/test/ops/test_relu.py @@ -48,7 +48,7 @@ def _test_relu_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.relu.default"]) @@ -69,7 +69,7 @@ def _test_relu_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index de555e7c80..f43f7af13c 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -48,7 +48,7 @@ def _test_repeat_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.repeat.default": 1}) @@ -66,7 +66,7 @@ def _test_repeat_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_rshift.py b/backends/arm/test/ops/test_rshift.py index dfbd0fdb3e..4c13beb7c4 100644 --- a/backends/arm/test/ops/test_rshift.py +++ b/backends/arm/test/ops/test_rshift.py @@ -33,7 +33,7 @@ def _test_rshift_tosa_MI(self, test_data): ArmTester( self.Rshift(), example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge_transform_and_lower() @@ -46,7 +46,7 @@ def _test_rshift_tosa_BI(self, test_data): ArmTester( self.Rshift(), example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index 2cddc8da26..2bf5fc371c 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -35,7 +35,7 @@ def _test_rsqrt_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.rsqrt.default": 1}) @@ -53,7 +53,7 @@ def _test_rsqrt_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 60c32a4557..f03d8f72d1 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -129,7 +129,7 @@ def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -143,7 +143,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index c7194833cc..f44e61c64f 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -58,7 +58,7 @@ def _test_select_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute + "TOSA-0.80+MI", permute_memory_to_nhwc=permute ), ) .export() @@ -84,7 +84,7 @@ def _test_select_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute + "TOSA-0.80+BI", permute_memory_to_nhwc=permute ), ) .quantize() diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index f12658c985..a5c6c86c52 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -71,7 +71,7 @@ def _test_sigmoid_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.sigmoid.default"]) @@ -89,7 +89,7 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 0fc92b011a..511873a8c2 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -39,7 +39,7 @@ def _test_slice_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.slice.Tensor"]) @@ -60,7 +60,7 @@ def _test_slice_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute + "TOSA-0.80+BI", permute_memory_to_nhwc=permute ), ) .quantize() diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index 30215b47f3..fd78d1a9ac 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -63,7 +63,7 @@ def _test_softmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.softmax.int"]) @@ -83,7 +83,7 @@ def _test_softmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 42395c4c2d..a1ba53c881 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -56,7 +56,7 @@ def _test_split_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -79,7 +79,7 @@ def _test_split_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index 7e915da645..ac26fd73fa 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -61,7 +61,7 @@ def _test_squeeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({export_target: 1}) @@ -82,7 +82,7 @@ def _test_squeeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index 2b1ecd6436..0812f8a47a 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -44,7 +44,7 @@ def _test_sub_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.sub.Tensor": 1}) @@ -64,7 +64,7 @@ def _test_sub_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 111517afbb..098e0fd1bc 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -61,7 +61,7 @@ def _test_sum_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.sum.dim_IntList": 1}) @@ -80,7 +80,7 @@ def _test_sum_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 5f3859eadd..060d7933ea 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -44,7 +44,7 @@ def _test_tanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.tanh.default"]) @@ -62,7 +62,7 @@ def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple) ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index a9d827f903..6992ac2f8e 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -52,7 +52,7 @@ def _test_to_copy_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .dump_artifact() diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 8936d55f8b..a6faf70af0 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -35,7 +35,7 @@ def _test_unsqueeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) @@ -53,7 +53,7 @@ def _test_unsqueeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py index d03ac1e441..8984d716a3 100644 --- a/backends/arm/test/ops/test_upsample_nearest2d.py +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -87,7 +87,7 @@ def _test_upsample_nearest_2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check(["torch.ops.aten.upsample_nearest2d.vec"]) @@ -111,7 +111,7 @@ def _test_upsample_nearest_2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 727cd05393..322ac5b0ed 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -96,7 +96,7 @@ def _test_var_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -117,7 +117,7 @@ def _test_var_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 07a32fe595..1603a2a37d 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -56,7 +56,7 @@ def _test_view_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .check_count({"torch.ops.aten.view.default": 1}) @@ -74,7 +74,7 @@ def _test_view_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py new file mode 100644 index 0000000000..cd7cf75139 --- /dev/null +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -0,0 +1,75 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import RunPasses + +from executorch.exir.dialects._ops import ops as exir_ops + + +class SimpleQuantizeModel(torch.nn.Module): + def forward(self, x, y): + return x + torch.max((x + x), (y + y)) + + def get_inputs(self): + return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)) + + +class FoldAndAnnotateQParamsPassTestClass(FoldAndAnnotateQParamsPass): + def __init__(self): + super(FoldAndAnnotateQParamsPassTestClass, self).__init__( + [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.maximum.default, + ] + ) + + +class TestFoldAndAnnotateQParamsPass(unittest.TestCase): + """ + Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into + the node and stores the quantization parameters in meta. + """ + + def test_fold_qdq_pass(self): + """ + Check that the pass runs for add operation and that one q node and one dq node + is removed from the representation. + """ + module = SimpleQuantizeModel() + test_pass_stage = RunPasses([FoldAndAnnotateQParamsPassTestClass]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .to_edge() + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, + } + ) + .run_passes(test_pass_stage) + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + } + ) + ) diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index 978a4c6fe5..93badc6435 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -46,7 +46,7 @@ def test_tosa_BI_meandim_to_averagepool(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() @@ -63,7 +63,7 @@ def test_tosa_BI_meandim_no_modification(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), ) .quantize() .export() diff --git a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py index d249c18ec8..4323e7332d 100644 --- a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py +++ b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py @@ -37,7 +37,7 @@ def test_tosa_MI_insert_view(self): ArmTester( module, example_inputs=inputs, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() @@ -61,7 +61,7 @@ def test_tosa_MI_dont_insert_view(self): ArmTester( module, example_inputs=inputs, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() .to_edge() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index 353c8b6019..61be0ccb3e 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -32,7 +32,7 @@ def check_annotation(self, model): tester = ArmTester( model, model.example_inputs(), - common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + common.get_tosa_compile_spec("TOSA-0.80+BI"), ) quant_model = tester.quantize().get_artifact() partitions = get_source_partitions(quant_model.graph, [model.op]) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 4de84ed345..9ae1a27cf7 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -127,7 +127,7 @@ def _get_output_node(program: ExportedProgram) -> Node: def _get_output_quantization_params( program: ExportedProgram, output_node: Node -) -> QuantizationParams: +) -> Optional[QuantizationParams]: """ Get output QuantizationParams from a program. Args: @@ -153,8 +153,6 @@ def _get_output_quantization_params( dtype=node.args[5], ) break # break early, there's only one output node - if quant_params is None: - raise RuntimeError("No Quantization parameters not found in exported model.") return quant_params @@ -485,13 +483,17 @@ def run_tosa_ref_model( if tosa_ref_output.dtype == np.int8: tosa_ref_output = tosa_ref_output.astype(np.int32) quant_param = self.qp_output - assert ( - quant_param is not None - ), "There are no quantization parameters, check output parameters" - tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if quant_param is not None: + # I.e. bool output is possible for quantized models + tosa_ref_output = ( + tosa_ref_output - quant_param.zp + ) * quant_param.scale if tosa_ref_output.dtype == np.double: tosa_ref_output = tosa_ref_output.astype("float32") + elif tosa_ref_output.dtype == bool: + # retain the bool output though for boolean related comparisons + tosa_ref_output = tosa_ref_output.astype("bool") # tosa_output is a numpy array, convert to torch tensor for comparison tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 62f6803577..f663b16a9d 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -11,7 +11,6 @@ import executorch.backends.xnnpack.test.tester.tester as tester -import numpy as np import serializer.tosa_serializer as ts import torch.fx @@ -317,12 +316,15 @@ def run_method_and_compare_outputs( target_board, ) + quantization_scale = None if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] - quantization_scale = self.runner_util.qp_output.scale + # bool output is quantized with none quantized output so allow + # self.runner_util.qp_output to be none + if self.runner_util.qp_output is not None: + quantization_scale = self.runner_util.qp_output.scale else: reference_stage = self.stages[self.stage_name(InitialModel)] - quantization_scale = None logger.info( f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" @@ -508,7 +510,7 @@ def transpose_data_format( inputs_transposed = list(data) for i in range(len(data)): if hasattr(data[i], "shape") and len(data[i].shape) == 4: - inputs_transposed[i] = np.transpose(data[i], dim_order) + inputs_transposed[i] = torch.permute(data[i], dim_order) return tuple(inputs_transposed) def _compare_outputs( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index b526a2aa8e..ab2d8befdc 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -57,6 +57,11 @@ def insert_rescale_ops_to_int32( the graph upstream for DQ nodes. """ + # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + ) + tensors = inputs.copy() # Reshape tensor according to TOSA dim order @@ -64,7 +69,8 @@ def insert_rescale_ops_to_int32( dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values()) + input_qparams = get_input_qparams(node) # pyre-ignore[16] + qargs = input_qparams.values() # Scale the int8 quantized input to a common scale in the integer # domain @@ -84,7 +90,7 @@ def insert_rescale_ops_to_int32( return rescaled_nodes, min_scale -def insert_rescale_node_back_to_int8( +def insert_rescale_op_to_int8( tosa_graph: ts.TosaSerializer, last_tensor: TosaArg, scale: float, @@ -102,9 +108,15 @@ def insert_rescale_node_back_to_int8( in the node meta dict as opposed to 'rescale_node_back_to_int8' which search the graph downstream for Q nodes. """ - assert len(node.meta["output_qparams"]) == 1 + # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, + ) + + output_qparams = get_output_qparams(node) # pyre-ignore[16] + assert len(output_qparams) == 1, "More than one output not supported" - qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0] + qargs_out = output_qparams[0] output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -136,6 +148,17 @@ def quantize_value(self, x): def dequantize_value(self, qx: int) -> float: return (qx - self.zp) * self.scale + def __eq__(self, other): + if isinstance(other, QuantArgs): + return ( + self.scale == other.scale + and self.zp == other.zp + and self.qmin == other.qmin + and self.qmax == other.qmax + and self.dtype == other.dtype + ) + return False + @classmethod def from_operator(cls, op, args): if op in dq_q_ops: diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py index 716e8daee2..232eb34afd 100644 --- a/backends/arm/tosa_specification.py +++ b/backends/arm/tosa_specification.py @@ -72,9 +72,9 @@ class from the found value or return None on failure. def create_from_string(repr: str) -> "TosaSpecification": """ Creates a TOSA specification class from a string representation: - TOSA-0.80.0+MI - TOSA-0.80.0+BI+8k - TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset + TOSA-0.80+MI + TOSA-0.80+BI+8k + TOSA-0.80+BI+u55 # Ethos-U55 extension to handle TOSA subset TOSA-0.90.0+MI TOSA-1.00.0+INT+FP+int4+cf """ diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 1ae319e0cd..5bda9bbf18 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,18 +7,13 @@ import logging import os -from typing import Any, cast +from typing import Any import numpy as np import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - get_quant_arg_upstream, - q_op, -) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -78,7 +73,7 @@ def dbg_fail(node, tosa_graph, path): # Helper function to match TOSA's broadcasting rank requirement -# Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from +# Ref: TOSA 0.80 specification - 1.9.3. Data Layouts from # https://www.mlplatform.org/tosa/tosa_spec.html def promote_shape(tosa_fb, arg, promoted_shape, out_dtype): assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape" @@ -90,7 +85,7 @@ def promote_shape(tosa_fb, arg, promoted_shape, out_dtype): # Helper transpose function to match TOSA's shape requirements -# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes: +# E.g., TOSA 0.80 specification - 2.3.3 CONV2D shapes: # https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d def transpose_helper(tosa_fb, input, new_order, out_dtype): # Check new_order's length is equal to input rank @@ -140,10 +135,15 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] - return ( + + if ( consumer_node.target == exir_ops.edge.aten.convolution.default - and list(consumer_node.users)[0].target == q_op - ) + and consumer_node.args[2] == node + and consumer_node.meta["val"].dtype == torch.int8 + ): + return True + + return False def is_consumer_node_depthwise_conv2d(node): @@ -159,48 +159,6 @@ def is_consumer_node_depthwise_conv2d(node): return False -def build_avg_pool_2d_common( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - input_tensor: TosaArg, - kernel_size: list, - stride: list, - padding: list, - is_quant_node: bool, - output: TosaArg, -): - accumulator_type = input_tensor.dtype - - if is_quant_node: - # Accumulator type always is int32 when input tensor is an integer type. - accumulator_type = ts.DType.INT32 - - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - if is_quant_node: - input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_arg_downstream(list(node.users)[0]).zp - - attr = ts.TosaSerializerAttribute() - attr.PoolAttribute( - kernel=kernel_size, - stride=stride, - pad=padding, - input_zp=input_zp, - output_zp=output_zp, - accum_dtype=accumulator_type, - ) - - tosa_graph.addOperator( - TosaOp.Op().AVG_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) - - def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: """Returns two input nodes to 'node' in order. If 'node' only has one input, it is returned twice. diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 9e20d6043a..bf922360fd 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -88,7 +88,7 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" -tosa_reference_model_rev="c5570b79e90c3a36ab8c4ddb8ee3fbc2cd3f7c38" +tosa_reference_model_rev="v0.80.1" # vela vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela"