-
Notifications
You must be signed in to change notification settings - Fork 413
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2024-10-16 nightly release (8673567)
- Loading branch information
pytorchbot
committed
Oct 16, 2024
1 parent
462751c
commit 400150b
Showing
76 changed files
with
1,901 additions
and
604 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast | ||
|
||
import torch | ||
import torch.fx | ||
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair | ||
|
||
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class InsertSqueezeAfterSumPass(ExportPass): | ||
""" | ||
In Pytorch, the default behaviour of Tensor.sum is to squeeze | ||
the dimension that is summed (keep_dim = False). | ||
However, in TOSA, REDUCE_SUM always preserves the | ||
rank of the input (keep_dim = True). | ||
To get a 1-1 mapping in the sum lowering, normalize the | ||
keep_dim = False case to keep_dim = True and add squeeze ops. | ||
Original: | ||
sum(dims, keep_dim = False) | ||
After pass: | ||
sum(dims, keep_dim = True) | ||
(q) | ||
(dq) | ||
squeeze(dim = dims) | ||
""" | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
for node in graph_module.graph.nodes: | ||
if node.op != "call_function": | ||
continue | ||
if node.target != exir_ops.edge.aten.sum.dim_IntList: | ||
continue | ||
sum_node = cast(torch.fx.Node, node) | ||
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False) | ||
if keep_dim: | ||
continue | ||
|
||
dim_list = cast(list[int], sum_node.args[1]) | ||
quantized = is_quant_node(sum_node) | ||
if quantized: | ||
qparams = get_quant_node_args(sum_node.all_input_nodes[0]) | ||
qparams = qparams + (torch.int8,) | ||
else: | ||
qparams = None | ||
|
||
# Add keep_dim = True arg to sum node. | ||
sum_node.args = sum_node.args[0:2] + (True,) | ||
|
||
with graph_module.graph.inserting_after(sum_node): | ||
squeeze_node = create_node( | ||
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, () | ||
) | ||
sum_node.replace_all_uses_with(squeeze_node) | ||
squeeze_node.args = (sum_node, dim_list) | ||
if quantized: | ||
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) | ||
graph_module.graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
graph_module = super().call(graph_module).graph_module | ||
return PassResult(graph_module, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
op_softmax, | ||
op_squeeze, | ||
op_sub, | ||
op_sum, | ||
op_unsqueeze, | ||
op_view, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2023-2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast, List | ||
|
||
import executorch.backends.arm.tosa_quant_utils as tqutils | ||
import executorch.backends.arm.tosa_utils as tutils | ||
|
||
import serializer.tosa_serializer as ts | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
from serializer.tosa_serializer import TosaOp | ||
from torch.fx import Node | ||
|
||
|
||
@register_node_visitor | ||
class AddVisitor(NodeVisitor): | ||
target = "aten.sum.dim_IntList" | ||
|
||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
input_node = inputs[0] | ||
input_shape = list(input_node.shape) | ||
dim_list = cast(list[int], inputs[1].special) | ||
dim_list = [dim % len(input_node.shape) for dim in dim_list] | ||
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) | ||
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" | ||
|
||
if is_quant_node: | ||
|
||
# Rescale input to 32 bit | ||
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( | ||
[node.all_input_nodes[0]], tosa_graph | ||
) | ||
|
||
prev_node = rescaled_inputs[0] | ||
reduced_shape = input_shape | ||
|
||
# Reduce all dims in dim_list one-by-one. | ||
for dim in dim_list: | ||
# When reduced, the size of the dim becomes 1. | ||
reduced_shape[dim] = 1 | ||
|
||
attr = ts.TosaSerializerAttribute() | ||
attr.AxisAttribute(input_node.dim_order.index(dim)) | ||
|
||
next_node = tosa_graph.addIntermediate( | ||
tutils.tosa_shape(reduced_shape, input_node.dim_order), | ||
dtype=ts.DType.INT32, | ||
) | ||
|
||
tosa_graph.addOperator( | ||
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr | ||
) | ||
|
||
prev_node = next_node | ||
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph) | ||
else: | ||
input_name = input_node.name | ||
reduced_shape = input_shape | ||
|
||
# Reduce all dims in dim_list one-by-one. | ||
for dim in dim_list: | ||
# When reduced, the size of the dim becomes 1 | ||
reduced_shape[dim] = 1 | ||
|
||
attr = ts.TosaSerializerAttribute() | ||
attr.AxisAttribute(input_node.dim_order.index(dim)) | ||
|
||
if dim == dim_list[-1]: | ||
output_name = output.name | ||
else: | ||
output_name = tosa_graph.addIntermediate( | ||
tutils.tosa_shape(reduced_shape, input_node.dim_order), | ||
dtype=ts.DType.FP32, | ||
).name | ||
|
||
tosa_graph.addOperator( | ||
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr | ||
) | ||
|
||
input_name = output_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 0 additions & 56 deletions
56
backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py
This file was deleted.
Oops, something went wrong.
57 changes: 57 additions & 0 deletions
57
backends/arm/quantizer/quantization_annotation/sum_annotator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Callable, cast, List, Optional | ||
|
||
import torch | ||
from executorch.backends.arm.quantizer import arm_quantizer_utils | ||
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator | ||
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig | ||
|
||
from torch.ao.quantization.quantizer import ( | ||
QuantizationAnnotation, | ||
QuantizationSpecBase, | ||
SharedQuantizationSpec, | ||
) | ||
from torch.fx import Node | ||
|
||
|
||
@register_annotator("sum") | ||
def _annotate_sum( | ||
gm: torch.fx.GraphModule, | ||
quantization_config: QuantizationConfig, | ||
filter_fn: Optional[Callable[[Node], bool]] = None, | ||
) -> Optional[List[List[Node]]]: | ||
annotated_partitions = [] | ||
for node in gm.graph.nodes: | ||
if node.target is not torch.ops.aten.sum.dim_IntList: | ||
continue | ||
if filter_fn and not filter_fn(node): | ||
continue | ||
|
||
sum_node = node | ||
if arm_quantizer_utils.is_annotated(sum_node): | ||
continue | ||
|
||
input_act = sum_node.args[0] | ||
|
||
if not isinstance(input_act, Node): | ||
continue | ||
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm): | ||
continue | ||
|
||
input_act_qspec = cast( | ||
Optional[QuantizationSpecBase], quantization_config.get_input_act_qspec() | ||
) | ||
input_qspec_map = {input_act: input_act_qspec} | ||
shared_with_input0_qspec = SharedQuantizationSpec((input_act, sum_node)) | ||
|
||
sum_node.meta["quantization_annotation"] = QuantizationAnnotation( | ||
input_qspec_map=input_qspec_map, | ||
output_qspec=shared_with_input0_qspec, | ||
_annotated=True, | ||
) | ||
annotated_partitions.append([sum_node]) | ||
return annotated_partitions |
Oops, something went wrong.