Skip to content

Commit

Permalink
2024-11-05 nightly release (8ab3385)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 5, 2024
1 parent 429810c commit 5f47771
Show file tree
Hide file tree
Showing 28 changed files with 518 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_first_fake_tensor,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs):
return args[0]


register_passable_op(torch.ops.passthrough_to_tosa._transpose)


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
Expand Down
14 changes: 1 addition & 13 deletions backends/arm/_passes/insert_squeeze_after_sum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass):
sum(dims, keep_dim = False)
After pass:
sum(dims, keep_dim = True)
(q)
(dq)
squeeze(dim = dims)
"""

Expand All @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

dim_list = cast(list[int], sum_node.args[1])
quantized = is_quant_node(sum_node)
if quantized:
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
qparams = qparams + (torch.int8,)
else:
qparams = None

# Add keep_dim = True arg to sum node.
sum_node.args = sum_node.args[0:2] + (True,)
Expand All @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
sum_node.replace_all_uses_with(squeeze_node)
squeeze_node.args = (sum_node, dim_list)
if quantized:
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import cast, Optional

import torch.fx
from executorch.backends.arm.tosa_quant_utils import is_quant_node
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload
Expand Down Expand Up @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
slice_node = graph.create_node(
"call_function", self.slice_op, (last_node,) + args
)
if is_quant_node(last_node):
if is_node_quantized(last_node):
q_params = last_node.args[1:]
dq_node = insert_q_dq_pair(
graph_module.graph, slice_node, q_params
Expand Down
38 changes: 14 additions & 24 deletions backends/arm/operators/op_addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
build_rescale,
search_quant_arg_downstream,
search_quant_arg_upstream,
)

from executorch.backends.arm.tosa_utils import build_reshape
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaOp


Expand Down Expand Up @@ -67,12 +70,7 @@ def define_node(
input_zp = 0
if is_quant_node:
input_node = node.all_input_nodes[1]
# rank > 2 linear layer
if input_node.target == exir_ops.edge.aten.view_copy.default:
quant_node = input_node.all_input_nodes[0]
else:
quant_node = input_node
input_zp = get_quant_node_args(quant_node).zp
input_zp = search_quant_arg_upstream(input_node).zp
attr.ConvAttribute(
pad=pad_attr,
stride=stride_attr,
Expand Down Expand Up @@ -107,24 +105,16 @@ def define_node(
# Read inputs' parent nodes
_, input_node, weight_node = node.all_input_nodes

# rank > 2 linear layer
if input_node.target == exir_ops.edge.aten.view_copy.default:
quant_node = input_node.all_input_nodes[0]
input_scale = get_quant_node_args(quant_node).scale
consumer_node = list(node.users)[0]
consumer_consumer_node = list(consumer_node.users)[0]
quant_args = get_quant_node_args(consumer_consumer_node)
consumer_node_scale = quant_args.scale
consumer_node_node_zp = quant_args.zp
else:
input_scale = get_quant_node_args(input_node).scale
consumer_node = list(node.users)[0]
quant_args = get_quant_node_args(consumer_node)
consumer_node_scale = quant_args.scale
consumer_node_node_zp = quant_args.zp
qargs = search_quant_arg_upstream(input_node)
input_scale = qargs.scale
consumer_node = list(node.users)[0]
quant_args = search_quant_arg_downstream(consumer_node)

consumer_node_scale = quant_args.scale
consumer_node_node_zp = quant_args.zp

weight_node_q_node = weight_node.all_input_nodes[0]
weight_scale = get_quant_node_args(weight_node_q_node).scale
weight_scale = search_quant_arg_upstream(weight_node_q_node).scale

output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale

Expand Down
16 changes: 10 additions & 6 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
build_rescale,
search_quant_arg_downstream,
search_quant_arg_upstream,
)
from executorch.backends.arm.tosa_utils import get_two_inputs
from serializer.tosa_serializer import TosaOp

Expand Down Expand Up @@ -42,8 +46,10 @@ def define_node(
# For INT8, we need to get the zero points and add an intermediate tensor
# for a later rescale.
if is_quant_node:
input0_zp = get_quant_node_args(input0).zp
input1_zp = get_quant_node_args(input1).zp
input0_q_params = search_quant_arg_upstream(input0)
input1_q_params = search_quant_arg_upstream(input1)
input0_zp = input0_q_params.zp
input1_zp = input1_q_params.zp
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
bmm_output_name = bmm_result.name
else:
Expand All @@ -63,9 +69,7 @@ def define_node(

# As INT8 accumulates into INT32, we need to rescale it back to INT8
if is_quant_node:
input0_q_params = get_quant_node_args(input0)
input1_q_params = get_quant_node_args(input1)
output_q_params = get_quant_node_args(list(node.users)[0])
output_q_params = search_quant_arg_downstream(list(node.users)[0])

final_output_scale = (
input0_q_params.scale * input1_q_params.scale
Expand Down
22 changes: 13 additions & 9 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import cast, List
from typing import List

import serializer.tosa_serializer as ts
import torch
Expand All @@ -15,9 +15,10 @@
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
build_rescale_conv_output,
get_quant_node_args,
search_quant_arg_downstream,
search_quant_arg_upstream,
)
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape

from serializer.tosa_serializer import TosaOp

Expand Down Expand Up @@ -82,7 +83,9 @@ def define_node(
)

input_zp = (
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
search_quant_arg_upstream(node.all_input_nodes[0]).zp
if is_quant_node
else 0
)

attr.ConvAttribute(
Expand Down Expand Up @@ -158,9 +161,10 @@ def define_node(
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
# Get scale_factor from input, weight, and output.
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
input_scale = search_quant_arg_upstream(node.all_input_nodes[0]).scale
weight_scale = search_quant_arg_upstream(node.all_input_nodes[1]).scale
output_qargs = search_quant_arg_downstream(list(node.users)[0])

build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
Expand All @@ -169,6 +173,6 @@ def define_node(
actual_out_type,
input_scale,
weight_scale,
output_scale,
output_zp,
output_qargs.scale,
output_qargs.zp,
)
7 changes: 4 additions & 3 deletions backends/arm/operators/op_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
search_quant_arg_downstream,
search_quant_arg_upstream,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node
Expand Down Expand Up @@ -48,9 +49,9 @@ def define_node(

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
in_quantargs = search_quant_arg_upstream(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)
out_quantargs = search_quant_arg_downstream(output_node)

table = exp_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
Expand Down
11 changes: 6 additions & 5 deletions backends/arm/operators/op_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
quantize_value,
search_quant_arg_downstream,
)
from executorch.backends.arm.tosa_utils import tosa_shape
from torch.fx import Node

Expand All @@ -39,10 +42,8 @@ def define_node(

value = inputs[1].number
if is_quant_node:
qargs = get_quant_node_args(list(node.users)[0])
qvalue = np.clip(
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
)
qargs = search_quant_arg_downstream(list(node.users)[0])
qvalue = quantize_value(value, qargs)
dtype = ts.DType.INT8
data = np.full(shape, qvalue, dtype=np.int8)
else:
Expand Down
13 changes: 7 additions & 6 deletions backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
quantize_value,
search_quant_arg_upstream,
)
from serializer.tosa_serializer import TosaOp


Expand All @@ -37,12 +40,10 @@ def define_node(

if is_quant_node:
# Get quant parameters
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
qargs = search_quant_arg_upstream(node.all_input_nodes[0])
# Convert to quantized representation
clamp_min_qs = round((inputs[1].number / scale) + zp)
clamp_min_qs = max(clamp_min_qs, qmin)
clamp_max_qs = round((inputs[2].number / scale) + zp)
clamp_max_qs = min(clamp_max_qs, qmax)
clamp_min_qs = quantize_value(inputs[1].number, qargs)
clamp_max_qs = quantize_value(inputs[2].number, qargs)
# Set fp values to 0.0 since they are not used
clamp_min_fp = 0.0
clamp_max_fp = 0.0
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/operators/op_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
search_quant_arg_downstream,
search_quant_arg_upstream,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node
Expand Down Expand Up @@ -49,9 +50,9 @@ def define_node(

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
in_quantargs = search_quant_arg_upstream(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)
out_quantargs = search_quant_arg_downstream(output_node)

table = log_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
Expand Down
16 changes: 10 additions & 6 deletions backends/arm/operators/op_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
build_rescale,
search_quant_arg_downstream,
search_quant_arg_upstream,
)
from executorch.backends.arm.tosa_utils import (
build_reshape,
expand_dims,
Expand Down Expand Up @@ -54,8 +58,8 @@ def define_node(
# For INT8, we need to get the zero point, otherwise it is 0
input0_zp, input1_zp = 0, 0
if is_quant_node:
input0_zp = get_quant_node_args(input0).zp
input1_zp = get_quant_node_args(input1).zp
input0_zp = search_quant_arg_upstream(input0).zp
input1_zp = search_quant_arg_upstream(input1).zp

mat_mul_result = tosa_graph.addIntermediate(
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
Expand Down Expand Up @@ -86,9 +90,9 @@ def define_node(

# As INT8 accumulates into INT32, we need to rescale it back to INT8
if is_quant_node:
input0_q_params = get_quant_node_args(input0)
input1_q_params = get_quant_node_args(input1)
output_q_params = get_quant_node_args(list(node.users)[0])
input0_q_params = search_quant_arg_upstream(input0)
input1_q_params = search_quant_arg_upstream(input1)
output_q_params = search_quant_arg_downstream(list(node.users)[0])

final_output_scale = (
input0_q_params.scale * input1_q_params.scale
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def define_node(
if is_quant_node:
input_A = inputs[0]
input_B = inputs[1]
input_A_qargs = tqutils.get_quant_node_args(
input_A_qargs = tqutils.search_quant_arg_upstream(
cast(torch.fx.Node, node.args[0])
)
input_B_qargs = tqutils.get_quant_node_args(
input_B_qargs = tqutils.search_quant_arg_upstream(
cast(torch.fx.Node, node.args[1])
)

Expand Down
Loading

0 comments on commit 5f47771

Please sign in to comment.