Skip to content

Commit 75d4abc

Browse files
authored
Enable quantized add
Differential Revision: D69441041 Pull Request resolved: #8584
1 parent 3e188fe commit 75d4abc

File tree

5 files changed

+146
-3
lines changed

5 files changed

+146
-3
lines changed

backends/cadence/aot/ops_registrations.py

+44
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@
9999
"quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
100100
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
101101
)
102+
lib.define(
103+
"quantized_add.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
104+
"int Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
105+
)
102106
lib.define(
103107
"quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
104108
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
@@ -175,6 +179,10 @@
175179
"quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
176180
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
177181
)
182+
lib.define(
183+
"quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
184+
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
185+
)
178186
lib.define(
179187
"quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
180188
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
@@ -290,6 +298,42 @@ def dequantize_per_tensor_meta(
290298
return input.new_empty(input.size(), dtype=torch.float)
291299

292300

301+
@register_fake("cadence::quantized_add")
302+
def quantized_add_meta(
303+
X: torch.Tensor,
304+
X_scale: torch.Tensor,
305+
X_zero_point: torch.Tensor,
306+
Y: torch.Tensor,
307+
Y_scale: torch.Tensor,
308+
Y_zero_point: torch.Tensor,
309+
out_scale: float,
310+
out_zero_point: int,
311+
) -> torch.Tensor:
312+
out_size = X.size()
313+
if list(X.size()) == [1]:
314+
out_size = Y.size()
315+
316+
return X.new_empty(out_size, dtype=X.dtype)
317+
318+
319+
@register_fake("cadence::quantized_add.per_tensor")
320+
def quantized_add_per_tensor_meta(
321+
X: torch.Tensor,
322+
X_scale: float,
323+
X_zero_point: int,
324+
Y: torch.Tensor,
325+
Y_scale: float,
326+
Y_zero_point: int,
327+
out_scale: float,
328+
out_zero_point: int,
329+
) -> torch.Tensor:
330+
out_size = X.size()
331+
if list(X.size()) == [1]:
332+
out_size = Y.size()
333+
334+
return X.new_empty(out_size, dtype=X.dtype)
335+
336+
293337
@register_fake("cadence::quantized_linear")
294338
def quantized_linear_meta(
295339
src: torch.Tensor,

backends/cadence/aot/quantizer/fusion_pass.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from executorch.backends.cadence.aot.quantizer.patterns import (
1313
AddmmPattern,
14+
AddPattern,
1415
BmmPattern,
1516
Conv1dPattern,
1617
Conv2dPattern,
@@ -41,6 +42,47 @@
4142
ReluPatterns = (ReluPattern0, ReluPattern1)
4243

4344

45+
def get_args_and_kwargs_add(
46+
graph_module: GraphModule,
47+
inputs_inputs: List[fx.Node],
48+
dequants_inputs: List[fx.Node],
49+
quant_node: fx.Node,
50+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
51+
X_scale_ = graph_module.graph.call_function(
52+
torch.ops.aten.full.default,
53+
([1], dequants_inputs[0].args[1]),
54+
{"dtype": torch.float},
55+
)
56+
X_zero_point_ = graph_module.graph.call_function(
57+
torch.ops.aten.full.default,
58+
([1], dequants_inputs[0].args[2]),
59+
{"dtype": torch.int32},
60+
)
61+
Y_scale_ = graph_module.graph.call_function(
62+
torch.ops.aten.full.default,
63+
([1], dequants_inputs[1].args[1]),
64+
{"dtype": torch.float},
65+
)
66+
Y_zero_point_ = graph_module.graph.call_function(
67+
torch.ops.aten.full.default,
68+
([1], dequants_inputs[1].args[2]),
69+
{"dtype": torch.int32},
70+
)
71+
args = (
72+
inputs_inputs[0],
73+
X_scale_,
74+
X_zero_point_,
75+
inputs_inputs[1],
76+
Y_scale_,
77+
Y_zero_point_,
78+
quant_node.args[1],
79+
quant_node.args[2],
80+
)
81+
82+
kwargs = {}
83+
return args, kwargs
84+
85+
4486
# Helper function to get the args and kwargs for the linear replacement op
4587
def get_args_and_kwargs_linear(
4688
graph_module: GraphModule,
@@ -339,7 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
339381
)
340382
for fused_partition in fused_partitions:
341383
anchors = pattern.get_anchors(graph_module, fused_partition)
342-
if not anchors:
384+
if not anchors or anchors.empty:
343385
continue
344386
if any(self.is_fused(p.nodes) for p in fused_partition):
345387
continue
@@ -385,7 +427,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
385427
inputs_inputs + weights_inputs + other_inputs + bias_inputs
386428
)
387429
kwargs = {}
388-
if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
430+
if isinstance(pattern, AddPattern):
431+
args, kwargs = get_args_and_kwargs_add(
432+
graph_module,
433+
inputs_inputs,
434+
dequants_inputs,
435+
quant_node,
436+
)
437+
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
389438
args, kwargs = get_args_and_kwargs_conv(
390439
graph_module,
391440
inputs_inputs,

backends/cadence/aot/quantizer/patterns.py

+33
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class PartitionAnchors:
4343
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
4444
default_factory=list
4545
)
46+
empty: bool = False
4647

4748

4849
class QuantizationPattern(ABC):
@@ -101,6 +102,38 @@ def replacement_op(self) -> OpOverload:
101102
return torch.ops.cadence.quantized_linear
102103

103104

105+
class AddPattern(QuantizationPattern):
106+
def partition_types(self) -> List[OpOverload]:
107+
return [torch.ops.aten.add.Tensor]
108+
109+
def get_anchors(
110+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
111+
) -> PartitionAnchors:
112+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
113+
add_node = fused_partition[0].nodes[-1]
114+
115+
# Bail if:
116+
# - the add node is not a tensor add
117+
# - the add node has kwargs (e.g. alpha)
118+
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(
119+
add_node.args[1], fx.Node
120+
)
121+
if not is_tensor_add or len(add_node.kwargs) > 0:
122+
return PartitionAnchors(
123+
empty=True,
124+
)
125+
126+
return PartitionAnchors(
127+
inputs=[(add_node, 0), (add_node, 1)],
128+
weights=[],
129+
biases=[],
130+
output=[(add_node,)],
131+
)
132+
133+
def replacement_op(self) -> OpOverload:
134+
return torch.ops.cadence.quantized_add.default
135+
136+
104137
class BmmPattern(QuantizationPattern):
105138
def partition_types(self) -> List[OpOverload]:
106139
return [torch.ops.aten.bmm.default]

backends/cadence/aot/quantizer/quantizer.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from executorch.backends.cadence.aot.quantizer.patterns import (
1414
AddmmPattern,
15+
AddPattern,
1516
BmmPattern,
1617
Conv1dPattern,
1718
Conv2dPattern,
@@ -109,7 +110,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
109110
continue
110111

111112
anchors = self.pattern.get_anchors(model, fused_partition)
112-
if not anchors:
113+
if not anchors or anchors.empty:
113114
continue
114115
if is_annotated(
115116
[
@@ -211,3 +212,15 @@ def __init__(
211212
self,
212213
) -> None:
213214
super().__init__([])
215+
216+
217+
class CadenceWakeWordQuantizer(CadenceQuantizer):
218+
"""
219+
Quantizer for WakeWord, including add
220+
"""
221+
222+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
223+
if quantizers is None:
224+
quantizers = get_cadence_default_quantizers()
225+
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
226+
super().__init__(quantizers)

backends/cadence/aot/replace_ops.py

+4
Original file line numberDiff line numberDiff line change
@@ -1839,6 +1839,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18391839
replaced_scalar_args: dict[
18401840
EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]]
18411841
] = {
1842+
exir_ops.edge.cadence.quantized_add: (
1843+
exir_ops.edge.cadence.quantized_add.per_tensor,
1844+
[1, 2, 4, 5],
1845+
),
18421846
exir_ops.edge.cadence.quantized_conv: (
18431847
exir_ops.edge.cadence.quantized_conv.per_tensor,
18441848
[8, 9, 12, 13],

0 commit comments

Comments
 (0)