From d6d19ce51ad6c01750bd072746f66f50686f65c0 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 22 Jan 2025 13:38:52 -0800 Subject: [PATCH] Move XNNPACKQuantizer from PyTorch to ExecuTorch (#7804) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/7804 X-link: https://github.com/pytorch/pytorch/pull/144940 This migrates XNNPACKQuantizer from PyTorch to ExecuTorch. Rationale: Main motivation is to avoid pytorch pin update in OSS after updating XNNPACKQuantizer, which can be rather frequent. Other impact and considerations: - PT2e flow (which lives in PyTorch) relies havily on XNNPACKQuantizer for a "example" implementation for quantizer and more importantly tests. Fow now, we will keep the torch.ao.quantization.xnnpack_quantizer as is but mark is as not BC, and deprecated to discourace future new dependencies on it. - Other OSS repository using XNNPACKQuantizer from PyTorch now have to take an additional dependency on ExecuTorch. Reviewed By: mcr229 Differential Revision: D68191752 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/export_example.py | 8 +- backends/cadence/aot/quantizer/TARGETS | 1 + backends/cadence/aot/quantizer/quantizer.py | 12 +- backends/example/TARGETS | 1 + backends/example/example_quantizer.py | 2 +- backends/transforms/targets.bzl | 1 + .../test_duplicate_dynamic_quant_chain.py | 4 +- backends/vulkan/quantizer/TARGETS | 1 + backends/vulkan/quantizer/vulkan_quantizer.py | 8 +- backends/xnnpack/quantizer/TARGETS | 20 + .../xnnpack/quantizer/xnnpack_quantizer.py | 436 +++++++ .../quantizer/xnnpack_quantizer_utils.py | 1127 +++++++++++++++++ backends/xnnpack/test/TARGETS | 17 + backends/xnnpack/test/ops/test_conv2d.py | 10 +- backends/xnnpack/test/ops/test_linear.py | 12 +- .../test/quantizer/test_pt2e_quantization.py | 824 ++++++++++++ .../test/quantizer/test_representation.py | 311 +++++ .../test/quantizer/test_xnnpack_quantizer.py | 1090 ++++++++++++++++ backends/xnnpack/test/test_xnnpack_utils.py | 8 +- backends/xnnpack/test/tester/TARGETS | 2 + backends/xnnpack/test/tester/tester.py | 12 +- docs/source/llm/getting-started.md | 2 +- ...e-delegates-executorch-xnnpack-delegate.md | 2 +- .../tutorial-xnnpack-delegate-lowering.md | 2 +- .../export-to-executorch-tutorial.py | 4 +- examples/models/llava/export_llava.py | 8 +- .../models/phi-3-mini/export_phi-3-mini.py | 8 +- examples/xnnpack/quantization/example.py | 8 +- examples/xnnpack/quantization/utils.py | 5 +- exir/tests/TARGETS | 4 + exir/tests/test_passes.py | 12 +- exir/tests/test_quantization.py | 10 +- exir/tests/test_quantize_io_pass.py | 10 +- extension/llm/export/TARGETS | 1 + extension/llm/export/quantizer_lib.py | 8 +- 36 files changed, 3919 insertions(+), 73 deletions(-) create mode 100644 backends/xnnpack/quantizer/TARGETS create mode 100644 backends/xnnpack/quantizer/xnnpack_quantizer.py create mode 100644 backends/xnnpack/quantizer/xnnpack_quantizer_utils.py create mode 100644 backends/xnnpack/test/quantizer/test_pt2e_quantization.py create mode 100644 backends/xnnpack/test/quantizer/test_representation.py create mode 100644 backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index b1484855d6..0590e69460 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -65,6 +65,7 @@ python_library( "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/runtime:runtime", "//executorch/backends/cadence/aot/quantizer:quantizer", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//executorch/backends/transforms:decompose_sdpa", "//executorch/backends/transforms:remove_clone_ops", "//executorch/exir:lib", diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 28a1a60a2a..0345aa6e2e 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -23,13 +23,13 @@ from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer from executorch.backends.cadence.runtime import runtime from executorch.backends.cadence.runtime.executor import BundledProgramManager -from executorch.exir import ExecutorchProgramManager -from torch import nn -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( QuantizationConfig, QuantizationSpec, ) +from executorch.exir import ExecutorchProgramManager +from torch import nn +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver from .utils import save_bpte_program, save_pte_program diff --git a/backends/cadence/aot/quantizer/TARGETS b/backends/cadence/aot/quantizer/TARGETS index 6290626216..75eab631dd 100644 --- a/backends/cadence/aot/quantizer/TARGETS +++ b/backends/cadence/aot/quantizer/TARGETS @@ -34,6 +34,7 @@ python_library( ":patterns", ":utils", "//caffe2:torch", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils", ], ) diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 65979919ed..e47ae15e8d 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -26,18 +26,18 @@ is_annotated, no_outside_users, ) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + QuantizationAnnotation, + QuantizationConfig, + QuantizationSpec, +) from torch import fx from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - OperatorConfig, - QuantizationAnnotation, - QuantizationConfig, - QuantizationSpec, -) act_qspec = QuantizationSpec( diff --git a/backends/example/TARGETS b/backends/example/TARGETS index 59df492e02..e99a408cdb 100644 --- a/backends/example/TARGETS +++ b/backends/example/TARGETS @@ -11,6 +11,7 @@ python_library( deps = [ "//caffe2:torch", "//executorch/backends/example/example_operators:example_operators_lib", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", ], ) diff --git a/backends/example/example_quantizer.py b/backends/example/example_quantizer.py index 7034b9792f..74a0057ba4 100644 --- a/backends/example/example_quantizer.py +++ b/backends/example/example_quantizer.py @@ -9,11 +9,11 @@ import torch from executorch.backends.example.example_operators.ops import module_to_annotator +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig from torch import fx from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OperatorConfig def get_uint8_tensor_spec(observer_or_fake_quant_ctr): diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 14725636f3..09ef0f59c5 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -195,6 +195,7 @@ def define_common_targets(): deps = [ "fbsource//third-party/pypi/expecttest:expecttest", # @manual ":duplicate_dynamic_quant_chain", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//caffe2:torch", "//executorch/exir:lib", ], diff --git a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py index 637ce807c1..91d2ddc916 100644 --- a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py +++ b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py @@ -11,11 +11,11 @@ from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( DuplicateDynamicQuantChainPass, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e # TODO: Move away from using torch's internal testing utils from torch.testing._internal.common_quantization import ( diff --git a/backends/vulkan/quantizer/TARGETS b/backends/vulkan/quantizer/TARGETS index 7cc5b79eb2..5650f2bd72 100644 --- a/backends/vulkan/quantizer/TARGETS +++ b/backends/vulkan/quantizer/TARGETS @@ -9,5 +9,6 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils", ], ) diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 451f18977e..2ea3e321dc 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -12,15 +12,15 @@ from typing import Any, Callable, Dict, Optional import torch -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, OP_TO_ANNOTATOR, propagate_annotation, QuantizationConfig, ) +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node diff --git a/backends/xnnpack/quantizer/TARGETS b/backends/xnnpack/quantizer/TARGETS new file mode 100644 index 0000000000..c14998d0e3 --- /dev/null +++ b/backends/xnnpack/quantizer/TARGETS @@ -0,0 +1,20 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "xnnpack_quantizer", + srcs = ["xnnpack_quantizer.py"], + deps = [ + ":xnnpack_quantizer_utils", + "//caffe2:torch", + "//executorch/exir:lib", + ], +) + +python_library( + name = "xnnpack_quantizer_utils", + srcs = ["xnnpack_quantizer_utils.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + ], +) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py new file mode 100644 index 0000000000..04a02c8fec --- /dev/null +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -0,0 +1,436 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy +import functools +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch._dynamo as torchdynamo +import torch.nn.functional as F +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + _convert_scalars_to_attrs, + OP_TO_ANNOTATOR, + OperatorConfig, + OperatorPatternType, + propagate_annotation, + QuantizationConfig, +) +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + from torch.fx import Node + + +__all__ = [ + "XNNPACKQuantizer", + "get_symmetric_quantization_config", +] + + +def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) + gm.graph.eliminate_dead_code() + return gm.graph + + +def _get_linear_patterns(input_size: list[int]): + in_channels = input_size[-1] + out_channels = 8 # hard coding but this should not matter + weight = torch.ones((out_channels, in_channels)) + bias = torch.ones((out_channels,)) + act = torch.ones(input_size) + + def linear_op(act, weight, bias=None): + return F.linear(act, weight, bias) + + pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) + pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) + return [pattern_w_bias, pattern_wo_bias] + + +def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPatternType]]: + supported_operators: dict[str, list[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "adaptive_avg_pool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } + return copy.deepcopy(supported_operators) + + +def _get_supported_symmetric_config_and_operators() -> list[OperatorConfig]: + supported_config_and_operators: list[OperatorConfig] = [] + for quantization_config in [ + get_symmetric_quantization_config(), + get_symmetric_quantization_config(is_qat=True), + get_symmetric_quantization_config(is_per_channel=True), + get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + ]: + ops = _supported_symmetric_quantized_operators() + supported_config_and_operators.extend( + OperatorConfig(quantization_config, pattern_list) + for pattern_list in ops.values() + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_symmetric_quantization_config( + is_per_channel: bool = False, + is_qat: bool = False, + is_dynamic: bool = False, + act_qmin: int = -128, + act_qmax: int = 127, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + MinMaxObserver + ) + if is_qat: + # TODO: qat + per channel? + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if weight_qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, + None, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> list[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() + + +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + tp_str = tp.__module__ + "." + tp.__qualname__ + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [] + for _, t in nn_module_stack.values(): + # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) + # return type. Handle both cases. + if isinstance(t, type): + t = t.__module__ + "." + t.__qualname__ + types.append(t) + return tp_str in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: list[Callable], module_name_list: list[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + +class XNNPACKQuantizer(Quantizer): + supported_config_and_operators = _get_supported_config_and_operators() + STATIC_QAT_ONLY_OPS = [ + "conv_bn_relu", + "conv_bn", + "conv_transpose_bn_relu", + "conv_transpose_bn", + ] + + # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops + STATIC_OPS = [ + "linear_relu", + "linear", + "conv_relu", + "conv", + "conv_transpose_relu", + "adaptive_avg_pool2d", + # TODO: move this to BoltNNQuantizer? + "gru_io_only", + "add_relu", + "add", + "mul_relu", + "mul", + "cat", + ] + + DYNAMIC_OPS = [ + "linear", + ] + + def __init__(self) -> None: + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_config: dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: dict[str, Optional[QuantizationConfig]] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.supported_config_and_operators + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> list[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: + self.global_config = quantization_config + return self + + def set_operator_type( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig, + ) -> XNNPACKQuantizer: + self.operator_type_config[operator_type] = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: QuantizationConfig + ): + """Set quantization_config for a submodule with type: `module_type`, for example: + quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator + patterns in the submodule with this module type with the given `quantization_config` + """ + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + assert ( + quantization_config is not None + ), " quantization_config == None is not supported yet" + self.module_name_config[module_name] = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + # hacked for handling dynamic linear quant. will fix later. + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for op in self.STATIC_QAT_ONLY_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + for op in self.STATIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_all_dynamic_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + for op in self.DYNAMIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_static_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_dynamic_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py new file mode 100644 index 0000000000..b655c36e28 --- /dev/null +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -0,0 +1,1127 @@ +# mypy: allow-untyped-defs +import itertools +import typing +from dataclasses import dataclass +from typing import Callable, NamedTuple, Optional + +import torch +import torch.nn.functional as F +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +__all__ = [ + "OperatorConfig", + "OperatorPatternType", + "QuantizationConfig", + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", + "OP_TO_ANNOTATOR", + "propagate_annotation", +] + + +# In the absence of better name, just winging it with QuantizationConfig +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec] + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this + is_qat: bool = False + + +# Use Annotated because list[Callable].__module__ is read-only. +OperatorPatternType = typing.Annotated[list[Callable], None] +OperatorPatternType.__module__ = ( + "executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils" +) + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[list[list[Node]]], +] +OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {} + + +def register_annotator(op: str) -> Callable[[AnnotatorType], None]: + def decorator(annotator: AnnotatorType) -> None: + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +class OperatorConfig(NamedTuple): + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + config: QuantizationConfig + operators: list[OperatorPatternType] + + +def _is_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.input_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.input_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.output_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.output_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.weight is None: + return None + quantization_spec: QuantizationSpec = quantization_config.weight + if quantization_spec.qscheme not in [ + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + None, + ]: + raise ValueError( + f"Unsupported quantization_spec {quantization_spec} for weight" + ) + return quantization_spec + + +def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.bias is None: + return None + quantization_spec: QuantizationSpec = quantization_config.bias + assert ( + quantization_spec.dtype == torch.float + ), "Only float dtype for bias is supported for bias right now" + return quantization_spec + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + _annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + _annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + _annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + if len(linear_node.users) > 1: + # if linear node has multiple users, then it can't be fused with relu + continue + + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv") +def _annotate_conv( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def _do_annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + is_conv_transpose: bool = False, +): + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = n + maybe_conv_node = n.args[0] + + is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node + if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node): + continue + conv_node = maybe_conv_node + + if len(conv_node.users) > 1: + # relu shouldn't be fuseable to conv if there are other users + # of convolution + continue + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [relu_node, conv_node, conv_node.args[1]] + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_relu") +def _annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=False + ) + + +@register_annotator("conv_transpose_relu") +def _annotate_conv_transpose_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=True + ) + + +@register_annotator("conv_bn") +def _annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) + + +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) + + +@register_annotator("conv_transpose_bn") +def _annotate_conv_transpose_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv_transpose + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True + ) + + +@register_annotator("conv_transpose_bn_relu") +def _annotate_conv_transpose_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv_transpose + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True + ) + + +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]], + has_relu: bool, + is_conv_transpose: bool = False, +) -> list[list[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. + + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ + + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _WrapperModule(_conv_bn) + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + if is_conv_transpose: + combinations = [ + (F.conv_transpose1d, _conv1d_bn_example_inputs), + (F.conv_transpose2d, _conv2d_bn_example_inputs), + ] + else: + combinations = [ + (F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item] + (F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item] + ] + + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( # type: ignore[assignment] + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc] + pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type] + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("gru_io_only") +def _annotate_gru_io_only( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) + gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) + annotated_partitions = [] + for gru_partition in gru_partitions: + annotated_partitions.append(gru_partition.nodes) + output_nodes = gru_partition.output_nodes + input_nodes = gru_partition.input_nodes + # skip annotation if it is already annotated + if _is_annotated(input_nodes + output_nodes): + continue + # inside each GRU partition, we should be able to annotate each linear + # subgraph + input_act = input_nodes[0] + input_act_user = next(iter(input_act.users.keys())) + assert isinstance(input_act, Node) + assert isinstance(input_act_user, Node) + input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + hidden_state = input_nodes[1] + hidden_state_user = next(iter(hidden_state.users.keys())) + assert isinstance(hidden_state, Node) + assert isinstance(hidden_state_user, Node) + hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + hidden_state: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + assert len(output_nodes) == 2, "expecting GRU to have two outputs" + for output in output_nodes: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + nodes_to_mark_annotated = list(gru_partition.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + return annotated_partitions + + +@register_annotator("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """Always annotate adaptive_avg_pool2d op""" + module_partitions = get_source_partitions( + gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn + ) + partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for partition in partitions: + pool_node = partition.output_nodes[0] + if ( + pool_node.op != "call_function" + or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default + ): + raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") + + if _is_annotated([pool_node]): + continue + + annotated_partitions.append(partition.nodes) + input_act = pool_node.args[0] + assert isinstance(input_act, Node) + + # only annotate input output sharing operator + # when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + input_act_qspec = get_input_act_qspec(quantization_config) + else: + input_act_qspec = SharedQuantizationSpec(input_act) + + # output sharing with input + output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) + pool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the node + since histc op (in HistogramObserver) only works for values up to certain upper bound + """ + if node.op == "get_attr": + qualified_name = str(node.target) + module_path, _, name = qualified_name.rpartition(".") + submod = gm.get_submodule(module_path) + tensor = getattr(submod, name) + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + + +@register_annotator("add_relu") +def _annotate_add_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_add = node.args[0] + if ( + not isinstance(maybe_add, Node) + or maybe_add.op != "call_function" + or maybe_add.target + not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ] + ): + continue + + add_node = maybe_add + + if len(add_node.users) > 1: + # add can't be fused with ReLU if the result of add is being used + # else where in the graph + continue + + partition = [relu_node, add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("add") +def _annotate_add( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + continue + add_node = node + partition = [add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act1) + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul_relu") +def _annotate_mul_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_mul = node.args[0] + if ( + not isinstance(maybe_mul, Node) + or maybe_mul.op != "call_function" + or maybe_mul.target + not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ] + ): + continue + + mul_node = maybe_mul + if len(mul_node.users) > 1: + # mul can't be fused with ReLU if the result of mul is being used + # else where in the graph + continue + + partition = [relu_node, mul_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul") +def _annotate_mul( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ]: + continue + + mul_node = node + partition = [mul_node] + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act0) + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +# TODO: remove Optional in return type, fix annotated_partitions logic +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + cat_node = cat_partition.output_nodes[0] + if _is_annotated([cat_node]): + continue + + if cat_node.target != torch.ops.aten.cat.default: + # TODO: change this to AnnotationException + raise Exception( # noqa: TRY002 + f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" + " please check if you are calling the correct capture API" + ) + + annotated_partitions.append(cat_partition.nodes) + + input_act_qspec = get_input_act_qspec(quantization_config) + inputs = cat_node.args[0] + + input_qspec_map = {} + input_act0 = inputs[0] # type: ignore[index] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec + + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] + for input_act in inputs[1:]: # type: ignore[index, union-attr] + if input_act not in input_qspec_map: + input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] + + output_act_qspec = shared_with_input0_qspec + + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable) -> bool: + return op in [ + torch.ops.aten.relu.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + # TODO: remove? + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# TODO: make the list of ops customizable +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index b2db8060e1..b3143743b9 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -35,6 +35,7 @@ runtime.python_test( ], deps = [ "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//executorch/backends/xnnpack/test/tester:tester", "//executorch/devtools:lib", "//executorch/devtools/bundled_program:config", @@ -76,3 +77,19 @@ runtime.python_test( "//executorch/backends/xnnpack:xnnpack_preprocess", ], ) + +runtime.python_test( + name = "test_xnnpack_quantizer", + srcs = glob([ + "quantizer/*.py", + ]), + deps = [ + "//executorch/backends/xnnpack:xnnpack_preprocess", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//pytorch/ao:torchao", # @manual + "//caffe2:torch", + ], + external_deps = [ + "libtorch", + ], +) diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 533b9ab90c..78bb288bc3 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -18,14 +18,16 @@ except: has_quantized_ops = False +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, +) from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn from executorch.backends.xnnpack.test.tester import Quantize, Tester from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, -) -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig class Conv2d(torch.nn.Module): diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 393b831c2a..eccda406b8 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -18,17 +18,19 @@ XnnpackFloatingPointPartitioner, XnnpackPartitioner, ) + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, +) from executorch.backends.xnnpack.test.tester import Quantize, Tester from executorch.backends.xnnpack.test.tester.tester import ( Partition, ToEdgeTransformAndLower, ) -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, -) -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig - try: from torchao.quantization.quant_api import ( int8_dynamic_activation_int4_weight, diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py new file mode 100644 index 0000000000..b030a78d03 --- /dev/null +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -0,0 +1,824 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from collections import Counter +from typing import Dict, Tuple + +import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.ao.quantization import ( + compare_results, + CUSTOM_KEY, + default_per_channel_symmetric_qnnpack_qconfig, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + observer, + prepare_for_propagation_comparison, +) +from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.ao.quantization.qconfig import ( + float_qparams_weight_only_qconfig, + per_channel_weight_observer_range_neg_127_to_127, + QConfig, + weight_observer_range_neg_127_to_127, +) +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + PT2EQuantizationTestCase, + TestHelperModules, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + TemporaryFileName, + TestCase, +) + + +class TestQuantizePT2E(PT2EQuantizationTestCase): + def _get_pt2e_quantized_linear( + self, is_per_channel: bool = False + ) -> torch.fx.GraphModule: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + return self._quantize(m, quantizer, example_inputs) + + def test_dont_fold_other_constant(self) -> None: + """Make sure the constant propagation does not apply to things unrelated to + quantization + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + t = self.dont_fold_me.t() + return self.linear(x) + t + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + # only quantize linear, so add is not quantized and the constant Tensor + # should not be folded + quantizer.set_module_type(torch.nn.Linear, operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + # transpose op not folded + ns.call_function(torch.ops.aten.t.default): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_all_ops_before_quantize(self) -> None: + """Test folding all ops that's before quantized operator: + Before: + get_attr(weight) -> transpose -> quantize -> dequantize + After: + get_attr(folded_weight) -> dequantize + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(2, 2) + + def forward(self, x): + t = self.weight.t() + return torch.nn.functional.linear(x, t) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_composable_quantizer_throw(self) -> None: + class BadQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in gm.graph.nodes: + n.meta["quantization_annotation"] = None + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + bad_quantizer = BadQuantizer() + composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer]) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + example_inputs = (torch.randn(2, 3, 4, 4),) + self.assertRaises( + RuntimeError, + lambda: self._test_quantizer( + m_eager, example_inputs, composable_quantizer, {} + ), + ) + + def test_composable_quantizer_linear_conv(self) -> None: + dynamic_quantizer = XNNPACKQuantizer() + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + dynamic_quantizer.set_global(quantization_config_dynamic) + static_quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + static_quantizer.set_global(quantization_config) + # Note that dynamic quantization must be applied first here. + # this is because static quantizer also quantizes linear with static qspec + # and if we apply static_quantizer first then dynamic_quantizer cannot be applied + composable_quantizer = ComposableQuantizer( + [dynamic_quantizer, static_quantizer] + ) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + dynamic_qconfig = QConfig( + activation=act_affine_quant_obs, + weight=weight_observer_range_neg_127_to_127, + ) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) + # Had to turn off check against fx because fx quant workflow does not seem + # to propagate observers for permute node for this model. + # Suprisingly it does propagate it for EmbeddingConvLinearModule + # TODO: Figure out the right behavior for propagation + self._test_quantizer( + m_eager, + example_inputs, + composable_quantizer, + node_occurrence, + [], + False, + qconfig_mapping, + ) + + def test_embedding_conv_linear_quantization(self) -> None: + m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() + indices = torch.tensor( + [ + 9, + 6, + 5, + 7, + 8, + 8, + 9, + 2, + 8, + 6, + 6, + 9, + 1, + 6, + 8, + 8, + 3, + 2, + 3, + 6, + 3, + 6, + 5, + 7, + 0, + 8, + 4, + 6, + 5, + 8, + 2, + 3, + ] + ) + indices = torch.unsqueeze(indices, 0) + example_inputs = (indices,) + + embedding_quantizer = EmbeddingQuantizer() + dynamic_quantizer = XNNPACKQuantizer() + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(quantization_config_dynamic) + static_quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + static_quantizer.set_global(quantization_config) + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, dynamic_quantizer, static_quantizer] + ) + + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + dynamic_qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) + qconfig_mapping = qconfig_mapping.set_object_type( + torch.nn.Embedding, float_qparams_weight_only_qconfig + ) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + self._test_quantizer( + m_eager, + example_inputs, + composed_quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_disallow_eval_train(self) -> None: + m = TestHelperModules.ConvWithBNRelu(relu=True) + example_inputs = (torch.rand(3, 3, 5, 5),) + + # Before export: this is OK + m.eval() + m.train() + + # After export: this is not OK + m = export_for_training(m, example_inputs).module() + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After prepare: still not OK + quantizer = XNNPACKQuantizer() + m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6] + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After convert: still not OK + m = convert_pt2e(m) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + def _get_bn_train_eval_ops(self) -> Tuple[torch._ops.OpOverload]: + return ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.batch_norm.default, + ) + + def _get_node( + self, m: torch.fx.GraphModule, target: torch._ops.OpOverload + ) -> torch.fx.Node: + """ + Return the first node matching the specified target, throwing an exception + if no such batch norm node is found. + """ + for n in m.graph.nodes: + if n.target == target: + return n + raise ValueError("Did not find node with target ", target) + + def test_allow_exported_model_train_eval(self) -> None: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, x): + x = self.bn(x) + x = self.dropout(x) + return x + + m = M().train() + example_inputs = (torch.randn(1, 3, 3, 3),) + bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() # pyre-ignore[23] + m = export_for_training(m, example_inputs).module() + + def _assert_ops_are_correct( + m: torch.fx.GraphModule, train: bool + ) -> None: # pyre-ignore [53] + bn_op = bn_train_op if train else bn_eval_op + bn_node = self._get_node(m, bn_op) + self.assertTrue(bn_node is not None) + dropout_node = self._get_node(m, torch.ops.aten.dropout.default) + self.assertEqual(dropout_node.args[2], train) + + # Before wrapping: this is not OK + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After wrapping: does not error and swaps the ops accordingly + torch.ao.quantization.allow_exported_model_train_eval(m) # pyre-ignore[6] + m.eval() + _assert_ops_are_correct(m, train=False) # pyre-ignore[6] + m.train() + _assert_ops_are_correct(m, train=True) # pyre-ignore[6] + + # After prepare but before wrapping: this is not OK + quantizer = XNNPACKQuantizer() + m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6] + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After prepare and after wrapping: does not error and swaps the ops accordingly + torch.ao.quantization.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + # After convert but before wrapping: this is not OK + m = convert_pt2e(m, fold_quantize=True) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After convert and after wrapping: does not error and swaps the ops accordingly + torch.ao.quantization.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + def test_constant_prop_preserve_metadata(self) -> None: + """Test to make sure the get_attr node for const propagated weight Tensor gets the correct + metadata (from original get_attr node from weight) + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training( + m, + example_inputs, + ).module() + weight_meta = None + for n in m.graph.nodes: # pyre-ignore[16] + if ( + n.op == "get_attr" + and next(iter(n.users)).target == torch.ops.aten.linear.default + ): + weight_meta = n.meta + break + assert weight_meta is not None, "Expect to find metadata for weight node" + + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + m = convert_pt2e(m) + + for n in m.graph.nodes: + if n.op == "get_attr" and "frozen_param" in n.target: + for key in n.meta: + self.assertEqual(n.meta[key], weight_meta[key]) + + def test_reentrant(self) -> None: + """Test we can safely call quantization apis multiple times""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_qat=True) + ) + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs + ).module() # pyre-ignore[8] + m.conv_bn_relu = prepare_qat_pt2e( + m.conv_bn_relu, quantizer + ) # pyre-ignore[6, 8] + m(*example_inputs) + m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) # pyre-ignore[6, 8] + + quantizer = XNNPACKQuantizer().set_module_type( + torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) + ) + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m = convert_pt2e(m) + + node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 4, + # one for weight + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function(torch.ops.aten.relu.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.linear.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + + def test_groupwise_per_channel_quant(self) -> None: + m = TestHelperModules.GroupwiseConv2d() + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + example_inputs = m.example_inputs() + m = self._quantize(m, quantizer, example_inputs) + # make sure it runs + m(*example_inputs) + + def test_preserve_nn_module_stack(self) -> None: + """Test we can preserve nn_module_stack on replaced pattern's nodes""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_qat=True) + ) + + def check_nn_module(node: torch.fx.Node) -> None: + self.assertTrue("nn_module_stack" in node.meta) + self.assertTrue( + "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] + ) + + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs + ).module() # pyre-ignore[8] + for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] + if node.op not in ["placeholder", "output", "get_attr"]: + check_nn_module(node) + m.conv_bn_relu = prepare_qat_pt2e( + m.conv_bn_relu, quantizer + ) # pyre-ignore[6, 8] + for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] + if node.name == "mul": + check_nn_module(node) + + def test_speed(self) -> None: + import time # noqa: F401 + + def dynamic_quantize_pt2e( + model, example_inputs + ) -> torch.fx.GraphModule: # pyre-ignore[2] + torch._dynamo.reset() + model = export_for_training(model, example_inputs).module() + # Per channel quantization for weight + # Dynamic quantization for activation + # Please read a detail: https://fburl.com/code/30zds51q + embedding_quantizer = EmbeddingQuantizer() + dynamic_quantizer = XNNPACKQuantizer() + operator_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(operator_config_dynamic) + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, dynamic_quantizer] + ) + # prev = time.time() + model = prepare_qat_pt2e(model, composed_quantizer) # pyre-ignore[6] + # cur = time.time() + # print("prepare time:", cur - prev) + # Without Calibraiton, scale/zero value will have an initialized value of 1.0 + # Per channel quantization needs a proper scale/zero shape/value to work properly. + # So we need to run calibration before converting to quantized model. + model(*example_inputs) + # prev = time.time() + model = convert_pt2e(model) + # cur = time.time() + # uncomment to see the time + # print("convert time:", cur - prev) + return model + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + m = M().eval() + example_inputs = (torch.randn(5, 5),) + _ = dynamic_quantize_pt2e(m, example_inputs) + + def test_multi_users_without_output_observer(self) -> None: + """ + Test the case in which a node is used by multiple users, + and had its output observer removed. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv(x) + return x, x + 1 + + example_inputs = (torch.randn(1, 3, 5, 5),) + m = M() + m = export_for_training(m, example_inputs).module() + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(), + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + + # Remove output observer + observer_to_remove = None + for n in m.graph.nodes: + if n.op == "output": + observer_to_remove = n.args[0][0] + assert observer_to_remove.op == "call_module" + assert observer_to_remove.target.startswith("activation_post_process_") + break + assert observer_to_remove is not None + observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) + m.graph.erase_node(observer_to_remove) + m.recompile() + + # Convert should succeed + m = convert_pt2e(m) + m(*example_inputs) + + def test_fold_quantize(self) -> None: + """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" + m = self._get_pt2e_quantized_linear() + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_quantize_per_channel(self) -> None: + """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" + m = self._get_pt2e_quantized_linear(is_per_channel=True) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_save_load(self) -> None: + """Test save/load a quantized model""" + m = self._get_pt2e_quantized_linear() + example_inputs = (torch.randn(2, 2),) + ref_res = m(*example_inputs) + + with TemporaryFileName() as fname: + # serialization + quantized_ep = torch.export.export(m, example_inputs, strict=True) + torch.export.save(quantized_ep, fname) + # deserialization + loaded_ep = torch.export.load(fname) + loaded_quantized_model = loaded_ep.module() + res = loaded_quantized_model(*example_inputs) + self.assertEqual(ref_res, res) + + +instantiate_parametrized_tests(TestQuantizePT2E) + + +class TestNumericDebugger(TestCase): + + def _extract_debug_handles(self, model) -> Dict[str, int]: # pyre-ignore[2] + debug_handle_map: Dict[str, int] = {} + + def _extract_debug_handles_from_node(node: torch.fx.Node) -> None: + nonlocal debug_handle_map + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ + NUMERIC_DEBUG_HANDLE_KEY + ] + + bfs_trace_with_node_process(model, _extract_debug_handles_from_node) + return debug_handle_map + + def _assert_each_node_has_debug_handle(self, model) -> None: # pyre-ignore[2] + def _assert_node_has_debug_handle(node: torch.fx.Node) -> None: + self.assertTrue( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], + f"Node {node} doesn't have debug handle", + ) + + bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + + def test_quantize_pt2e_preserve_handle(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + debug_handle_map = self._extract_debug_handles(m) + res_counter = Counter(debug_handle_map.values()) + repeated_debug_handle_ids = [1, 2, 3] + # 3 ids were repeated because we copy over the id from node to its output observer + # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default + for dh_id in repeated_debug_handle_ids: + self.assertEqual(res_counter[dh_id], 2) + + m(*example_inputs) + m = convert_pt2e(m) + self._assert_each_node_has_debug_handle(ep) + debug_handle_map = self._extract_debug_handles(m) + res_counter = Counter(debug_handle_map.values()) + # same set of ids where repeated, because we copy over the id from observer/fake_quant to + # dequantize node + repeated_debug_handle_ids = [1, 2, 3] + for dh_id in repeated_debug_handle_ids: + self.assertEqual(res_counter[dh_id], 2) + + def test_extract_results_from_loggers(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + m = convert_pt2e(m) + m_quant_logger = prepare_for_propagation_comparison(m) + + m_ref_logger(*example_inputs) + m_quant_logger(*example_inputs) + ref_results = extract_results_from_loggers(m_ref_logger) + quant_results = extract_results_from_loggers(m_quant_logger) + comparison_results = compare_results( + ref_results, quant_results + ) # pyre-ignore[6] + for node_summary in comparison_results.values(): + if len(node_summary.results) > 0: + self.assertGreaterEqual( + node_summary.results[0].sqnr, 35 + ) # pyre-ignore[6] + + def test_extract_results_from_loggers_list_output(self) -> None: + m = TestHelperModules.Conv2dWithSplit() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m(*example_inputs) + m = convert_pt2e(m) + m_quant_logger = prepare_for_propagation_comparison(m) + + m_ref_logger(*example_inputs) + m_quant_logger(*example_inputs) + ref_results = extract_results_from_loggers(m_ref_logger) + quant_results = extract_results_from_loggers(m_quant_logger) + comparison_results = compare_results( + ref_results, quant_results + ) # pyre-ignore[6] + for node_summary in comparison_results.values(): + if len(node_summary.results) > 0: + sqnr = node_summary.results[0].sqnr + if isinstance(sqnr, list): + for sqnr_i in sqnr: + self.assertGreaterEqual(sqnr_i, 35) + else: + self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6] diff --git a/backends/xnnpack/test/quantizer/test_representation.py b/backends/xnnpack/test/quantizer/test_representation.py new file mode 100644 index 0000000000..83cecaec5a --- /dev/null +++ b/backends/xnnpack/test/quantizer/test_representation.py @@ -0,0 +1,311 @@ +# Owner(s): ["oncall: quantization"] +import copy +from typing import Any, Optional + +import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, + skipIfNoQNNPACK, + TestHelperModules, +) + + +@skipIfNoQNNPACK +class TestPT2ERepresentation(QuantizationTestCase): + def _test_representation( + self, + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + quantizer: Quantizer, + ref_node_occurrence: dict[ns, int], + non_ref_node_occurrence: dict[ns, int], + fixed_output_tol: Optional[float] = None, + output_scale_idx: int = 2, + ) -> None: + # resetting dynamo cache + torch._dynamo.reset() + model = export_for_training( + model, + example_inputs, + ).module() + model_copy = copy.deepcopy(model) + + model = prepare_pt2e(model, quantizer) # pyre-ignore[6] + # Calibrate + model(*example_inputs) + model = convert_pt2e(model, use_reference_representation=True) + self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) + # make sure it runs + pt2e_quant_output = model(*example_inputs) + + # TODO: torchdynamo times out when we do this, we can enable numerical checking + # after that is fixed + model_copy = prepare_pt2e(model_copy, quantizer) # pyre-ignore[6] + # Calibrate + model_copy(*example_inputs) + model_copy = convert_pt2e(model_copy, use_reference_representation=False) + self.checkGraphModuleNodes( + model_copy, expected_node_occurrence=non_ref_node_occurrence + ) + pt2e_quant_output_copy = model_copy(*example_inputs) + + output_tol = None + if fixed_output_tol is not None: + output_tol = fixed_output_tol + else: + idx = 0 + for n in model_copy.graph.nodes: + if ( + n.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ): + idx += 1 + if idx == output_scale_idx: + output_tol = n.args[1] + assert output_tol is not None + + # make sure the result is off by one at most in the quantized integer representation + self.assertTrue( + torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) + <= (2 * output_tol + 1e-5) + ) + + def test_static_linear(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_dynamic_linear(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + fixed_output_tol=1e-4, + ) + + def test_conv2d(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv2d(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(1, 3, 3, 3),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add_relu(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + out = x + y + out = torch.nn.functional.relu(out) + return out + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(out_dtype): 2, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence=ref_node_occurrence, + non_ref_node_occurrence={}, + ) + + def test_maxpool2d(self): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = TestHelperModules.ConvMaxPool2d().eval() + + example_inputs = (torch.randn(1, 2, 2, 2),) + + self._test_representation( + m_eager, + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_qdq_per_channel(self): + """Test representation for quantize_per_channel and dequantize_per_channel op""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + # use per channel quantization for weight + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + M().eval() + + inputs = [ + (torch.randn(1, 5),), + (torch.randn(1, 3, 5),), + (torch.randn(1, 3, 3, 5),), + (torch.randn(1, 3, 3, 3, 5),), + ] + for example_inputs in inputs: + ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 0, + } + non_ref_node_occurrence = { + # quantize_per_channel is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + output_scale_idx=2, + ) + + def test_qdq(self): + """Test representation for quantize and dequantize op""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0, + } + non_ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + ) diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py new file mode 100644 index 0000000000..c9e6d2230a --- /dev/null +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -0,0 +1,1090 @@ +# Owner(s): ["oncall: mobile"] +import copy +import operator + +import torch +import torch._dynamo as torchdynamo +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization import ( + default_dynamic_fake_quant, + default_dynamic_qconfig, + observer, + QConfig, + QConfigMapping, +) +from torch.ao.quantization.backend_config import get_qnnpack_backend_config +from torch.ao.quantization.qconfig import ( + default_per_channel_symmetric_qnnpack_qconfig, + default_symmetric_qnnpack_qconfig, + per_channel_weight_observer_range_neg_127_to_127, + weight_observer_range_neg_127_to_127, +) +from torch.ao.quantization.quantize_fx import ( + _convert_to_reference_decomposed_fx, + convert_to_reference_fx, + prepare_fx, +) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + PT2EQuantizationTestCase, + skip_if_no_torchvision, + skipIfNoQNNPACK, + TestHelperModules, +) +from torch.testing._internal.common_quantized import override_quantized_engine + + +@skipIfNoQNNPACK +class TestXNNPACKQuantizer(PT2EQuantizationTestCase): + def test_conv1d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv1d_with_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + m = TestHelperModules.Conv2dThenConv1d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + ) + + def test_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_3d = (torch.randn(9, 10, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_linear_relu(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.LinearReluModel().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(1, 5),) + example_inputs_3d = (torch.randn(1, 2, 5),) + example_inputs_4d = (torch.randn(1, 2, 3, 5),) + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # There should not be extra quantize_per_tensor or dequantize_per_tensors for relu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], # node_list + False, # executorch_backend_config() does not fuse linear-relu + qconfig_mapping, + ) + + def test_conv_linear_no_permute(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinear(), + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_conv_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinearPermute(), + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_linear_with_dynamic_shape(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_3d = (torch.randn(9, 10, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + m_eager, + example_inputs_3d, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + export_with_dynamic_shape=True, + ) + + def test_obs_sharing_ops(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = TestHelperModules.Conv2dWithObsSharingOps().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.hardtanh.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mean.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_name(self): + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_module_name("sub", quantization_config) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_name_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # This module name has underscores, which can be part of a mangled + # name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + quantizer = XNNPACKQuantizer() + # Set global to no quantization and then per-channel for a specific submodule. + quantizer.set_module_name( + "foo_bar", get_symmetric_quantization_config(is_per_channel=True) + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # The weight tensor should be per-tensor and not per-channel + # for foo_bar. + self.assertEqual(weight_observer_node.op, "call_module") + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual(weight_observer_node.op, "call_module") + count += 1 + + def test_set_module_type(self): + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_module_type(Sub, quantization_config) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_type_case_2(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv3 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.relu = torch.nn.ReLU() + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(3, 16) + + def forward(self, x): + x1 = self.conv(x) + x2 = self.relu(self.conv2(x1) + self.conv3(x1)) + x3 = self.avgpool(x2) + x4 = torch.flatten(x3, 1) + x5 = self.fc(x4) + return x5 + + m = M().eval() + example_inputs = (torch.randn(1, 3, 16, 16),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + # We only want to annotate Linear type + quantizer.set_module_type(torch.nn.Linear, quantization_config) + node_occurrence = { + torch.ops.aten.conv2d.default: 3, + torch.ops.aten.linear.default: 1, + # input and output for the linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # only the linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_propagate_annotation(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = TestHelperModules.Conv2dPropAnnotaton().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + for n in m.graph.nodes: + if n.target in [ + torch.ops.aten.view.default, + torch.ops.aten.hardtanh.default, + ]: + input_act = getattr(m, n.args[0].target) + output_act = getattr(m, next(iter(n.users)).target) + self.assertIs(input_act, output_act) + + m = convert_pt2e(m) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 5, + # note: quantize op for weights are const propagated + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_dynamic_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_dynamic_linear_int4_weight(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + weight_qmin=0, + weight_qmax=15, + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127.with_args( + quant_min=0, quant_max=15 + ), + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_qat_dynamic_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + is_qat=True, + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 2, + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = default_dynamic_fake_quant + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + is_qat=True, + ) + + def test_dynamic_linear_with_conv(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + } + + training_ir_node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # In training IR, the decomposition is different. + # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes + # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes. + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=weight_observer_range_neg_127_to_127, + ) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + training_ir_node_occurrence=training_ir_node_occurrence, + ) + + def test_gru(self): + """this is a test for annotating fp32 GRU so that it produces + q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases, + but we may change the annotation to be more precise in the future + """ + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + with override_quantized_engine("qnnpack"): + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + qconfig_mapping = QConfigMapping().set_object_type( + operator.mul, default_symmetric_qnnpack_qconfig + ) + model_fx = prepare_fx( + model_fx, + qconfig_mapping, + example_inputs, + backend_config=get_qnnpack_backend_config(), + ) + model_fx(*example_inputs) + model_fx = _convert_to_reference_decomposed_fx(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=False + ) + quantizer.set_global(quantization_config) + model_graph = prepare_pt2e(model_graph, quantizer) + model_graph(*example_inputs) + model_graph = convert_pt2e(model_graph) + self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + + def test_linear_gru(self): + """this test is to make sure GRU annotation does not interfere with linear annotation""" + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + self.linear = torch.nn.Linear(2, 2) + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = self.linear(input_tensor) + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + with override_quantized_engine("qnnpack"): + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + qconfig_mapping = ( + QConfigMapping() + .set_object_type(operator.mul, default_symmetric_qnnpack_qconfig) + .set_object_type(torch.nn.Linear, default_symmetric_qnnpack_qconfig) + ) + model_fx = prepare_fx( + model_fx, + qconfig_mapping, + example_inputs, + backend_config=get_qnnpack_backend_config(), + ) + model_fx(*example_inputs) + model_fx = _convert_to_reference_decomposed_fx(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=False + ) + quantizer.set_global(quantization_config) + model_graph = prepare_pt2e(model_graph, quantizer) + model_graph(*example_inputs) + model_graph = convert_pt2e(model_graph) + self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + + def test_add_and_inplace_add(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddInplaceAdd(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_and_inplace_mul(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.MulInplaceMul(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_scalar(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + # TODO torch.ops.quantized_decomposed.dequantize_per_tensor.default: 9, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddMulScalar(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_float32_max(self): + class M(torch.nn.Module): + def forward(self, x): + return x * 3.4028235e38 + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_long(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.t = torch.tensor([100]) + + def forward(self, x): + x = x + self.t + x = x * self.t + return x + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_cat_same_node(self): + """Ensure that concatenating the same node does not cause any unexpected behavior""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.cat([x, x]) + return x + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + +# TODO: express this using self._test_quantizer, add test for inception_v4 +class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase): + @skip_if_no_torchvision + @skipIfNoQNNPACK + def test_resnet18(self): + import torchvision + + with override_quantized_engine("qnnpack"): + example_inputs = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.resnet18().eval() + m_copy = copy.deepcopy(m) + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = prepare_pt2e(m, quantizer) + # checking that we inserted observers correctly for maxpool operator (input and + # output share observer instance) + self.assertEqual( + id(m.activation_post_process_3), id(m.activation_post_process_2) + ) + after_prepare_result = m(*example_inputs) + m = convert_pt2e(m) + + after_quant_result = m(*example_inputs) + + # comparing with existing fx graph mode quantization reference flow + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + backend_config = get_qnnpack_backend_config() + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + after_prepare_result_fx = m_fx(*example_inputs) + m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) + + after_quant_result_fx = m_fx(*example_inputs) + + # the result matches exactly after prepare + # Note: this currently will always be true since we are inserting observers + # the check becomes useful when we add qat examples + # but we can still manully inspect the printed observers to make sure + # it matches + self.assertEqual(after_prepare_result, after_prepare_result_fx) + self.assertEqual( + compute_sqnr(after_prepare_result, after_prepare_result_fx), + torch.tensor(float("inf")), + ) + # there are slight differences after convert due to different implementations + # of quant/dequant + self.assertTrue( + torch.max(after_quant_result - after_quant_result_fx) < 1e-1 + ) + self.assertTrue( + compute_sqnr(after_quant_result, after_quant_result_fx) > 35 + ) diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index ea9217e04a..f11075cf26 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -16,6 +16,10 @@ XnnpackDynamicallyQuantizedPartitioner, XnnpackPartitioner, ) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.backends.xnnpack.utils.configs import ( get_transform_passes, get_xnnpack_edge_compile_config, @@ -68,10 +72,6 @@ ) from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) from torch.export import export_for_training from torch.testing import FileCheck diff --git a/backends/xnnpack/test/tester/TARGETS b/backends/xnnpack/test/tester/TARGETS index 21fc329cef..0ba34cc0bf 100644 --- a/backends/xnnpack/test/tester/TARGETS +++ b/backends/xnnpack/test/tester/TARGETS @@ -16,6 +16,8 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils", "//executorch/backends/xnnpack/utils:xnnpack_utils", "//executorch/devtools/visualization:visualization", "//executorch/exir:lib", diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 8510a0e438..dc885135bb 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -43,15 +43,17 @@ logger.warning(f"{e=}") pass +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, +) from executorch.exir.program._program import _transform from torch._export.pass_base import PassType from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.quantizer import Quantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig from torch.export import export, ExportedProgram from torch.testing import FileCheck from torch.utils._pytree import tree_flatten diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index f0de7cc9c9..fd5dd3ba3a 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -620,7 +620,7 @@ quantized operators where available. from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( DuplicateDynamicQuantChainPass, ) -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) diff --git a/docs/source/native-delegates-executorch-xnnpack-delegate.md b/docs/source/native-delegates-executorch-xnnpack-delegate.md index b21f4c4d44..de54de7706 100644 --- a/docs/source/native-delegates-executorch-xnnpack-delegate.md +++ b/docs/source/native-delegates-executorch-xnnpack-delegate.md @@ -80,7 +80,7 @@ The XNNPACK delegate can also be used as a backend to execute symmetrically quan ### Configuring the XNNPACKQuantizer ```python -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index c81f61878c..1c71a6ba80 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -86,7 +86,7 @@ sample_inputs = (torch.randn(1, 3, 224, 224), ) mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index 34839a9c1f..86a816f143 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -194,11 +194,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: print("Pre-Autograd ATen Dialect Graph") print(pre_autograd_aten_dialect) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) # type: ignore[arg-type] diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index dabb07e61c..97a724c368 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -12,6 +12,10 @@ ConfigPrecisionType, ) from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.examples.models.llama.export_llama_lib import ( build_args_parser, get_quantizer_and_quant_params, @@ -41,10 +45,6 @@ from executorch.extension.llm.export.builder import DType, LLMEdgeManager from executorch.extension.llm.tokenizer.tokenizer import Tokenizer from executorch.util.activation_memory_profiler import generate_memory_trace -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) from torch.export import Dim from torch.nn.attention import SDPBackend diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index 305b83457d..8fa948e7dc 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -13,14 +13,14 @@ DuplicateDynamicQuantChainPass, ) from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config -from executorch.exir import to_edge -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import to_edge +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training from transformers import Phi3ForCausalLM diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index 141b0701d0..3e30c23921 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -12,6 +12,10 @@ import time import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.exir import EdgeCompileConfig from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.export_util.utils import export_to_edge, save_pte_program @@ -26,10 +30,6 @@ prepare_fx, ) from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) from ...models import MODEL_NAME_TO_MODEL from ...models.model_factory import EagerModelFactory diff --git a/examples/xnnpack/quantization/utils.py b/examples/xnnpack/quantization/utils.py index 6f8aa3913f..de59c076a8 100644 --- a/examples/xnnpack/quantization/utils.py +++ b/examples/xnnpack/quantization/utils.py @@ -6,12 +6,13 @@ import logging -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + def quantize(model, example_inputs): """This is the official recommended flow for quantization in pytorch 2.0 export""" diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 1fefcae721..13253b0edc 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -222,6 +222,8 @@ python_unittest( "//executorch/exir/program:program", "//executorch/extension/pybindings:portable_lib", # @manual "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils", ], ) @@ -312,6 +314,7 @@ python_unittest( "//executorch/exir:lib", "//executorch/exir/passes:quant_fusion_pass", "//executorch/exir/passes:spec_prop_pass", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//pytorch/vision:torchvision", ], ) @@ -459,5 +462,6 @@ python_unittest( "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/passes:quantize_io_pass", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", ], ) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 5691bf870e..0583b6d163 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -17,6 +17,13 @@ import executorch.exir.memory_planning # noqa import torch from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, +) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -67,11 +74,6 @@ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import QuantizationSpec -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter diff --git a/exir/tests/test_quantization.py b/exir/tests/test_quantization.py index 148d7f4f9d..61e3410186 100644 --- a/exir/tests/test_quantization.py +++ b/exir/tests/test_quantization.py @@ -10,6 +10,11 @@ import torch import torchvision + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.spec_prop_pass import SpecPropPass @@ -23,11 +28,6 @@ convert_pt2e, prepare_pt2e, ) - -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) from torch.export import export from torch.testing import FileCheck from torch.testing._internal.common_quantized import override_quantized_engine diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py index b3899b008c..aab941b538 100644 --- a/exir/tests/test_quantize_io_pass.py +++ b/exir/tests/test_quantize_io_pass.py @@ -8,6 +8,11 @@ import unittest import torch + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from executorch.exir.passes.quantize_io_pass import ( get_config_method_name, @@ -16,11 +21,6 @@ ) from executorch.exir.tensor import get_scalar_type from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e - -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) from torch.testing import FileCheck op_str = { diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index bcfc130add..7303d2d66a 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -33,6 +33,7 @@ runtime.python_library( "//executorch/backends/vulkan/partitioner:vulkan_partitioner", "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//executorch/exir:delegate", "//executorch/exir:lib", "//executorch/exir/backend:backend_details", diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 3a9eebd2c3..55e530553f 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -11,14 +11,14 @@ from typing import List, Optional import torch - -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT)