Skip to content

Commit

Permalink
Fix pyre issues
Browse files Browse the repository at this point in the history
Address issues from pyre and add similar # pyre-ignores as in
pytorch#7362.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: I6feaa611dcd539b3b0d21a6a7dd696ef7db691ef
  • Loading branch information
oscarandersson8218 committed Dec 20, 2024
1 parent b8343a2 commit cdd6c91
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 28 deletions.
16 changes: 7 additions & 9 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Any, Dict, List

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import (
get_source_partitions,
SourcePartition,
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateDecomposedMatmulPass(ExportPass):
Expand All @@ -28,8 +24,8 @@ class AnnotateDecomposedMatmulPass(ExportPass):
matmul-op (can be mm or bmm).
"""

def call(self, graph_module: GraphModule):
matmul_partitions: Dict[Any, List[SourcePartition]] = get_source_partitions(
def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
[
torch.matmul,
Expand All @@ -56,7 +52,7 @@ def call(self, graph_module: GraphModule):
input_node = partition.input_nodes[i]
matmul_input_node = matmul_args[i]
# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.args[0])
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = input_node.args[1:]
with graph_module.graph.inserting_before(matmul_node):
Expand All @@ -81,7 +77,9 @@ def call(self, graph_module: GraphModule):
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *output_node_qargs)
# Remove partition output q-node
partition_output.replace_all_uses_with(partition_output.args[0])
partition_output.replace_all_uses_with(
partition_output.all_input_nodes[0]
)
graph_module.graph.erase_node(partition_output)

# retrace the graph to update the fake tensor types
Expand Down
22 changes: 17 additions & 5 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@

import copy

from typing import cast, Iterable
from typing import cast, Dict, Iterable, Set, Tuple

from executorch.backends.arm.tosa_quant_utils import QuantArgs

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch.fx import GraphModule, Node

q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
Expand Down Expand Up @@ -82,7 +88,7 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:

def fold_and_annotate_arg(
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
):
) -> None:
input_qparams = None
nodes_to_remove = set()
for arg in arg_list:
Expand Down Expand Up @@ -210,11 +216,17 @@ class RetraceFoldedDtypesPass(ExportPass):
the output type of that matches the type of the output_qparams.
"""

targeted_ops = {
targeted_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.sum.dim_IntList,
}

def call_operator(self, op, args, kwargs, meta):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

Expand Down
9 changes: 5 additions & 4 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable
from typing import Callable, Dict

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.exir import ExportedProgram

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
Expand All @@ -22,7 +23,7 @@


@impl(lib, "_table")
def _table_impl(*args, **kwargs):
def _table_impl(*args, **kwargs): # pyre-ignore
return args[0]


Expand All @@ -34,7 +35,7 @@ class InsertTableOpsPass(ExportPass):
which will be used to produce the table values in operators/op_table.py.
"""

table_ops = {
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
exir_ops.edge.aten.exp.default: torch.exp,
exir_ops.edge.aten.log.default: torch.log,
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
Expand All @@ -43,7 +44,7 @@ class InsertTableOpsPass(ExportPass):
exir_ops.edge.aten.tanh.default: torch.tanh,
}

def __init__(self, exported_program: ExportedProgram):
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.exported_program = exported_program

Expand Down
8 changes: 5 additions & 3 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import serializer.tosa_serializer as ts
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand Down Expand Up @@ -49,7 +51,7 @@ def define_node(
# for a later rescale.

if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
input_qparams = get_input_qparams(node) # pyre-ingore[16]
input0_zp = input_qparams[0].zp
input1_zp = input_qparams[1].zp
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
Expand All @@ -71,9 +73,9 @@ def define_node(

# As INT8 accumulates into INT32, we need to rescale it back to INT8
if output.dtype == ts.DType.INT8:
output_qparams = get_output_qparams(node)[0]
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
final_output_scale = (
input_qparams[0].scale * input_qparams[1].scale
input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61]
) / output_qparams.scale

build_rescale(
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import serializer.tosa_serializer as ts
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
Expand Down Expand Up @@ -39,7 +41,7 @@ def define_node(

if inputs[0].dtype == ts.DType.INT8:
# Get quant parameters
input_qparams = get_input_qparams(node)
input_qparams = get_input_qparams(node) # pyre-ignore[16]
qargs = input_qparams[0]
# Convert to quantized representation
clamp_min_qs = quantize_value(inputs[1].number, qargs)
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import serializer.tosa_serializer as ts
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand Down Expand Up @@ -49,12 +51,12 @@ def define_node(
# Initilize zero point to zero.
input_zp = 0
if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
input_qparams = get_input_qparams(node) # pyre-ignore[16]
input_zp = input_qparams[0].zp

output_zp = 0
if output.dtype == ts.DType.INT8:
output_qparams = get_output_qparams(node)
output_qparams = get_output_qparams(node) # pyre-ignore[16]
output_zp = output_qparams[0].zp

attr = ts.TosaSerializerAttribute()
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/operators/op_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import serializer.tosa_serializer as ts
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand Down Expand Up @@ -52,7 +54,7 @@ def define_node(
# The output also needs to be rank 3
output_new_shape = (1, output.shape[0], output.shape[1])

input_qparams = get_input_qparams(node)
input_qparams = get_input_qparams(node) # pyre-ignore[16]
assert len(input_qparams) == 2
input0_qparams = input_qparams[0]
input1_qparams = input_qparams[1]
Expand All @@ -78,7 +80,7 @@ def define_node(
)

# As INT8 accumulates into INT32, we need to rescale it back to INT8
output_qparams = get_output_qparams(node)
output_qparams = get_output_qparams(node) # pyre-ignore[16]
assert len(output_qparams) == 1

final_output_scale = (
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import serializer.tosa_serializer as ts
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
Expand Down Expand Up @@ -43,7 +45,7 @@ def define_node(
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
input_A = inputs[0]
input_B = inputs[1]
input_qparams = get_input_qparams(node)
input_qparams = get_input_qparams(node) # pyre-ignore[16]
input_A_qargs = input_qparams[0]
input_B_qargs = input_qparams[1]
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/op_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import serializer.tosa_serializer as ts
import torch.fx

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_output_qparams,
)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def quantize_value(self, x):
self.qmax,
).to(self.dtype)

def dequantize_value(self, qx: int) -> float:
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
return (qx - self.zp) * self.scale

def __eq__(self, other):
Expand Down

0 comments on commit cdd6c91

Please sign in to comment.