Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "Add full operator to fold dq/q handling" (#7351)" #7362

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ The Arm Backend should be considered a prototype quality at this point, likely s
## Current flows

The ArmBackend has a two stage process,
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80.0 TOSA BI with specific concern to a subset which gives support on Ethos-U55, the target of the initial prototype efforts.
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80.0 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80 TOSA BI with specific concern to a subset which gives support on Ethos-U55, the target of the initial prototype efforts.
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.

The ArmPartitioner is currenly used to ensure the operations converted are Ethos-U compatible, but will be extended to offer spec-correct TOSA Base inference and TOSA Main Inference generation in future.

Expand Down
18 changes: 18 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
)
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand All @@ -50,6 +54,7 @@
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import PassManager


Expand Down Expand Up @@ -80,6 +85,19 @@ def transform_to_backend_pipeline(
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
[
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.full.default,
]
)
)
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
183 changes: 183 additions & 0 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

from typing import cast, Iterable

from executorch.backends.arm.tosa_quant_utils import QuantArgs

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node

q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default


def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
Raises a ValueError if the node doesn't have any parameters set.
"""
if "input_qparams" not in node.meta.keys():
raise ValueError(f"No input quantization parameter found in node {node}")
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
if len(input_qparams) == 0:
raise ValueError(f"No input quantization parameter found in node {node}")
return input_qparams


def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
Raises a ValueError if the node doesn't have any parameters set.
"""
if "output_qparams" not in node.meta.keys():
raise ValueError(f"No output quantization parameter found in node {node}")
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
if len(input_qparams) == 0:
raise ValueError(f"No output quantization parameter found in node {node}")
return input_qparams


class FoldAndAnnotateQParamsPass(ExportPass):
"""
A pass that walks the graph and removes any DQ and Q nodes before and after the target
node in the supplied list of operators.
The quantization parameters from the DQ/Q nodes are stored as meta values to be
accessible for later lowering and serialization passes.
The assumption is that the quantization annotatation adds DQ nodes for all tensor
inputs to the target one Q node to the output.

Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):

x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)

x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)

output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)

Becomes:
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)

aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)

output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)

The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.

"""

def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
super().__init__()
self.targeted_ops = targeted_ops

def call(self, graph_module: GraphModule) -> PassResult:

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
continue

# Make sure we haven't already set qparams meta information on the node
assert "input_qparams" not in n.meta.keys()
assert "output_qparams" not in n.meta.keys()

# for the inputs and outputs search the graph for quantization info and
# store the information in a dict with order of the _tensor_ inputs as key,
# ignoring any other arguments to the target node.
n.meta["input_qparams"] = {}
n.meta["output_qparams"] = {}
for i, arg in enumerate(n.args):
if not isinstance(arg, Node):
continue

# Make sure arg has requires_grad set to False
# For parameters that are not quantized, sometimes (i.e. convolution)
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
# causes the retracing of the graph to fail with:
#
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
# E
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
# E Original traceback:
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
# E x = conv(x)
#
if arg.op == "placeholder":
arg.meta["val"].requires_grad = False

if arg.target != dq_op:
continue

# arg.target for argument i is a dequant node, extract the information
n.meta["input_qparams"][i] = QuantArgs.from_operator(
arg.target, arg.args
)

# arg.args[0] is the tensor input, replace the input usage
tensor_input = cast(Node, arg.args[0])
n.replace_input_with(arg, tensor_input)
graph_module.graph.erase_node(arg)

# Copy the users, since we are modifying it.
users_copy = copy.copy(n.users)
for i, user in enumerate(users_copy):
if user.target != q_op:
continue

# quantization node found here, store the quantization parameters in meta value
n.meta["output_qparams"][i] = QuantArgs.from_operator(
user.target, user.args
)

user.replace_all_uses_with(n)
graph_module.graph.erase_node(user)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()
return PassResult(graph_module, True)


class QuantizeFullArgument(ExportPass):
"""
Make sure the fill_value for full.default is quantized. This pass needs to be run before
the folding pass above to make sure that the retraced output of the full.default op is
the right dtype.
"""

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.target != exir_ops.edge.aten.full.default:
continue

# Make sure we have a quantized operator
user = list(n.users)[0]
if user.target != q_op:
continue

qargs = QuantArgs.from_operator(user.target, user.args)
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
# replace the node arg with a quantized dito and also set dtype
# to get the right output according to the Edge IR specification:
# exir/dialects/edge/edge.yaml:3596
quantized_full_value = qargs.quantize_value(n.args[1]).item()
n.update_arg(1, quantized_full_value)
n.update_kwarg("dtype", qargs.dtype)
modified = True

return PassResult(graph_module, modified)
2 changes: 1 addition & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ethosu_compile_spec(
if extra_flags is not None:
self.compiler_flags.append(extra_flags)

base_tosa_version = "TOSA-0.80.0+BI"
base_tosa_version = "TOSA-0.80+BI"
if "u55" in config:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.__rshift__.Scalar]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten._to_copy.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
_tosa_spec_dicts: dict[
TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
] = {
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
}


Expand Down Expand Up @@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
op_get_item,
op_hardtanh,
op_log,
op_max,
op_max_pool2d,
op_min,
op_mm,
op_mul,
op_permute,
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class NodeVisitor:
# When all node_visitors has been refactored to target a specific
# version, this list should be removed.
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
Expand All @@ -46,8 +46,8 @@ def define_node(

# container for all node visitors
_node_visitor_dicts = {
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
}


Expand Down
Loading
Loading