Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Empty Input Tensors and > 5 Cat Inputs #7855

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,7 @@ oncall("executorch")

python_library(
name = "xnnpack_passes",
srcs = [
"__init__.py",
"channels_last_tagged_reshape_pass.py",
"conv1d_unsqueeze_pass.py",
"convert_to_linear.py",
"convert_to_sdpa.py",
"convert_to_upsample_bilinear2d.py",
"fuse_activation_pass.py",
"fuse_batch_norm_with_conv.py",
"prelu_reshape_pass.py",
"remove_getitem_op.py",
"tag_implicit_q_dq_pass.py",
"xnnpack_pass.py",
],
srcs = native.glob(["*.py"]),
deps = [
"//caffe2:torch",
"//executorch/backends/transforms:addmm_mm_to_linear",
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
ConvertToUpsampleBilinear2d,
)
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
ConstPropPass,
FuseBatchNormWithConvPass,
FuseActivationPass,
DecomposeConcatenate,
RemoveGetItemPass,
Conv1dUnsqueezePass,
PReLUReshapePass,
Expand Down
99 changes: 99 additions & 0 deletions backends/xnnpack/_passes/decompose_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class DecomposeConcatenate(ExportPass):
"""
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
at a time. As a result, to support concatenates with > 5 tensors, we can decompose
concatenates into sequences of cats each with <= 5 tensors.

Example:
Before Pass:
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);

After Pass:
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
"""

def call(self, graph_module: torch.fx.GraphModule):
gm = graph_module
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.target.__name__ == "aten.cat.default"
):
concat_args = node.args
nodes_to_concat = node.args[0]
if len(nodes_to_concat) <= 5:
continue

is_quantized = all(
is_dequant(node) for node in nodes_to_concat
) and all(is_quant(node) for node in node.users.keys())

# replace the cat args with the same args but only with the first 5 nodes
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:]
node.args = new_concat_args

remainder_nodes_to_concat = nodes_to_concat[5:]
with gm.graph.inserting_after(node):
logger.debug(f"Decomposing cat node {node}")
remainder_concat_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.aten.cat.default,
args=([],), # we will replace this remainder_nodes later
kwargs=node.kwargs,
)
node.replace_all_uses_with(remainder_concat_node)
if is_quantized:
# if quantized we need to enforce the q/dq pattern for the newly inserted
# concat node
q_params = nodes_to_concat[0].args[1:]
q_kwargs = nodes_to_concat[0].kwargs
# Quantizer enforces all the inputs and output to a concat node must share
# the same qparams, this means the newly inserted q/dq pair must share the
# same qparams as the first quantized input in the concat node.
with gm.graph.inserting_after(node):
logger.debug(
f"Inserting Q/DQ pair for new cat node {remainder_concat_node}"
)
q_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(node,) + q_params,
kwargs=q_kwargs,
)
with gm.graph.inserting_after(q_node):
dq_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q_node,) + q_params,
kwargs=q_kwargs,
)
remainder_concat_node.args = (
[dq_node] + remainder_nodes_to_concat,
) + node.args[1:]
else:
remainder_concat_node.args = (
[node] + remainder_nodes_to_concat,
) + node.args[1:]

gm.recompile()
new_gm = super().call(gm).graph_module
return PassResult(new_gm, True)
4 changes: 2 additions & 2 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:

num_tensors = len(node.all_input_nodes)

if not (num_tensors >= 2 and num_tensors <= 5):
if not (num_tensors >= 2):
why(
node,
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
reason=f"only support concatenation of > 2 tensors, got {num_tensors} tensors",
)
return False

Expand Down
79 changes: 44 additions & 35 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

class TestCat(unittest.TestCase):
class Cat(torch.nn.Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim

def forward(self, *args):
xs = [*args]
x = torch.cat(xs)
x = torch.cat(xs, dim=self.dim)
return x + x # Quantize by propagation.

def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
Expand All @@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
tester.quantize()

tester.export().check_count({"torch.ops.aten.cat": 1})
tester.dump_artifact()

if quant:
# Expect multiple quantize ops - one per input, cat, and add.
Expand Down Expand Up @@ -93,6 +96,29 @@ def test_fp16_cat4(self):
)
self._test_cat(self.Cat(), inputs)

def test_fp16_cat5(self):
"""
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
inputs = (
torch.randn(1, 2, 3).to(torch.float16),
torch.randn(3, 2, 3).to(torch.float16),
torch.randn(2, 2, 3).to(torch.float16),
torch.randn(5, 2, 3).to(torch.float16),
torch.randn(5, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat(), inputs)

def test_fp16_cat_gt_5(self):
"""
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3).to(torch.float16))
self._test_cat(self.Cat(), tuple(inputs))

def test_fp32_cat2(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat(), inputs)
Expand Down Expand Up @@ -120,6 +146,13 @@ def test_fp32_cat5(self):
)
self._test_cat(self.Cat(), inputs)

def test_fp32_cat_gt_5(self):
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))
self._test_cat(self.Cat(), tuple(inputs))

def test_qs8_cat2(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
Expand All @@ -137,46 +170,22 @@ def test_qs8_cat4(self):
)
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)

def test_fp32_cat_unsupported(self):
"""
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
"""
def test_qs8_cat5(self):
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(2, 2, 3),
)
(
Tester(self.Cat(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge_transform_and_lower()
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
)

def test_fp32_cat_unsupported_legacy_mode(self):
"""
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
"""
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(6, 2, 3),
)
(
Tester(self.Cat(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge()
.partition()
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
)
self._test_cat(self.Cat(), inputs, cat_num=5, quant=True)

def test_qs8_cat_gt_5(self):
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)

class CatNegativeDim(torch.nn.Module):
def __init__(self):
Expand Down
108 changes: 108 additions & 0 deletions backends/xnnpack/test/passes/test_decompose_cat_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
import unittest

import torch
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestDecomposeCatPass(unittest.TestCase):
PassStage = RunPasses([DecomposeConcatenate])
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default"

class Cat(torch.nn.Module):
def forward(self, *args):
xs = [*args]
x = torch.cat(xs)
return x + x # Quantize by propagation.

def test_cat_gt_5(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

def test_cat_gt_10(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in [11, 16, 18]:
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

def test_qs8_cat_gt_5(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.quantize()
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

def test_cat_gt_10(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in [11, 16, 18]:
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)
Loading