From 727fb8edad3c32eb63ee5f5336f6011922892d27 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 3 Dec 2024 11:35:19 +0000 Subject: [PATCH] 2024-12-03 nightly release (4f47cc9ce7dc82b2a09918695f98dbc9ee047515) --- .github/scripts/extract_benchmark_results.py | 153 ++++-- .github/workflows/android-perf.yml | 24 +- .github/workflows/apple-perf.yml | 24 +- CMakeLists.txt | 31 +- backends/arm/_passes/arm_pass_manager.py | 4 + .../_passes/unsqueeze_before_repeat_pass.py | 62 +++ backends/arm/operators/op_repeat.py | 31 +- backends/arm/test/common.py | 4 +- backends/arm/test/ops/test_repeat.py | 11 +- .../test_unsqueeze_before_repeat_pass.py | 74 +++ backends/cadence/aot/compiler.py | 26 +- backends/cadence/hifi/kernels/kernels.h | 6 +- .../nnlib/xa_nn_elm_minimum_maximum_f32.c | 14 +- backends/test/README.md | 0 backends/test/TARGETS | 8 + backends/test/multi_method_delegate_test.cpp | 164 ++++++ backends/test/targets.bzl | 29 ++ .../partition/config/generic_node_configs.py | 15 + backends/xnnpack/test/ops/bilinear2d.py | 21 + .../LlmBenchmarkRunner.java | 2 +- examples/models/llama/TARGETS | 2 + examples/models/llama/eval_llama.py | 11 +- examples/models/llama/eval_llama_lib.py | 64 +++ examples/models/llama/export_llama_lib.py | 7 + examples/models/llama/model.py | 19 + examples/models/llama/runner/eager.py | 6 +- examples/models/llama/runner/generation.py | 27 +- .../source_transformation/attention_sink.py | 237 ++++++++- .../test_attention_sink.py | 473 +++++++++++++++++- exir/emit/test/test_emit.py | 66 +-- exir/memory_planning.py | 11 +- exir/program/TARGETS | 3 +- exir/program/_program.py | 22 + exir/tests/test_joint_graph.py | 26 +- exir/tests/test_memory_planning.py | 65 ++- exir/tests/test_remove_view_copy.py | 24 +- .../pytorch/minibench/BenchmarkMetric.java | 2 +- extension/llm/custom_ops/CMakeLists.txt | 12 +- kernels/quantized/CMakeLists.txt | 16 +- kernels/quantized/cpu/op_dequantize.cpp | 209 +++++++- kernels/quantized/test/op_dequantize_test.cpp | 50 +- .../executorch/build/runtime_wrapper.bzl | 2 +- test/models/export_delegated_program.py | 88 +++- test/models/targets.bzl | 24 + 44 files changed, 1878 insertions(+), 291 deletions(-) create mode 100644 backends/arm/_passes/unsqueeze_before_repeat_pass.py create mode 100644 backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py create mode 100644 backends/test/README.md create mode 100644 backends/test/TARGETS create mode 100644 backends/test/multi_method_delegate_test.cpp create mode 100644 backends/test/targets.bzl diff --git a/.github/scripts/extract_benchmark_results.py b/.github/scripts/extract_benchmark_results.py index 113ff2a420..bfa6c06312 100755 --- a/.github/scripts/extract_benchmark_results.py +++ b/.github/scripts/extract_benchmark_results.py @@ -310,6 +310,7 @@ def transform( workflow_run_attempt: int, job_name: str, job_id: int, + schema_version: str, ) -> List: """ Transform the benchmark results into the format writable into the benchmark database @@ -319,45 +320,91 @@ def transform( for r in benchmark_results: r["deviceInfo"]["device"] = job_name - # TODO (huydhn): This is the current schema of the database oss_ci_benchmark_v2, - # and I'm trying to fit ET benchmark results into it, which is kind of awkward. - # However, the schema is going to be updated soon - return [ - { - # GH-info to identify where the benchmark is run - "repo": repo, - "head_branch": head_branch, - "workflow_id": workflow_run_id, - "run_attempt": workflow_run_attempt, - "job_id": job_id, - # The model - "name": f"{r['benchmarkModel']['name']} {r['benchmarkModel'].get('backend', '')}".strip(), - "dtype": ( - r["benchmarkModel"]["quantization"] - if r["benchmarkModel"]["quantization"] - else "unknown" - ), - # The metric value - "metric": r["metric"], - "actual": r["actualValue"], - "target": r["targetValue"], - # The device - "device": r["deviceInfo"]["device"], - "arch": r["deviceInfo"].get("os", ""), - # Not used here, just set it to something unique here - "filename": workflow_name, - "test_name": app_type, - "runner": job_name, - } - for r in benchmark_results - ] + if schema_version == "v2": + # TODO (huydhn): Clean up this branch after ExecuTorch dashboard migrates to v3 + return [ + { + # GH-info to identify where the benchmark is run + "repo": repo, + "head_branch": head_branch, + "workflow_id": workflow_run_id, + "run_attempt": workflow_run_attempt, + "job_id": job_id, + # The model + "name": f"{r['benchmarkModel']['name']} {r['benchmarkModel'].get('backend', '')}".strip(), + "dtype": ( + r["benchmarkModel"]["quantization"] + if r["benchmarkModel"]["quantization"] + else "unknown" + ), + # The metric value + "metric": r["metric"], + "actual": r["actualValue"], + "target": r["targetValue"], + # The device + "device": r["deviceInfo"]["device"], + "arch": r["deviceInfo"].get("os", ""), + # Not used here, just set it to something unique here + "filename": workflow_name, + "test_name": app_type, + "runner": job_name, + } + for r in benchmark_results + ] + elif schema_version == "v3": + quantization = ( + r["benchmarkModel"]["quantization"] + if r["benchmarkModel"]["quantization"] + else "unknown" + ) + # From https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + return [ + { + "benchmark": { + "name": "ExecuTorch", + "mode": "inference", + "dtype": quantization, + "extra_info": { + "app_type": app_type, + }, + }, + "model": { + "name": r["benchmarkModel"]["name"], + "type": "OSS model", + "backend": r["benchmarkModel"].get("backend", ""), + "extra_info": { + "quantization": quantization, + }, + }, + "metric": { + "name": r["metric"], + "benchmark_values": [r["actualValue"]], + "target_value": r["targetValue"], + "extra_info": { + "method": r.get("method", ""), + }, + }, + "runners": [ + { + "name": r["deviceInfo"]["device"], + "type": r["deviceInfo"]["os"], + "avail_mem_in_gb": r["deviceInfo"].get("availMem", ""), + "total_mem_in_gb": r["deviceInfo"].get("totalMem", ""), + } + ], + } + for r in benchmark_results + ] def main() -> None: args = parse_args() - # Across all devices - all_benchmark_results = [] + # Across all devices, keeping both schemas for now until ExecuTorch dashboard migrates to v3 + all_benchmark_results = { + "v2": [], + "v3": [], + } with open(args.artifacts) as f: for artifact in json.load(f): @@ -384,23 +431,31 @@ def main() -> None: ) if benchmark_results: - benchmark_results = transform( - app_type, - benchmark_results, - args.repo, - args.head_branch, - args.workflow_name, - args.workflow_run_id, - args.workflow_run_attempt, - job_name, - extract_job_id(args.artifacts), - ) - all_benchmark_results.extend(benchmark_results) + for schema in all_benchmark_results.keys(): + results = transform( + app_type, + benchmark_results, + args.repo, + args.head_branch, + args.workflow_name, + args.workflow_run_id, + args.workflow_run_attempt, + job_name, + extract_job_id(args.artifacts), + schema, + ) + all_benchmark_results[schema].extend(results) + + for schema in all_benchmark_results.keys(): + if not all_benchmark_results.get(schema): + continue + + output_dir = os.path.join(args.output_dir, schema) + os.mkdir(output_dir) - if all_benchmark_results: output_file = os.path.basename(args.artifacts) - with open(f"{args.output_dir}/{output_file}", "w") as f: - json.dump(all_benchmark_results, f) + with open(f"{output_dir}/{output_file}", "w") as f: + json.dump(all_benchmark_results[schema], f) if __name__ == "__main__": diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index 93ec4fe4e7..76e5f5a1b9 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -298,15 +298,25 @@ jobs: --workflow-run-attempt ${{ github.run_attempt }} done - ls -lah benchmark-results - - for BENCHMARK_RESULTS in benchmark-results/*.json; do - cat "${BENCHMARK_RESULTS}" - echo + for SCHEMA in v2 v3; do + for BENCHMARK_RESULTS in benchmark-results/"${SCHEMA}"/*.json; do + cat "${BENCHMARK_RESULTS}" + echo + done done - - name: Upload the benchmark results + # TODO (huydhn): Remove v2 schema once the benchmark dashboard finishes the migration + - name: Upload the benchmark results (v2) + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: benchmark-results/v2 + dry-run: false + schema-version: v2 + + - name: Upload the benchmark results (v3) uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main with: - benchmark-results-dir: 'benchmark-results' + benchmark-results-dir: benchmark-results/v3 dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index 7de308b1a6..f14e40b942 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -372,15 +372,25 @@ jobs: --workflow-run-attempt ${{ github.run_attempt }} done - ls -lah benchmark-results - - for BENCHMARK_RESULTS in benchmark-results/*.json; do - cat "${BENCHMARK_RESULTS}" - echo + for SCHEMA in v2 v3; do + for BENCHMARK_RESULTS in benchmark-results/"${SCHEMA}"/*.json; do + cat "${BENCHMARK_RESULTS}" + echo + done done - - name: Upload the benchmark results + # TODO (huydhn): Remove v2 schema once the benchmark dashboard finishes the migration + - name: Upload the benchmark results (v2) + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: benchmark-results/v2 + dry-run: false + schema-version: v2 + + - name: Upload the benchmark results (v3) uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main with: - benchmark-results-dir: 'benchmark-results' + benchmark-results-dir: benchmark-results/v3 dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/CMakeLists.txt b/CMakeLists.txt index f960dced37..3b242b1ded 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,21 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Debug) endif() +# Setup RPATH. +# See https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling +# Use separate rpaths during build and install phases +set(CMAKE_SKIP_BUILD_RPATH OFF) +# Don't use the install-rpath during the build phase +set(CMAKE_BUILD_WITH_INSTALL_RPATH ON) +# Automatically add all linked folders that are NOT in the build directory to +# the rpath (per library?) +# TODO: Doesn't work for us right now because we are not installing .so's into the +# correct locations. For example we have libcustom_ops_aot_lib.so depending on +# _portable_lib.so, which was eventually put under /executorch/extension/pybindings/ +# but this rpath is not automatically added because at build time it seems `portable_lib` +# is being built under the same directory, so no extra rpath is being added. To +# properly fix this we need to install `portable_lib` into the correct path. +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH ON) # ------------------------------ OPTIONS ------------------------------------- # WARNING: Please don't add example specific options in this CMakeLists.txt. # Instead please use `find_package(executorch REQUIRED)` in the example @@ -682,22 +697,6 @@ if(EXECUTORCH_BUILD_PTHREADPOOL endif() if(EXECUTORCH_BUILD_PYBIND) - # Setup RPATH. - # See https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling - if(APPLE) - set(CMAKE_MACOSX_RPATH ON) - set(_rpath_portable_origin "@loader_path") - else() - set(_rpath_portable_origin $ORIGIN) - endif(APPLE) - # Use separate rpaths during build and install phases - set(CMAKE_SKIP_BUILD_RPATH FALSE) - # Don't use the install-rpath during the build phase - set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE) - set(CMAKE_INSTALL_RPATH "${_rpath_portable_origin}") - # Automatically add all linked folders that are NOT in the build directory to - # the rpath (per library?) - set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pybind11) if(NOT EXECUTORCH_BUILD_EXTENSION_DATA_LOADER) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 1e2b26ef64..25811d077b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -41,6 +41,9 @@ ScalarsToAttributePass, ) from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass +from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( + UnsqueezeBeforeRepeatPass, +) from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( UnsqueezeScalarPlaceholdersPass, ) @@ -66,6 +69,7 @@ def transform_to_backend_pipeline( self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(DecomposeLayerNormPass()) + self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(DecomposeVarPass()) self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeMeanDimPass()) diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py new file mode 100644 index 0000000000..01983baa9a --- /dev/null +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -0,0 +1,62 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-unsafe +import torch +import torch.fx +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class UnsqueezeBeforeRepeatPass(ExportPass): + """ + A TOSA TILE op only supports rank(in) == rank(out). + To support Pytorch's repeat which can also add dimensions, + we add an explicit view op before which adds the new dimensions. + New dimensions are appendend at the front, see + https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html + + Original: + repeat(multiples) + After pass: + view(shape = [1]*num_new_dims + old_shape) + repeat(multiples) + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified_graph = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.repeat.default: + continue + + old_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape) + old_rank = len(old_shape) + multiples = node.args[1] + new_rank = len(multiples) + if old_rank == new_rank: + continue + + num_new_dims = new_rank - old_rank + new_shape = [1] * num_new_dims + old_shape + + with graph_module.graph.inserting_before(node): + view_node = create_node( + graph_module.graph, + exir_ops.edge.aten.view_copy.default, + (node.all_input_nodes[0], new_shape), + ) + node.replace_input_with(node.all_input_nodes[0], view_node) + modified_graph = True + + if modified_graph: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified_graph) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 20de9e0846..1e4dc4e23c 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -32,37 +32,8 @@ def define_node( is_quant_node: bool, ) -> None: - item_name = inputs[0].name - shape = inputs[0].shape - rank = len(shape) multiples = inputs[1].special - new_rank = len(multiples) - - assert new_rank >= rank - - # TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first. - if new_rank > rank: - # Add length 1 dimensions to shape to match multiples - num_new_dims = new_rank - rank - expanded_shape = tuple( - 1 if i < num_new_dims else shape[i - num_new_dims] - for i in range(new_rank) - ) - expanded_shape = tosa_shape(expanded_shape, output.dim_order) - dtype = ( - ts.dtype_str_to_val("INT8") - if is_quant_node - else ts.dtype_str_to_val("FP32") - ) - - rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype) - rescale_attr = ts.TosaSerializerAttribute() - rescale_attr.ReshapeAttribute(expanded_shape) - tosa_graph.addOperator( - TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr - ) - item_name = rescale_out.name attr = ts.TosaSerializerAttribute() attr.TileAttribute(tosa_shape(multiples, output.dim_order)) - tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr) + tosa_graph.addOperator(TosaOp.Op().TILE, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 48214a48a7..d755ffc8b1 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -11,9 +11,9 @@ from datetime import datetime from pathlib import Path -from conftest import is_option_enabled - from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder + +from executorch.backends.arm.test.conftest import is_option_enabled from executorch.exir.backend.compile_spec_schema import CompileSpec diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 20c57ba749..de555e7c80 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -37,6 +37,7 @@ class Repeat(torch.nn.Module): (torch.randn(3), (2, 2)), (torch.randn(3), (1, 2, 3)), (torch.randn((3, 3)), (2, 2, 2)), + (torch.randn((3, 3, 3)), (2, 1, 2, 4)), ] def forward(self, x: torch.Tensor, multiples: Sequence): @@ -106,12 +107,20 @@ def test_repeat_tosa_MI(self, test_input, multiples): def test_repeat_tosa_BI(self, test_input, multiples): self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples)) - @parameterized.expand(Repeat.test_parameters) + @parameterized.expand(Repeat.test_parameters[:-1]) def test_repeat_u55_BI(self, test_input, multiples): self._test_repeat_ethosu_pipeline( common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples) ) + # Final test requires transpose which is not supported on u55. + @parameterized.expand(Repeat.test_parameters[-1:]) + @unittest.expectedFailure + def test_repeat_u55_BI_xfails(self, test_input, multiples): + self._test_repeat_ethosu_pipeline( + common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples) + ) + @parameterized.expand(Repeat.test_parameters) def test_repeat_u85_BI(self, test_input, multiples): self._test_repeat_ethosu_pipeline( diff --git a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py new file mode 100644 index 0000000000..d249c18ec8 --- /dev/null +++ b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py @@ -0,0 +1,74 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import unittest + +import torch +from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( + UnsqueezeBeforeRepeatPass, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.xnnpack.test.tester.tester import RunPasses + + +class Repeat(torch.nn.Module): + """ + Basic repeat model. + """ + + def forward(self, x: torch.Tensor): + return x.repeat(2, 2, 2, 2) + + +class TestUnsqueezeBeforeRepeatPass(unittest.TestCase): + def test_tosa_MI_insert_view(self): + """ + When rank(input) != number of repeated dimensions (=4 in Repeat module), + insert view. + """ + module = Repeat() + inputs = (torch.rand((2, 3, 4)),) + test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) + ( + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .to_edge() + .check(["aten_repeat_default"]) + .check_not(["aten_view_copy_default"]) + .run_passes(test_pass_stage) + .check(["aten_repeat_default", "aten_view_copy_default"]) + ) + ) + + def test_tosa_MI_dont_insert_view(self): + """ + When rank(input) == number of repeated dimensions (=4 in Repeat module), + DON'T insert view. + """ + module = Repeat() + inputs = (torch.rand((2, 3, 4, 1)),) + test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) + ( + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .to_edge() + .check(["aten_repeat_default"]) + .check_not(["aten_view_copy_default"]) + .run_passes(test_pass_stage) + .check(["aten_repeat_default"]) + .check_not(["aten_view_copy_default"]) + ) + ) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 937e3e39bc..6b3a023181 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -28,6 +28,7 @@ to_edge, ) from executorch.exir.pass_base import PassResult +from torch._inductor.decomposition import remove_decompositions from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -58,16 +59,33 @@ def convert_pt2( Returns a GraphModule with the converted model. """ + # Get default decompositions + decomp_table = torch.export.default_decompositions() + # Select ops to keep + ops_to_keep = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.layer_norm.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + ] + # Remove decompositions for the ops we want to keep + # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any + remove_decompositions(decomp_table, ops_to_keep) # Export with dynamo - model_gm = torch.export.export_for_training(model, inputs).module() + model_gm = ( + torch.export.export_for_training(model, inputs) + .run_decompositions(decomp_table) + .module() + ) - if model_gm_has_SDPA(model_gm): # pyre-fixme[6] + if model_gm_has_SDPA(model_gm): # Decompose SDPA - DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] + DecomposeScaledDotProductAttention(False)(model_gm) # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 # for details). - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) assert result is not None model_gm = result.graph_module diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 10927adc2a..2eabfb9507 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -92,9 +92,9 @@ extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32( const WORD32* const p_inp2_shape); extern "C" void xa_nn_elm_pow_f32( - FLOAT32* restrict z, - const FLOAT32* restrict x, - const FLOAT32* restrict y, + FLOAT32* __restrict__ z, + const FLOAT32* __restrict__ x, + const FLOAT32* __restrict__ y, WORD32 N); extern "C" WORD32 xa_nn_elm_where_f32xf32_f32( diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c index 3af93fc00c..61d7bcf781 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c @@ -19,12 +19,12 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ******************************************************************************/ -#include "nnlib-hifi4/xa_nnlib/include/xa_type_def.h" -#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_common_fpu.h" -#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nn_common.h" -#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h" -#include "nnlib-hifi4/xa_nnlib/algo/kernels/basic/hifi4/xa_nn_basic_state.h" -#include "nnlib-hifi4/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h" +#include "xa_type_def.h" +#include "xa_nnlib_common_fpu.h" +#include "xa_nn_common.h" +#include "xa_nnlib_err_chk.h" +// #include "xa_nn_basic_state.h" +#include "xa_nnlib_kernels_api.h" #if !HAVE_VFPU DISCARD_FUN_FOR_NONVOID_RETURN( @@ -844,4 +844,4 @@ WORD32 xa_nn_elm_minimum_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out, } return 0; } -#endif \ No newline at end of file +#endif diff --git a/backends/test/README.md b/backends/test/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backends/test/TARGETS b/backends/test/TARGETS new file mode 100644 index 0000000000..a6c52d105f --- /dev/null +++ b/backends/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets(is_fbcode = True) diff --git a/backends/test/multi_method_delegate_test.cpp b/backends/test/multi_method_delegate_test.cpp new file mode 100644 index 0000000000..e24585434c --- /dev/null +++ b/backends/test/multi_method_delegate_test.cpp @@ -0,0 +1,164 @@ +#include + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +using executorch::extension::FileDataLoader; +using executorch::extension::MallocMemoryAllocator; +using executorch::extension::prepare_input_tensors; + +/* + * Backend agnostic base class. + */ +class ETPTEMethodRunBaseTest : public ::testing::Test { + protected: + void SetUp() override { + executorch::runtime::runtime_init(); + } + + // Runs the PTE e2e without using outside resources. + // This will run in a single thread. + // TODO(T208989128) - Add Synchronizer based run method. + void run( + const int id, + const std::string& kTestPTEPath, + const std::string& kMethodName, + std::atomic& count) const { + Result loader = FileDataLoader::from(kTestPTEPath.c_str()); + ASSERT_EQ(loader.error(), Error::Ok); + + Result program = Program::load( + &loader.get(), Program::Verification::InternalConsistency); + ASSERT_EQ(program.error(), Error::Ok); + + Result method_meta = program->method_meta(kMethodName.c_str()); + ASSERT_EQ(method_meta.error(), Error::Ok); + + const size_t num_memory_planned_buffers = + method_meta->num_memory_planned_buffers(); + + std::vector> planned_buffers; + std::vector> planned_spans; + for (size_t i = 0; i < num_memory_planned_buffers; ++i) { + const size_t buffer_size = + static_cast(method_meta->memory_planned_buffer_size(i).get()); + planned_buffers.push_back(std::make_unique(buffer_size)); + planned_spans.push_back({planned_buffers.back().get(), buffer_size}); + } + + auto method_allocator = std::make_unique(); + auto memory_planned_allocator = std::make_unique( + Span(planned_spans.data(), planned_spans.size())); + auto temp_allocator = std::make_unique(); + + auto memory_manager = std::make_unique( + method_allocator.get(), + memory_planned_allocator.get(), + temp_allocator.get()); + + Result method = + program->load_method(kMethodName.c_str(), memory_manager.get()); + ASSERT_EQ(method.error(), Error::Ok); + + auto inputs = prepare_input_tensors(*method); + ASSERT_EQ(inputs.error(), Error::Ok); + + Error err = method->execute(); + for (int i = 0; i < id % 7; i++) { + err = method->execute(); + ASSERT_EQ(err, Error::Ok); + } + + std::vector outputs(method->outputs_size()); + err = method->get_outputs(outputs.data(), outputs.size()); + ET_CHECK(err == Error::Ok); + // TODO(T208989129) - Add validation of outputs using bundled + // inputs/outputs. + count++; + } +}; + +class XNNPACKMultiDelegateTest : public ETPTEMethodRunBaseTest { + protected: + std::string kTestPTE1Path, kTestPTE2Path; + std::string kMethodName; + int num_threads; + + void SetUp() override { + ETPTEMethodRunBaseTest::SetUp(); + const char* pte1_path = + std::getenv("ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH"); + if (pte1_path == nullptr) { + std::cerr << "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH is not set" + << std::endl; + FAIL(); + } + kTestPTE1Path = std::string(pte1_path); + + const char* pte2_path = + std::getenv("ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH"); + if (pte1_path == nullptr) { + std::cerr << "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH is not set" + << std::endl; + FAIL(); + } + kTestPTE2Path = std::string(pte2_path); + + num_threads = 40; + kMethodName = "forward"; + } +}; + +// This test is to validate the assumption that the delegate is thread safe. +// That includes the following: +// 1. The delegate can be initilized by multiple threads in parallel. +// 2. The delegate can be executed by multiple threads in parallel. +// 3. The delegate can be destroyed by multiple threads in parallel. +// Regardless of the underlying implementation of the delegate. +// This is particularly important when we have shared resources across +// delegate instances through a singleton backend instance. +TEST_F(XNNPACKMultiDelegateTest, MultipleThreads) { + ASSERT_NE(kTestPTE1Path.size(), 0); + ASSERT_NE(kTestPTE2Path.size(), 0); + ASSERT_NE(num_threads, 0); + ASSERT_NE(kMethodName.size(), 0); + + std::vector threads(num_threads); + std::atomic count{0}; + + for (int i = 0; i < num_threads; i++) { + threads[i] = std::thread([&, i]() { + run(i, i % 7 ? kTestPTE1Path : kTestPTE2Path, kMethodName, count); + }); + } + for (int i = 0; i < num_threads; i++) { + threads[i].join(); + } + ASSERT_EQ(count, num_threads); +} + +// TODO(T208989291): Add more tests here. For example, +// - PTEs with multiple methods +// - PTEs with proucer and consumer relationships in different threads +// - PTEs with more than 1 delegate instances +// - PTEs with different type of delegate instances +// - Add more patterns of delegate initialization and execution diff --git a/backends/test/targets.bzl b/backends/test/targets.bzl new file mode 100644 index 0000000000..6588c57fcc --- /dev/null +++ b/backends/test/targets.bzl @@ -0,0 +1,29 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(is_fbcode = False): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + if not runtime.is_oss and is_fbcode: + modules_env = { + "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleAddLarge.pte])", + "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleSubLarge.pte])", + } + + runtime.cxx_test( + name = "multi_method_delegate_test", + srcs = [ + "multi_method_delegate_test.cpp", + ], + deps = [ + "//executorch/runtime/executor:program", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/memory_allocator:malloc_memory_allocator", + "//executorch/kernels/portable:generated_lib", + "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/extension/runner_util:inputs", + ], + env = modules_env, + ) diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index b95d7c5b89..f08b8ccb3c 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -303,6 +303,21 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: class UpsampleBilinear2dConfig(GenericNodePartitionerConfig): target_name = "upsample_bilinear2d.vec" + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_resize_bilinear does not support dynamic output sizes + """ + if not self.check_common_constraints(node, ep): + return False + + is_output_dynamic = "val" in node.meta and any( + isinstance(d, torch.SymInt) for d in node.meta["val"].shape + ) + if is_output_dynamic: + why(node, reason="dynamic output sizes are not supported") + return False + return True + def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/test/ops/bilinear2d.py b/backends/xnnpack/test/ops/bilinear2d.py index bf89e2196f..6a19476365 100644 --- a/backends/xnnpack/test/ops/bilinear2d.py +++ b/backends/xnnpack/test/ops/bilinear2d.py @@ -8,6 +8,7 @@ import torch from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.test.tester.tester import Export class TestUpsampleBilinear2d(unittest.TestCase): @@ -118,3 +119,23 @@ def test_fp32_static_resize_bilinear2d_antialiased(self): ) .check_not(["torch.ops.higher_order.executorch_call_delegate"]) ) + + def test_fp32_bilinear2d_dynamic_bilinear2d_not_partitioned(self): + """ + Verify that upsample_bilinear2d ops with dynamic output sizes are not partitioned. + """ + example_inputs = (torch.randn(2, 3, 4, 5),) + dynamic_shapes = { + "x": { + 2: torch.export.Dim("h", min=1, max=10), + 3: torch.export.Dim("w", min=1, max=12), + } + } + ( + Tester(self.StaticResizeBilinear2dModule(), example_inputs) + .export(Export(dynamic_shapes)) + .to_edge_transform_and_lower() + # NOTE The decomposition is partially delegated. This will need to be replaced + # with the aten upsample op once decomp is removed. + .check("executorch_exir_dialects_edge__ops_aten_index_Tensor") + ) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java index 7236fe317b..8c2d60252a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java @@ -187,7 +187,7 @@ public BenchmarkMetric( // the .pte model itself instead of parsing its name public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { final Matcher m = - Pattern.compile("(?\\w+)_(?\\w+)_(?\\w+)").matcher(model); + Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); if (m.matches()) { return new BenchmarkMetric.BenchmarkModel( m.group("name"), m.group("backend"), m.group("quantization")); diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 284520d4d5..445bcd673b 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -150,6 +150,8 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + "fbsource//third-party/pypi/tqdm:tqdm", + "fbsource//third-party/pypi/datasets:datasets", "fbsource//third-party/pypi/lm-eval:lm-eval", "fbsource//third-party/pypi/tiktoken:tiktoken", ":export_library", diff --git a/examples/models/llama/eval_llama.py b/examples/models/llama/eval_llama.py index 09157789bd..7c959d08b9 100644 --- a/examples/models/llama/eval_llama.py +++ b/examples/models/llama/eval_llama.py @@ -10,7 +10,11 @@ import torch -from .eval_llama_lib import build_args_parser, eval_llama +from .eval_llama_lib import ( + build_args_parser, + eval_llama, + eval_llama_with_attention_sink, +) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -24,7 +28,10 @@ def main() -> None: args = parser.parse_args() # Overrides this arg, because evaluation requires full logits. args.generate_full_logits = True - eval_llama(modelname, args) # pyre-ignore + if args.use_attention_sink: + eval_llama_with_attention_sink(modelname, args) # pyre-ignore + else: + eval_llama(modelname, args) # pyre-ignore if __name__ == "__main__": diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index dd01365ba5..a7f0f88cd9 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -10,6 +10,8 @@ from typing import Optional, Union import torch + +from datasets import load_dataset from executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params, ) @@ -21,6 +23,8 @@ ) from executorch.extension.llm.tokenizer.utils import get_tokenizer from lm_eval.evaluator import simple_evaluate +from torch.nn import CrossEntropyLoss +from tqdm import tqdm from .evaluate.eager_eval import EagerEvalWrapper @@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser: help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) + # Set of parameters secpific to AttentionSink. + parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) + return parser @@ -309,3 +316,60 @@ def eval_llama( for task, res in eval_results["results"].items(): print(f"{task}: {res}") + + +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): + """ + Evaluate the model's perplexity when AttentionSink is enabled. + + This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + """ + assert args.use_attention_sink is not None # pyre-ignore [16] + assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] + attention_sink_params = args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + + assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] + + device = "cuda" if torch.cuda.is_available() else "cpu" + manager: LLMEdgeManager = _prepare_for_llama_export(args) + model = manager.model.eval().to(device=device) + tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] + + eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + nlls = [] + loss_fn = CrossEntropyLoss(reduction="none") + progress_bar = tqdm(total=args.attention_sink_eval_tokens) + input_pos = 0 + while input_pos < args.attention_sink_eval_tokens: + for text in eval_data["text"]: # pyre-ignore [16] + tokens = tokenizer.encode(text, bos=False, eos=False) + if len(tokens) <= 0: + continue + with torch.no_grad(): + num_tokens = min( + len(tokens) - 1, args.attention_sink_eval_tokens - input_pos + ) + logits = model( + torch.tensor( + [tokens[:num_tokens]], dtype=torch.int64, device=device + ), + torch.tensor([input_pos], dtype=torch.int64, device=device), + ).squeeze(dim=0) + neg_log_likelihood = loss_fn( + logits, + torch.tensor( + [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device + ).view(-1), + ) + nlls.append(neg_log_likelihood) + input_pos += num_tokens + progress_bar.update(num_tokens) + if input_pos >= args.attention_sink_eval_tokens: + break + ppl = torch.exp(torch.cat(nlls).mean()) + print(f"Perplexity: {ppl.item()}") + return ppl.item() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 9a290968a3..ea4296cc52 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="type of embedding quantization for pre-quantized checkpoint, ',', e.g., '8,1024'.", ) + parser.add_argument( + "--use_attention_sink", + default=None, + type=str, + help="Use attention sink to have fluent multi-round conversation. ',,', e.g., '4,2044,1024'.", + ) + parser.add_argument( "--output_prune_map", default=None, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 0f83e404a3..2385aba6d5 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -201,6 +201,25 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) + if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + from .source_transformation.attention_sink import enable_attention_sink + + attention_sink_params = self.args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + eviction_batch_size = int(attention_sink_params[2]) + + assert self.args.max_seq_length == sink_size + window_size + + self.model_ = enable_attention_sink( + module=self.model_, + params=model_args, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7b4ebf36a5..559b4e0489 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: with torch.no_grad(): runner = runner_class(args) # pyre-ignore: Missing argument [20] generated_tokens = ( - runner.chat_completion(temperature=args.temperature) + runner.chat_completion( + max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, + temperature=args.temperature, + show_progress=args.show_tokens, + ) if args.chat else runner.text_completion( prompt=args.prompt, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 13ac750305..891ce20db3 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -168,18 +168,19 @@ def text_completion( def chat_completion( self, + max_seq_len: int, temperature: float = 0.6, top_p: float = 0.9, + show_progress: bool = False, ) -> List[int]: """ Perform multi-turn chat with the language model. Args: - prompt (str): Text prompt for completion. + max_seq_len (int): Maximum number of tokens to generate for each prompt. temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. - + show_progress (bool, optional): Flag indicating whether to show number of tokens generated. Returns: Generated list of tokens. @@ -188,20 +189,26 @@ def chat_completion( """ exit_prompt = "exit" tokens = [] + pre_stop_token = [] prompt = input("Me: ") while prompt and prompt != exit_prompt: print("LLM: ", end="", flush=True) - new_tokens = self.generate( - prompt_tokens=self.tokenizer.encode( - self._format_prompt(prompt), bos=True, eos=False - ), - max_seq_len=self.max_seq_len, + prompt_tokens = self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ) + generated_tokens = self.generate( + prompt_tokens=pre_stop_token + prompt_tokens, + max_seq_len=max_seq_len, temperature=temperature, top_p=top_p, - echo=True, + echo=False, pos_base=len(tokens) - 1 if len(tokens) > 0 else 0, ) - tokens.extend(new_tokens) + pre_stop_token = generated_tokens[-1:] + tokens.extend(prompt_tokens) + tokens.extend(generated_tokens) + if show_progress: + print(f"[Generated {len(tokens)} tokens]") prompt = input("Me: ") return tokens diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 8f4fd1ebd2..b534a98e07 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,15 +7,22 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +import types from typing import Optional import torch -from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope +from executorch.examples.models.llama.llama_transformer import ( + Attention, + KVCache, + ModelArgs, + Rope, +) from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, ) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter class RopeWithAttentionSink(Rope): @@ -87,3 +94,231 @@ def rerotate_k( ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) + + +class KVCacheWithAttentionSink(KVCache): + """ + KV cache that supports attention sink. It keeps the initial few tokens as attention sink. + For other tokens, it uses a sliding window to keep the most recent tokens. + + Parameters: + window_size: the size of the sliding window + sink_size: the number of initial tokens to keep as attention sink + eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + """ + + def __init__( + self, + n_heads: int, + head_dim: int, + transpose_cache: bool, + enable_dynamic_shape: bool, + rope: RopeWithAttentionSink, + window_size: int, + sink_size: int, + eviction_batch_size: int, + max_batch_size: int = 1, + dtype=torch.float32, + ): + super().__init__( + max_batch_size=max_batch_size, + max_seq_length=window_size + sink_size, + n_heads=n_heads, + head_dim=head_dim, + transpose_cache=transpose_cache, + enable_dynamic_shape=enable_dynamic_shape, + dtype=dtype, + ) + self.rope = rope + self.window_size = window_size + self.sink_size = sink_size + self.eviction_batch_size = eviction_batch_size + self.position_shift = 0 + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + Evict old tokens from the cache to make rooms for new tokens. + + Parameters: + input_pos: the start position of the incoming token in the actual sequence + seq_len: the length of the incoming sequence + rope: the rope object to use for rerotating k + + Returns: + the number of tokens to evict from the cache which is also the number of + positions to shift for incoming tokens + """ + input_pos_item = input_pos.item() + torch._check_is_size(input_pos_item) + if input_pos_item + self.position_shift + seq_len > self.max_seq_length: + # There are not enough spaces in the cache to store the new tokens. + # We need to evict some old tokens and shift some recent tokens. + num_to_evict = max( + input_pos_item + self.position_shift - self.max_seq_length + seq_len, + self.eviction_batch_size, + ) + num_to_keep = ( + input_pos_item + self.position_shift - self.sink_size - num_to_evict + ) + num_empty_space = self.window_size - num_to_keep + dim_to_slice = 2 if self.transpose_cache else 1 + k_to_keep = self.k_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, # pyre-ignore [6] + num_to_keep, # pyre-ignore [6] + ) + if self.transpose_cache: + k_to_keep = self.rope.rerotate_k( + k=k_to_keep.transpose(1, 2), + original_position=( # pyre-ignore [6] + self.sink_size + num_to_evict + ), + new_position=self.sink_size, + ).transpose(1, 2) + else: + k_to_keep = self.rope.rerotate_k( + k=k_to_keep, + original_position=( # pyre-ignore [6] + self.sink_size + num_to_evict + ), + new_position=self.sink_size, + ) + self.k_cache = torch.cat( + [ + self.k_cache.narrow(dim_to_slice, 0, self.sink_size), + k_to_keep, + torch.zeros_like( + self.k_cache.narrow( + dim_to_slice, 0, num_empty_space # pyre-ignore [6] + ) + ), + ], + dim=dim_to_slice, + ) + self.v_cache = torch.cat( + [ + self.v_cache.narrow(dim_to_slice, 0, self.sink_size), + self.v_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, # pyre-ignore [6] + num_to_keep, # pyre-ignore [6] + ), + torch.zeros_like( + self.v_cache.narrow( + dim_to_slice, 0, num_empty_space # pyre-ignore [6] + ) + ), + ], + dim=dim_to_slice, + ) + self.position_shift -= num_to_evict # pyre-ignore [8] + return self.position_shift + + +def attention_sink_forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, +): + assert self.use_kv_cache + assert input_pos is not None + + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # Prepare for space in KV cache and get position shift + position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) + + # RoPE relative positional embeddings with shifted position in KV cache + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) + return self.wo(output) + + +def _replace_rope( + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, Rope) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + return rope_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_attention( + module: torch.nn.Module, + rope_with_attention_sink: RopeWithAttentionSink, + sink_size: int, + window_size: int, + eviction_batch_size: int, +): + for _, child_module in module._modules.items(): + if len(list(child_module.children())) > 0: # pyre-ignore [16] + _replace_attention( + module=child_module, # pyre-ignore [6] + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + + if isinstance(child_module, Attention): + kv_cache = child_module.kv_cache + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + transpose_cache=kv_cache.transpose_cache, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + child_module.SDPA.kv_cache = kv_cache_with_attention_sink + child_module.forward = types.MethodType( # pyre-ignore + attention_sink_forward, child_module + ) + + +def enable_attention_sink( + module: torch.nn.Module, + params: ModelArgs, + sink_size: int, + window_size: int, + eviction_batch_size: int, +) -> torch.nn.Module: + """ + Transform the model to be able to run inference with Attention Sink. + There mainly three steps: + - Replace Rope with RopeWithAttentionSink + - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + """ + rope_with_attention_sink = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + _replace_rope(module, rope_with_attention_sink) + _replace_attention( + module=module, + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + return module diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 8eaa992dc3..4ffecf1e9c 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -10,6 +10,7 @@ from executorch.examples.models.llama.llama_transformer import ModelArgs from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, RopeWithAttentionSink, ) from parameterized import parameterized @@ -79,14 +80,10 @@ def test_get_freqs( def test_rotate(self, original_position, new_position): seq_len = 32 - q = torch.rand( - 1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32 - ) + size = (1, seq_len, self.params.n_heads, self.params.head_dim) + q = torch.rand(*size, dtype=torch.float32) k = torch.rand( - 1, - seq_len, - self.params.n_heads, - self.params.head_dim, + *size, dtype=torch.float32, ) freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( @@ -118,3 +115,465 @@ def test_rotate(self, original_position, new_position): ) torch.testing.assert_close(rerotated_k, expected_k) + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + + _single_evict_test_cases = [ + [False, 4, 1], + [True, 4, 1], + ] + + _batch_evict_test_cases = [ + [False, 4, 8], + [True, 4, 8], + ] + + _sliding_window_test_cases = [ + [False, 0, 1], + [True, 0, 1], + ] + + def _init_cache(self, transpose_cache, sink_size, eviction_batch_size): + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_seq_len=self.window_size + sink_size, + ) + self.rope_with_attention_sink = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.params.n_heads, + head_dim=self.params.head_dim, + transpose_cache=transpose_cache, + enable_dynamic_shape=self.params.enable_dynamic_shape, + rope=self.rope_with_attention_sink, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=self.dtype, + ) + + def _rand_kv_with_length(self, transpose_cache, seq_len): + size = ( + ( + self.max_batch_size, + seq_len, + self.params.n_heads, + self.params.head_dim, + ) + if not transpose_cache + else ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + ) + if not transpose_cache: + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + else: + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + return k, v + + def _zero_kv_with_length(self, transpose_cache, seq_len): + size = ( + ( + self.max_batch_size, + seq_len, + self.params.n_heads, + self.params.head_dim, + ) + if not transpose_cache + else ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + ) + if not transpose_cache: + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + else: + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + return k, v + + def _get_dim_to_slice(self, transpose_cache): + return 2 if transpose_cache else 1 + + def _get_expected_rotated_k( + self, transpose_cache, k, original_position, new_position + ): + if transpose_cache: + return self.rope_with_attention_sink.rerotate_k( + k=k.transpose(1, 2), + original_position=original_position, + new_position=new_position, + ).transpose(1, 2) + else: + return self.rope_with_attention_sink.rerotate_k( + k=k, original_position=original_position, new_position=new_position + ) + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.dtype = torch.float32 + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_empty_cache(self, transpose_cache, sink_size, eviction_batch_size): + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache is empty, evict does nothing + input_pos = torch.tensor([0], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + expected_k, expected_v = self._zero_kv_with_length( + transpose_cache, self.window_size + sink_size + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_without_shift(self, transpose_cache, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has enough spaces for new tokens, no shift + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 10) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + zero_k, zero_v = self._zero_kv_with_length( + transpose_cache, self.window_size + sink_size - 10 + ) + + expected_k = torch.cat( + [ + k, + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_some_shift( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 24) == -2 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 24) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 1, 4), 6, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 1, 4), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_all_shift( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 27) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([32], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 5, 22), 10, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 5, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_some_shift_for_sliding_window( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 20) == -2 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 20) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + transpose_cache, k.narrow(dimension_to_slice, 2, 3), 2, 0 + ), + self._get_expected_rotated_k(transpose_cache, k1, 5, 3), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 2, 3), + v1, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_all_shift_for_sliding_window( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 23) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([28], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 1, 22), 6, 0 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v1.narrow(dimension_to_slice, 1, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_seq_len( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 12) == -10 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 12) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 9, 16), 14, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 9, 16), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_batch_size( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -8 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 10) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 7, 18), 12, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 7, 18), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 2feeefc4ef..1ebe9b2224 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -340,10 +340,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: exir.print_program.pretty_print(program) deboxed_int_list = [] - for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16] - deboxed_int_list.append( - program.execution_plan[0].values[item].val.int_val # pyre-ignore[16] - ) + for item in program.execution_plan[0].values[5].val.items: + deboxed_int_list.append(program.execution_plan[0].values[item].val.int_val) self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1])) @@ -459,11 +457,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Check the mul operator's stack trace contains f -> g -> h self.assertTrue( "return torch.mul(x, torch.randn(3, 2))" - in program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .stacktrace[1] - .items[-1] - .context + in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context ) self.assertEqual( program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f" @@ -616,11 +610,7 @@ def false_fn(y: torch.Tensor) -> torch.Tensor: if not isinstance(inst.instr_args, KernelCall): continue - op = ( - program.execution_plan[0] - .operators[inst.instr_args.op_index] # pyre-ignore[16] - .name - ) + op = program.execution_plan[0].operators[inst.instr_args.op_index].name if "mm" in op: num_mm += 1 @@ -657,19 +647,13 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # generate the tensor on which this iteration will operate on. self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[0] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[0].instr_args.op_index ].name, "aten::sym_size", ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[1] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[1].instr_args.op_index ].name, "aten::select_copy", ) @@ -681,28 +665,19 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # We check here that both of these have been generated. self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-5] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-5].instr_args.op_index ].name, "executorch_prim::et_copy_index", ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-4] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-4].instr_args.op_index ].name, "executorch_prim::add", ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-3] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-3].instr_args.op_index ].name, "executorch_prim::eq", ) @@ -716,10 +691,7 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-1] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-1].instr_args.op_index ].name, "executorch_prim::sub", ) @@ -1300,9 +1272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # this triggers the actual emission of the graph program = program_mul._emitter_output.program node = None - program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16] - 0 - ].instr_args.op_index + program.execution_plan[0].chains[0].instructions[0].instr_args.op_index # Find the multiplication node in the graph that was emitted. for node in program_mul.exported_program().graph.nodes: @@ -1314,7 +1284,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Find the multiplication instruction in the program that was emitted. for idx in range(len(program.execution_plan[0].chains[0].instructions)): instruction = program.execution_plan[0].chains[0].instructions[idx] - op_index = instruction.instr_args.op_index # pyre-ignore[16] + op_index = instruction.instr_args.op_index if "mul" in program.execution_plan[0].operators[op_index].name: break @@ -1453,9 +1423,7 @@ def forward(self, x, y): exec_prog._emitter_output.program self.assertIsNotNone(exec_prog.delegate_map) self.assertIsNotNone(exec_prog.delegate_map.get("forward")) - self.assertIsNotNone( - exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16] - ) + self.assertIsNotNone(exec_prog.delegate_map.get("forward").get(0)) self.assertEqual( exec_prog.delegate_map.get("forward").get(0).get("name"), "BackendWithCompilerExample", @@ -1568,9 +1536,7 @@ def forward(self, x): model = model.to_executorch() model.dump_executorch_program(True) self.assertTrue( - model.executorch_program.execution_plan[0] # pyre-ignore[16] - .values[0] - .val.allocation_info + model.executorch_program.execution_plan[0].values[0].val.allocation_info is not None ) executorch_module = _load_for_executorch_from_buffer(model.buffer) @@ -1611,9 +1577,7 @@ def forward(self, x): ) model.dump_executorch_program(True) self.assertTrue( - model.executorch_program.execution_plan[0] # pyre-ignore[16] - .values[0] - .val.allocation_info + model.executorch_program.execution_plan[0].values[0].val.allocation_info is not None ) executorch_module = _load_for_executorch_from_buffer(model.buffer) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 6ec740bd9f..6f0ab2a392 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -268,7 +268,9 @@ def _is_inplace_node(node: torch.fx.Node) -> bool: ) -def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None: +def update_tensor_lifetime( + node: torch.fx.Node, spec: TensorSpec, node_idx: int +) -> None: r""" Update the lifetime of the tensor to cover node_idx. A tensor's lifetime are represented by the index of the first and last node referring @@ -279,7 +281,10 @@ def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None: node_idx: extend the tensor's lifetime to cover node_idx """ start, end = spec.lifetime - start = node_idx if start is None or start > node_idx else start + if node.op == "placeholder": + start = 0 + else: + start = node_idx if start is None or start > node_idx else start end = node_idx if end is None or end < node_idx else end spec.lifetime = [start, end] @@ -444,7 +449,7 @@ def update_all_tensors_lifetime( do_assertion=False, ignore_dynamic_unbound_tensor=False, ): - update_tensor_lifetime(spec, node_idx) + update_tensor_lifetime(node, spec, node_idx) specs.add(spec) return specs diff --git a/exir/program/TARGETS b/exir/program/TARGETS index fc73abf1ff..674d7baa35 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -1,4 +1,5 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") @@ -43,7 +44,7 @@ python_library( "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/verification:verifier", - ], + ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) ) python_library( diff --git a/exir/program/_program.py b/exir/program/_program.py index b136d6cead..cbfa110528 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -75,8 +75,24 @@ Val = Any +from typing import Any, Callable + from torch.library import Library +try: + from executorch.exir.program.fb.logger import et_logger +except ImportError: + # Define a stub decorator that does nothing + def et_logger(api_name: str) -> Callable[[Any], Any]: + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + # This is the reserved namespace that is used to register ops to that will # be prevented from being decomposed during to_edge_transform_and_lower. edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP" @@ -957,6 +973,7 @@ def _gen_edge_manager_for_partitioners( return edge_manager +@et_logger("to_edge_transform_and_lower") def to_edge_transform_and_lower( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ @@ -1110,6 +1127,7 @@ def to_edge_with_preserved_ops( ) +@et_logger("to_edge") def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, @@ -1204,8 +1222,10 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram: """ Returns the ExportedProgram specified by 'method_name'. """ + return self._edge_programs[method_name] + @et_logger("transform") def transform( self, passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]], @@ -1253,6 +1273,7 @@ def transform( new_programs, copy.deepcopy(self._config_methods), compile_config ) + @et_logger("to_backend") def to_backend( self, partitioner: Union[Partitioner, Dict[str, Partitioner]] ) -> "EdgeProgramManager": @@ -1296,6 +1317,7 @@ def to_backend( new_edge_programs, copy.deepcopy(self._config_methods), config ) + @et_logger("to_executorch") def to_executorch( self, config: Optional[ExecutorchBackendConfig] = None, diff --git a/exir/tests/test_joint_graph.py b/exir/tests/test_joint_graph.py index 2413e2b498..f3b6f0ed55 100644 --- a/exir/tests/test_joint_graph.py +++ b/exir/tests/test_joint_graph.py @@ -73,25 +73,21 @@ def forward(self, x, y): # assert that the weight and bias have proper data_buffer_idx and allocation_info self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore - .values[0] - .val.data_buffer_idx, + et.executorch_program.execution_plan[0].values[0].val.data_buffer_idx, 1, ) self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore - .values[1] - .val.data_buffer_idx, + et.executorch_program.execution_plan[0].values[1].val.data_buffer_idx, 2, ) self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore + et.executorch_program.execution_plan[0] .values[0] .val.allocation_info.memory_offset_low, 0, ) self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore + et.executorch_program.execution_plan[0] .values[1] .val.allocation_info.memory_offset_low, 48, @@ -106,7 +102,7 @@ def forward(self, x, y): self.assertTrue(torch.allclose(loss, et_outputs[0])) self.assertTrue( - torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6] + torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore ) self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2])) self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3])) @@ -118,23 +114,17 @@ def forward(self, x, y): # gradient outputs start at index 1 self.assertEqual( - et.executorch_program.execution_plan[1] # pyre-ignore - .values[0] - .val.int_val, + et.executorch_program.execution_plan[1].values[0].val.int_val, 1, ) self.assertEqual( - et.executorch_program.execution_plan[2] # pyre-ignore - .values[0] - .val.string_val, + et.executorch_program.execution_plan[2].values[0].val.string_val, "linear.weight", ) # parameter outputs start at index 3 self.assertEqual( - et.executorch_program.execution_plan[3] # pyre-ignore - .values[0] - .val.int_val, + et.executorch_program.execution_plan[3].values[0].val.int_val, 3, ) diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 90398035e7..5e4573a2ba 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -14,6 +14,7 @@ import torch from executorch.exir import ExecutorchBackendConfig, to_edge +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( filter_nodes, get_node_tensor_specs, @@ -21,7 +22,7 @@ naive, Verifier, ) -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.pass_manager import PassManager from executorch.exir.passes import ( # noqa MemoryPlanningPass, @@ -593,3 +594,65 @@ def count_planned_inputs( num_placeholders, 5, ) + + def test_placeholder_lifetime(self) -> None: + class TestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, a, b, x): + a = a + b + b = a + b + y = self.linear(x) + return a, b, y + + model = TestModel() + example_inputs = (torch.rand(1, 6, 2), torch.rand(1, 6, 2), torch.randn(5, 5)) + exported_model = torch.export.export(model, example_inputs) + edge = to_edge(exported_model) + + class TestPass(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + permute_dims = [1, 0, 2] + for node in graph_module.graph.nodes: + if node.op == "placeholder" and str(node) == "a": + inverse_dims = [ + permute_dims.index(x) for x in range(len(permute_dims)) + ] + + with graph_module.graph.inserting_after(node): + permute = graph_module.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(node, inverse_dims), + ) + permute.meta = node.meta.copy() + node.meta["val"] = node.meta["val"].permute(permute_dims) + node.replace_all_uses_with( + permute, lambda x, permute=permute: x is not permute + ) + break + return PassResult(graph_module, True) + + edge = edge.transform([TestPass()]) + et = edge.to_executorch() + et_program = et.executorch_program + inputs = et_program.execution_plan[0].inputs + self.assertNotEqual( + et_program.execution_plan[0] # pyre-ignore + .values[inputs[0]] + .val.allocation_info.memory_offset_low, + et_program.execution_plan[0] # pyre-ignore + .values[inputs[1]] + .val.allocation_info.memory_offset_low, + ) + + constants = 0 + for node in et.exported_program().graph_module.graph.nodes: + if node.op == "placeholder" and node.meta.get("spec"): + meta_spec = node.meta["spec"] + if meta_spec.const is True: + constants += 1 + self.assertIsNone(node.meta["spec"].mem_offset) + self.assertIsNone(node.meta["spec"].mem_id) + self.assertEqual(constants, 2) diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index 0925a8abc8..318dc085b4 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -196,24 +196,14 @@ def test_spec(self) -> None: instructions = plan.chains[0].instructions self.assertEqual(len(instructions), 7) + self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2 + self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3 + self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6 + self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7 + self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9 self.assertEqual( - instructions[0].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx2 - self.assertEqual( - instructions[1].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx3 - self.assertEqual( - instructions[2].instr_args.op_index, 1 # pyre-ignore - ) # aten:mul @ idx6 - self.assertEqual( - instructions[3].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx7 - self.assertEqual( - instructions[4].instr_args.op_index, 1 # pyre-ignore - ) # aten:mul @ idx9 - self.assertEqual( - instructions[5].instr_args.op_index, 2 # pyre-ignore + instructions[5].instr_args.op_index, 2 ) # aten:view_copy @ idx11 self.assertEqual( - instructions[6].instr_args.op_index, 2 # pyre-ignore + instructions[6].instr_args.op_index, 2 ) # aten:view_copy @ idx11 diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java index 22ee7b8480..66ab50550a 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java @@ -63,7 +63,7 @@ public BenchmarkMetric( // the .pte model itself instead of parsing its name public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { final Matcher m = - Pattern.compile("(?\\w+)_(?\\w+)_(?\\w+)").matcher(model); + Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); if (m.matches()) { return new BenchmarkMetric.BenchmarkModel( m.group("name"), m.group("backend"), m.group("quantization")); diff --git a/extension/llm/custom_ops/CMakeLists.txt b/extension/llm/custom_ops/CMakeLists.txt index 811eb87ac6..16ca4fff80 100644 --- a/extension/llm/custom_ops/CMakeLists.txt +++ b/extension/llm/custom_ops/CMakeLists.txt @@ -84,6 +84,14 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT) target_include_directories( custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../include" ) + # TODO: This only works if we install portable_lib.so to + # /executorch/extension/pybindings/. + if(APPLE) + set(RPATH "@loader_path/../../pybindings") + else() + set(RPATH "$ORIGIN/../../pybindings") + endif() + set_target_properties(custom_ops_aot_lib PROPERTIES INSTALL_RPATH ${RPATH}) if(TARGET portable_lib) # If we have portable_lib built, custom_ops_aot_lib gives the ability to use # the ops in PyTorch and ExecuTorch through pybind @@ -109,5 +117,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT) ${_common_compile_options} -DET_USE_THREADPOOL ) - install(TARGETS custom_ops_aot_lib DESTINATION lib) + install(TARGETS custom_ops_aot_lib + LIBRARY DESTINATION executorch/extension/llm/custom_ops + ) endif() diff --git a/kernels/quantized/CMakeLists.txt b/kernels/quantized/CMakeLists.txt index 9d2b14d8eb..7e49f73b09 100644 --- a/kernels/quantized/CMakeLists.txt +++ b/kernels/quantized/CMakeLists.txt @@ -126,16 +126,14 @@ if(NOT CMAKE_GENERATOR STREQUAL "Xcode" # installed location of our _portable_lib.so file. To see these LC_* # values, run `otool -l libquantized_ops_lib.dylib`. if(APPLE) - set_target_properties( - quantized_ops_aot_lib - PROPERTIES # Assume this library will be installed in - # /executorch/kernels/quantized/, and the - # _portable_lib.so is installed in - # /executorch/extension/pybindings/ - BUILD_RPATH "@loader_path/../../extensions/pybindings" - INSTALL_RPATH "@loader_path/../../extensions/pybindings" - ) + set(RPATH "@loader_path/../../extensions/pybindings") + else() + set(RPATH "$ORIGIN/../../extensions/pybindings") endif() + set_target_properties( + quantized_ops_aot_lib PROPERTIES BUILD_RPATH ${RPATH} INSTALL_RPATH + ${RPATH} + ) endif() endif() endif() diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 847f764b0e..f07592fbfb 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -11,6 +11,9 @@ #include #include #include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#endif /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -22,6 +25,8 @@ namespace native { using Tensor = exec_aten::Tensor; using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; +using StridesType = exec_aten::StridesType; +using SizesType = exec_aten::SizesType; namespace { @@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args( quant_max); } +/** + * Useful to reduce a tensor `in` over a given dimension `dim` using the + * reduce function `fn`, which should have the following signature: + * void fn(const size_t size, const size_t stride, const size_t base_ix) + * where `size` and `stride` are the size and stride of the dimension being + * reduced and `base_ix` is the index of the first element of the reduction. + */ +template +void apply_over_unpacked_dim( + const Fn& fn, + const exec_aten::Tensor& in, + const int64_t& dim) { + if (in.numel() == 0) { + return; + } + + ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension"); + ET_CHECK_VALID_DIM(dim, in.dim()); + + const size_t d = ET_NORMALIZE_IX(dim, in.dim()); + const size_t dim_size = in.size(d); + const size_t outer_size = getLeadingDims(in, d); + const size_t inner_size = getTrailingDims(in, d); + // Loop through all outer dimensions + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + // Loop through dim + for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size; + ++unpacked_dim_idx) { + fn(inner_size, outer_idx, unpacked_dim_idx); + } + } +} + +void dequantize_optimized( + const int8_t* in, + const double scale, + const int64_t zero_point, + float* out, + int64_t quant_min, + int64_t quant_max, + size_t numel) { + ET_CHECK_MSG( + zero_point >= quant_min, + "zero_point must be %" PRId64 " <= quant_min %" PRId64, + zero_point, + quant_min); + ET_CHECK_MSG( + zero_point <= quant_max, + "zero_point must be %" PRId64 " >= quant_max %" PRId64, + zero_point, + quant_max); + size_t i = 0; +#if defined(__aarch64__) || defined(__ARM_NEON) + int8x8_t zero_point_vec = vdup_n_s8(zero_point); + float32x4_t scales = vdupq_n_f32(static_cast(scale)); + constexpr int32_t kVecSize = 16; + const size_t num_vecs = numel / kVecSize; + const int8_t* in_copy = in; + float* out_copy = out; + for (; i < num_vecs; i++) { + int8x16_t in_vec = vld1q_s8(in_copy); + int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7)); + int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7)); + float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales); + float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales); + + int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15)); + int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15)); + float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales); + float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales); + vst1q_f32(out_copy + 0, out_vec_0_3); + vst1q_f32(out_copy + 4, out_vec_4_7); + vst1q_f32(out_copy + 8, out_vec_8_11); + vst1q_f32(out_copy + 12, out_vec_12_15); + in_copy += kVecSize; + out_copy += kVecSize; + } + i = i * kVecSize; +#endif + for (; i < numel; i++) { + out[i] = (in[i] - zero_point) * scale; + } +} + +float get_scale(const Tensor& scale, size_t channel_ix) { + ET_CHECK_MSG( + (scale.scalar_type() == ScalarType::Double) || + (scale.scalar_type() == ScalarType::Float), + "scale.scalar_type() %" PRId8 " is not double or float type", + static_cast(scale.scalar_type())); + if (scale.scalar_type() == ScalarType::Double) { + return static_cast(scale.const_data_ptr()[channel_ix]); + } else { + return scale.const_data_ptr()[channel_ix]; + } +} + +bool can_use_optimized_dequantize_per_channel( + const Tensor& in, + const ScalarType in_dtype, + exec_aten::optional& out_dtype) { + bool is_contiguous = false; +#ifdef USE_ATEN_LIB + is_contiguous = in.is_contiguous(); +#else + is_contiguous = executorch::runtime::is_contiguous_dim_order( + in.dim_order().data(), in.dim()); +#endif + if (!is_contiguous || (in_dtype != ScalarType::Char) || + (out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) { + return false; + } + return true; +} + +void dequantize_per_channel_optimized( + const Tensor& in, + const Tensor& scales, + const optional& opt_zero_points, + Tensor& out, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType in_dtype, + exec_aten::optional& out_dtype) { + check_dequantize_per_tensor_args( + in, quant_min, quant_max, in_dtype, out_dtype, out); + ET_CHECK_MSG( + in_dtype == ScalarType::Char, + "in.scalar_type() %" PRId8 " is not supported:", + static_cast(in.scalar_type())); + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out_dtype.value() == ScalarType::Float, + "Only float output is supported"); + } + const int8_t* in_data = in.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + const int64_t* zero_points_data = nullptr; + if (opt_zero_points.has_value()) { + zero_points_data = opt_zero_points.value().const_data_ptr(); + } + const StridesType axis_stride = in.strides()[axis]; + const StridesType outer_stride = in.size(axis) * axis_stride; + apply_over_unpacked_dim( + [in_data, + out_data, + &scales, + zero_points_data, + axis_stride, + outer_stride, + quant_min, + quant_max]( + SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) { + const int8_t* in_data_local = + in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride; + const double scale = get_scale(scales, unpacked_dim_idx); + const int64_t zero_point = zero_points_data != nullptr + ? zero_points_data[unpacked_dim_idx] + : 0; + float* out_data_local = out_data + outer_idx * outer_stride + + unpacked_dim_idx * axis_stride; + dequantize_optimized( + in_data_local, + scale, + zero_point, + out_data_local, + quant_min, + quant_max, + numel); + }, + in, + axis); +} + } // namespace /** @@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out( return out; } -float get_scale(const Tensor& scale, size_t channel_ix) { - ET_CHECK_MSG( - (scale.scalar_type() == ScalarType::Double) || - (scale.scalar_type() == ScalarType::Float), - "scale.scalar_type() %" PRId8 " is not double or float type", - static_cast(scale.scalar_type())); - if (scale.scalar_type() == ScalarType::Double) { - return static_cast(scale.const_data_ptr()[channel_ix]); - } else { - return scale.const_data_ptr()[channel_ix]; - } -} - Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, @@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out( check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); + if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) { + dequantize_per_channel_optimized( + input, + scale, + opt_zero_points, + out, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + return out; + } + // a list contains all dimensions except axis int64_t dims[kTensorDimensionLimit]; for (int64_t i = 0; i < input.dim() - 1; i++) { diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 8d23e74e41..676aa32690 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -123,13 +123,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpDequantizeOutTest, DequantizePerChannel) { - et_pal_init(); - TensorFactory tf_byte; +template +void test_per_channel_dtype() { + TensorFactory tf; TensorFactory tf_double; TensorFactory tf_long; - Tensor input = tf_byte.full({3, 2}, 100); + Tensor input = tf.full({3, 2}, 100); Tensor scale = tf_double.make({2}, {0.5, 1}); Tensor zero_point = tf_long.make({2}, {30, 60}); int64_t quant_min = 0; @@ -147,7 +147,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/1, quant_min, quant_max, - ScalarType::Byte, + DTYPE, optional(), out); @@ -168,7 +168,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/0, quant_min, quant_max, - ScalarType::Byte, + DTYPE, optional(), out); @@ -176,7 +176,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { // Test with a different axis out = tfo.zeros({3}); - input = tf_byte.make({3}, {100, 100, 100}); + input = tf.make({3}, {100, 100, 100}); scale = tf_double.make({3}, {0.5, 0.75, 1}); zero_point = tf_long.make({3}, {30, 50, 60}); // (100 - 30) * 0.5 @@ -190,8 +190,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/0, quant_min, quant_max, - ScalarType::Byte, + DTYPE, + optional(), + out); + EXPECT_TENSOR_EQ(out, expected); + + // Test with a different axis + input = tf.full({3, 19}, 100); + out = tfo.zeros({3, 19}); + scale = tf_double.make({3}, {0.5, 0.75, 1}); + zero_point = tf_long.make({3}, {30, 50, 60}); + // (100 - 30) * 0.5 + // (100 - 50) * 0.75 + // (100 - 60) * 1 + expected = tfo.make( + {3, 19}, + {35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, + 35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5, + 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, + 37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 40, 40, 40, 40, 40, 40, 40, 40, 40}); + dequantize_per_channel_out( + input, + scale, + zero_point, + /*axis=*/0, + quant_min, + quant_max, + DTYPE, optional(), out); + EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpDequantizeOutTest, DequantizePerChannel) { + et_pal_init(); + test_per_channel_dtype(); + test_per_channel_dtype(); +} diff --git a/shim/xplat/executorch/build/runtime_wrapper.bzl b/shim/xplat/executorch/build/runtime_wrapper.bzl index ea5b2eee1f..ad3ba3bf2d 100644 --- a/shim/xplat/executorch/build/runtime_wrapper.bzl +++ b/shim/xplat/executorch/build/runtime_wrapper.bzl @@ -59,7 +59,7 @@ def _patch_executorch_references(targets, use_static_deps = False): return targets out_targets = [] for target in targets: - if target.startswith("//xplat/executorch"): + if target.startswith("//xplat/executorch/") or target.startswith("//xplat/executorch:"): fail("References to executorch build targets must use " + "`//executorch`, not `//xplat/executorch`") diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index e9dccdbdf1..a37fe32e55 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -13,7 +13,7 @@ import executorch.exir as exir import torch -from executorch.exir import to_edge +from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.test.backend_with_compiler_demo import ( @@ -52,6 +52,41 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]: return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) +class ModuleAddLarge(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + x: torch.Tensor = torch.add(a, b) + y: torch.Tensor = torch.add(x, c) + z: torch.Tensor = torch.add(x, y) + return z + + def get_random_inputs(self) -> Sequence[torch.Tensor]: + n = 10 # to create a large tensor + return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n)) + + +class ModuleSubLarge(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + x: torch.Tensor = torch.sub(a, b) + y: torch.Tensor = torch.sub(x, c) + z: torch.Tensor = torch.sub(x, y) + w: torch.Tensor = torch.sub(z, c) + return w + + def get_random_inputs(self) -> Sequence[torch.Tensor]: + n = 10 # to create a large tensor + return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n)) + + # # Backends # @@ -95,30 +130,45 @@ def __init__(self, fn): def forward(self, *args, **kwargs): return self.fn(*args, **kwargs) - edge: exir.EdgeProgramManager = to_edge( - export(WrapperModule(getattr(eager_module, method)), args=inputs) + exported_program = export(WrapperModule(getattr(eager_module, method)), args=inputs) + + edge_config = EdgeCompileConfig(_check_ir_validity=False) + et_config = exir.ExecutorchBackendConfig( + extract_delegate_segments=extract_delegate_segments, + constant_tensor_alignment=constant_tensor_alignemnt, + delegate_alignment=delegate_alignment, ) - lowered_module = to_backend(backend_id, edge.exported_program(), compile_specs=[]) + if backend_id == "XnnpackBackend": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) - class CompositeModule(nn.Module): - def __init__(self): - super().__init__() - self.lowered_module = lowered_module + executorch_program = to_edge_transform_and_lower( + exported_program, + compile_config=edge_config, + partitioner=[XnnpackPartitioner()], + ).to_executorch(config=et_config) + else: + edge: exir.EdgeProgramManager = to_edge(exported_program) + lowered_module = to_backend( + backend_id, edge.exported_program(), compile_specs=[] + ) - def forward(self, *args, **kwargs): - return self.lowered_module(*args, **kwargs) + class CompositeModule(nn.Module): + def __init__(self): + super().__init__() + self.lowered_module = lowered_module - composite_module = CompositeModule() - composite_module(*inputs) + def forward(self, *args, **kwargs): + return self.lowered_module(*args, **kwargs) - executorch_program = to_edge(export(composite_module, args=inputs)).to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments, - constant_tensor_alignment=constant_tensor_alignemnt, - delegate_alignment=delegate_alignment, - ) - ) + composite_module = CompositeModule() + composite_module(*inputs) + + executorch_program = to_edge( + export(composite_module, args=inputs) + ).to_executorch(config=et_config) return executorch_program.buffer diff --git a/test/models/targets.bzl b/test/models/targets.bzl index aea47c9e03..f291a17c62 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -117,6 +117,8 @@ def define_common_targets(): par_style = "xar", deps = [ ":export_delegated_program_lib", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + ], visibility = [], # Private ) @@ -124,6 +126,8 @@ def define_common_targets(): # Class names of nn.Modules for :exported_delegated_programs to export. DELEGATED_MODULES_TO_EXPORT = [ "ModuleAddMul", + "ModuleAddLarge", + "ModuleSubLarge", ] # Name of the backend to use when exporting delegated programs. @@ -153,3 +157,23 @@ def define_common_targets(): "//executorch/test/...", ], ) + + runtime.genrule( + name = "exported_xnnp_delegated_programs", + cmd = "$(exe :export_delegated_program)" + + " --modules " + ",".join(DELEGATED_MODULES_TO_EXPORT) + + " --backend_id " + "XnnpackBackend" + + " --outdir $OUT", + outs = { + fname + ".pte": [fname + ".pte"] + for fname in DELEGATED_MODULES_TO_EXPORT + }, + default_outs = ["."], + visibility = [ + "//executorch/runtime/executor/test/...", + "//executorch/backends/test/...", + "//executorch/test/...", + "@EXECUTORCH_CLIENTS", + ], + env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",}, + )