Skip to content

Commit 5e4d6b6

Browse files
authored
Arm: Add op_constant_pad_nd visitor (#8464)
add op_constant_pad_nd visitor
1 parent ed82561 commit 5e4d6b6

File tree

6 files changed

+225
-0
lines changed

6 files changed

+225
-0
lines changed

backends/arm/operator_support/tosa_supported_operators.py

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
138138
operator.getitem,
139139
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
140140
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
141+
exir_ops.edge.aten.constant_pad_nd.default,
141142
]
142143

143144
return supported

backends/arm/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
op_bmm,
1313
op_cat,
1414
op_clamp,
15+
op_constant_pad_nd,
1516
op_conv2d,
1617
op_eq,
1718
op_exp,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
import torch
12+
13+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
14+
get_input_qparams,
15+
)
16+
from executorch.backends.arm.operators.node_visitor import (
17+
NodeVisitor,
18+
register_node_visitor,
19+
)
20+
from executorch.backends.arm.tosa_mapping import TosaArg
21+
from serializer.tosa_serializer import TosaOp
22+
23+
24+
@register_node_visitor
25+
class ConstantPadNDVisitor(NodeVisitor):
26+
27+
target = "aten.constant_pad_nd.default"
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
37+
if inputs[0].dtype == ts.DType.INT8:
38+
input_qparams = get_input_qparams(node)
39+
qargs = input_qparams[0]
40+
pad_const_qs = qargs.quantize_value(inputs[2].number).item()
41+
pad_const_fp = 0.0
42+
else:
43+
pad_const_fp = inputs[2].number
44+
pad_const_qs = 0
45+
46+
rank = len(output.shape)
47+
# Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form
48+
# (padding_left, padding_right); to pad the last two dimensions, the pad has the form
49+
# (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding
50+
# values are in the reverse order. So, firstly we need to reverse the input padding parameters.
51+
input_pad = sum(
52+
[
53+
[inputs[1].special[i], inputs[1].special[i + 1]]
54+
for i in range(0, len(inputs[1].special), 2)
55+
][::-1],
56+
[],
57+
)
58+
# Then, add dummy zeros to make sure that both input_pad and output_pad has the same size.
59+
input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad
60+
# For PyTorch NCHW format, dim order is [0,...,rank-1]
61+
input_dim_order = list(range(rank))
62+
output_pad = [0] * rank * 2
63+
64+
# Map input padding parameters into output padding parameters. TOSA is NHWC format.
65+
for input_dim_idx, input_dim in enumerate(input_dim_order):
66+
output_dim_idx = output.dim_order.index(input_dim)
67+
output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[
68+
input_dim_idx * 2 : (input_dim_idx + 1) * 2
69+
]
70+
71+
attr = ts.TosaSerializerAttribute()
72+
attr.PadAttribute(tosa_graph.builder, output_pad, pad_const_qs, pad_const_fp)
73+
74+
tosa_graph.addOperator(TosaOp.Op().PAD, [inputs[0].name], [output.name], attr)

backends/arm/quantizer/quantization_annotator.py

+4
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def _match_pattern(
172172
torch.ops.aten.chunk.default,
173173
torch.ops.aten.contiguous.default,
174174
torch.ops.aten.upsample_nearest2d.vec,
175+
torch.ops.aten.pad.default,
175176
]
176177

