Skip to content

Commit

Permalink
Arm backend: qdq folding support for remaining operators (#7340)
Browse files Browse the repository at this point in the history
* Add TOSA table as custom edge op

Edge operators that are lowered to TOSA TABLEs are convereted to a
custom edge IR table-op.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: I147008c30b9b46c7b8ae1a1c15bc540fea614a69

* Add support for concat q/dq folding

This is a special case where node.args can be lists with many incoming
dq-nodes.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: Icf511a8bdeaaffb597b18455ab7f1fbd947ce3ca

* Increase q/dq folding coverage

Add support for q/dq folding of more operators such as hardtanh,
maxpool2d, mul, relu, select, sub, to_copy.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: Ifdabda4c927dade41c000859054696844c546f7b

* Add support for sum q/dq folding

sum is retraced to an int64 dtype of operator after q/dq folding.
This patch adds a pass to manually force the dtype to be int8.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: Ifa737a398c5a878d52cd76a2392499905da085ce

* Complete q/dq folding coverage

Add support for q/dq folding for the remaining supported ops in Arm
backend.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: I9012b4a501ce018c9771c729706be3b031a5c7ae

* Remove is_quant_node from NodeVisitor.define_node

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: Ibb17add461dc79e022a7f4accde29f9f9d61b16d

* Fix pyre issues

Address issues from pyre and add similar # pyre-ignores as in
#7362.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: I6feaa611dcd539b3b0d21a6a7dd696ef7db691ef

---------

Signed-off-by: Oscar Andersson <[email protected]>
  • Loading branch information
oscarandersson8218 authored Jan 7, 2025
1 parent 62c6346 commit a29b208
Show file tree
Hide file tree
Showing 53 changed files with 822 additions and 749 deletions.
89 changes: 89 additions & 0 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateDecomposedMatmulPass(ExportPass):
"""
torch.matmul can be decomposed in many ways, for instance:
dq -> matmul -> q can become
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
difficult. This helper function find all matmul partitions and annotate its
matmul-op (can be mm or bmm).
"""

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
[
torch.matmul,
],
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions.values())
)
matmul_targets = {
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.bmm.default,
}
for partition in matmul_partitions:
quantized_input = all(
input_node.target == dq_op for input_node in partition.input_nodes
)
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
][0]
if quantized_input:
matmul_args = matmul_node.all_input_nodes
for i in range(len(matmul_args)):
input_node = partition.input_nodes[i]
matmul_input_node = matmul_args[i]
# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = input_node.args[1:]
with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=dq_op,
)
dq_node.args = (matmul_input_node, *input_node_qargs)
matmul_node.replace_input_with(matmul_input_node, dq_node)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
output_node_qargs = partition_output.args[1:]
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
graph=graph_module.graph,
op_target=q_op,
)
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *output_node_qargs)
# Remove partition output q-node
partition_output.replace_all_uses_with(
partition_output.all_input_nodes[0]
)
graph_module.graph.erase_node(partition_output)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()
return PassResult(graph_module, True)
58 changes: 46 additions & 12 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
AnnotateDecomposedMatmulPass,
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
Expand All @@ -32,7 +35,9 @@
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand Down Expand Up @@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(DecomposeVarPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
# TODO MLETORCH-558
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
Expand All @@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
]
)
)
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
22 changes: 2 additions & 20 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_param_tensor,
insert_q_dq_pair,
is_param_node,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
supports 2d and 3d convolution. This is done by modifying the graph to do the
following:
1) unsqueeze the convolution's input from 3d to 4d
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
3) perform a conv2d (with a modified version of the original conv1d args)
4) squeeze the output back down to 3d.
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
2) perform a conv2d (with a modified version of the original conv1d args)
3) squeeze the output back down to 3d.
"""

def __init__(self, exported_program: ExportedProgram) -> None:
Expand Down Expand Up @@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

kernel_node = node.args[1]
if kernel_node.target == dq_op:
kernel_node = kernel_node.args[0]

if not is_param_node(self.exported_program, kernel_node):
raise AssertionError(
Expand Down Expand Up @@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
node.replace_input_with(input_node, unsqueeze_before)

# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)

with graph.inserting_after(node):
squeeze_after = create_node(
graph,
Expand All @@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
for user in original_users:
user.replace_input_with(node, squeeze_after)

# If quantized, insert conv2d --> q --> dq --> squeeze
if all(
original_user.target == q_op for original_user in original_users
):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, node, q_params)

graph_module.recompile()
# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
Expand Down
119 changes: 88 additions & 31 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@

import copy

from typing import cast, Iterable
from typing import cast, Dict, Iterable, Set, Tuple

from executorch.backends.arm.tosa_quant_utils import QuantArgs

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch.fx import GraphModule, Node

q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
Expand Down Expand Up @@ -80,6 +86,46 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
super().__init__()
self.targeted_ops = targeted_ops

def fold_and_annotate_arg(
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
) -> None:
input_qparams = None
nodes_to_remove = set()
for arg in arg_list:
if not isinstance(arg, Node):
return
"""
Make sure arg has requires_grad set to False
For parameters that are not quantized, sometimes (i.e. convolution)
the Parameter(FakeTensor(...)) has requires_grad set to True, which
causes the retracing of the graph to fail with:
E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
E
E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
E Original traceback:
E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
E x = conv(x)
"""
if arg.op == "placeholder":
arg.meta["val"].requires_grad = False

arg_quant_params = None
if arg.target == dq_op:
arg_quant_params = QuantArgs.from_operator(arg.target, arg.args)
# add arg to nodes_to_remove to fold the dq-node
nodes_to_remove.add(arg)
if input_qparams is not None and input_qparams != arg_quant_params:
# Two args are quantized differently
raise RuntimeError("Input qparams does not match!")
input_qparams = arg_quant_params
if input_qparams is not None:
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
assert n.target == dq_op
n.replace_all_uses_with(n.args[0])
graph_module.graph.erase_node(n)

def call(self, graph_module: GraphModule) -> PassResult:

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
Expand All @@ -98,36 +144,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
n.meta["input_qparams"] = {}
n.meta["output_qparams"] = {}
for i, arg in enumerate(n.args):
if not isinstance(arg, Node):
continue

# Make sure arg has requires_grad set to False
# For parameters that are not quantized, sometimes (i.e. convolution)
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
# causes the retracing of the graph to fail with:
#
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
# E
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
# E Original traceback:
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
# E x = conv(x)
#
if arg.op == "placeholder":
arg.meta["val"].requires_grad = False

if arg.target != dq_op:
continue

# arg.target for argument i is a dequant node, extract the information
n.meta["input_qparams"][i] = QuantArgs.from_operator(
arg.target, arg.args
)
if isinstance(arg, list):
self.fold_and_annotate_arg(graph_module, n, arg, i)

# arg.args[0] is the tensor input, replace the input usage
tensor_input = cast(Node, arg.args[0])
n.replace_input_with(arg, tensor_input)
graph_module.graph.erase_node(arg)
elif isinstance(arg, Node):
self.fold_and_annotate_arg(graph_module, n, [arg], i)

# Copy the users, since we are modifying it.
users_copy = copy.copy(n.users)
Expand Down Expand Up @@ -181,3 +202,39 @@ def call(self, graph_module: GraphModule) -> PassResult:
modified = True

return PassResult(graph_module, modified)


class RetraceFoldedDtypesPass(ExportPass):
"""
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
some operators are retraced to types that cannot be handled by TOSA. One
such example is sum.dim_IntList:
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
After folding it becomes:
q (int8) -> sum (int64) -> ...
This pass changes types of ops in self.targeted_ops, such as sum, so that
the output type of that matches the type of the output_qparams.
"""

targeted_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.sum.dim_IntList,
}

def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

node_kwargs = kwargs.copy()
output_qparams = meta["output_qparams"]
if len(output_qparams) == 0:
return super().call_operator(op, args, kwargs, meta)

output_dtype = output_qparams[0].dtype
node_kwargs["dtype"] = output_dtype
return super().call_operator(op, args, node_kwargs, meta)
Loading

0 comments on commit a29b208

Please sign in to comment.