177178
# Operators that can inherit the quantization specs from its parent node
@@ -216,6 +217,7 @@ def any_or_hardtanh_min_zero(n: Node):
216217
torch.ops.aten.conv1d.default,
217218
torch.ops.aten.conv2d.default,
218219
torch.ops.aten.linear.default,
220+
torch.ops.aten.conv2d.padding,
219221
],
220222
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
221223
],
@@ -225,6 +227,7 @@ def any_or_hardtanh_min_zero(n: Node):
225227
torch.ops.aten.conv1d.default,
226228
torch.ops.aten.conv2d.default,
227229
torch.ops.aten.linear.default,
230+
torch.ops.aten.conv2d.padding,
228231
):
229232
quant_properties.quant_inputs = [
230233
_QuantProperty(0, input_act_qspec),
@@ -237,6 +240,7 @@ def any_or_hardtanh_min_zero(n: Node):
237240
torch.ops.aten.conv1d.default,
238241
torch.ops.aten.conv2d.default,
239242
torch.ops.aten.linear.default,
243+
torch.ops.aten.conv2d.padding,
240244
):
241245
quant_properties.quant_inputs = [
242246
_QuantProperty(0, input_act_qspec),

backends/arm/quantizer/quantization_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _derive_qparams_fn(
7878
torch.ops.aten.conv1d.default,
7979
torch.ops.aten.conv2d.default,
8080
torch.ops.aten.linear.default,
81+
torch.ops.aten.conv2d.padding,
8182
]:
8283
input_act = node.args[0]
8384
weight = node.args[1]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
#
7+
# Test the pad_constant_nd op which pads the input tensor at specific dimension(s).
8+
#
9+
import unittest
10+
from typing import Tuple
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from parameterized import parameterized
18+
19+
test_data_suite = [
20+
("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
21+
("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
22+
("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
23+
("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
24+
("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
25+
("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
26+
("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
27+
("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
28+
("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
29+
]
30+
31+
32+
class TestConstantPadND(unittest.TestCase):
33+
"""Tests pad."""
34+
35+
class ConstantPadND(torch.nn.Module):
36+
def __init__(self, pad: Tuple, value: float | None = None):
37+
super().__init__()
38+
self.dim = len(pad) // 2
39+
self.value = value
40+
in_channels = 1
41+
# Only apply conv2d when the input dim = 4.
42+
if self.dim == 4:
43+
in_channels += pad[-3] + pad[-4]
44+
45+
self.conv2d = nn.Conv2d(
46+
in_channels=in_channels,
47+
out_channels=3,
48+
kernel_size=3,
49+
bias=True,
50+
stride=(2, 2),
51+
padding=0,
52+
)
53+
54+
in_channels = 3
55+
in_channels += pad[-3] + pad[-4]
56+
self.conv2d_1 = nn.Conv2d(
57+
in_channels=in_channels,
58+
out_channels=3,
59+
kernel_size=3,
60+
bias=True,
61+
padding="same",
62+
)
63+
64+
nonzero_idx = len(pad)
65+
for i in range(0, len(pad), 2):
66+
if pad[i] + pad[i + 1] == 0:
67+
nonzero_idx = i
68+
break
69+
self.pad = pad[:nonzero_idx]
70+
self.relu = nn.ReLU()
71+
self.sigmoid = nn.Sigmoid()
72+
73+
def forward(self, x: torch.Tensor):
74+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
75+
if self.dim == 4:
76+
x = self.conv2d(x)
77+
x = self.relu(x)
78+
79+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
80+
if self.dim == 4:
81+
x = self.conv2d_1(x)
82+
x = self.sigmoid(x)
83+
return x
84+
85+
def _test_constant_pad_nd_tosa_MI_pipeline(
86+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
87+
):
88+
(
89+
ArmTester(
90+
module,
91+
example_inputs=test_data,
92+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
93+
)
94+
.export()
95+
.check_count({"torch.ops.aten.pad.default": 2})
96+
.to_edge()
97+
.partition()
98+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
99+
.to_executorch()
100+
.run_method_and_compare_outputs(inputs=test_data)
101+
)
102+
103+
def _test_constant_pad_nd_tosa_BI_pipeline(
104+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
105+
):
106+
(
107+
ArmTester(
108+
module,
109+
example_inputs=test_data,
110+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
111+
)
112+
.quantize()
113+
.export()
114+
.check_count({"torch.ops.aten.pad.default": 2})
115+
.to_edge()
116+
.partition()
117+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
118+
.to_executorch()
119+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
120+
)
121+
122+
@parameterized.expand(test_data_suite)
123+
def test_constant_pad_nd_tosa_MI(
124+
self,
125+
test_name: str,
126+
test_data: torch.Tensor,
127+
padding: Tuple,
128+
value: float | None = None,
129+
):
130+
self._test_constant_pad_nd_tosa_MI_pipeline(
131+
self.ConstantPadND(padding, value), (test_data,)
132+
)
133+
134+
@parameterized.expand(test_data_suite)
135+
def test_constant_pad_nd_tosa_BI(
136+
self,
137+
test_name: str,
138+
test_data: torch.Tensor,
139+
padding: Tuple,
140+
value: float | None = None,
141+
):
142+
self._test_constant_pad_nd_tosa_BI_pipeline(
143+
self.ConstantPadND(padding, value), (test_data,)
144+
)

0 commit comments

Comments
 (0)