|
11 | 11 | import torch
|
12 | 12 | from executorch.backends.cadence.aot.quantizer.patterns import (
|
13 | 13 | AddmmPattern,
|
| 14 | + AddPattern, |
14 | 15 | BmmPattern,
|
15 | 16 | Conv1dPattern,
|
16 | 17 | Conv2dPattern,
|
|
41 | 42 | ReluPatterns = (ReluPattern0, ReluPattern1)
|
42 | 43 |
|
43 | 44 |
|
| 45 | +def get_args_and_kwargs_add( |
| 46 | + graph_module: GraphModule, |
| 47 | + inputs_inputs: List[fx.Node], |
| 48 | + dequants_inputs: List[fx.Node], |
| 49 | + quant_node: fx.Node, |
| 50 | +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: |
| 51 | + X_scale_ = graph_module.graph.call_function( |
| 52 | + torch.ops.aten.full.default, |
| 53 | + ([1], dequants_inputs[0].args[1]), |
| 54 | + {"dtype": torch.float}, |
| 55 | + ) |
| 56 | + X_zero_point_ = graph_module.graph.call_function( |
| 57 | + torch.ops.aten.full.default, |
| 58 | + ([1], dequants_inputs[0].args[2]), |
| 59 | + {"dtype": torch.int32}, |
| 60 | + ) |
| 61 | + Y_scale_ = graph_module.graph.call_function( |
| 62 | + torch.ops.aten.full.default, |
| 63 | + ([1], dequants_inputs[1].args[1]), |
| 64 | + {"dtype": torch.float}, |
| 65 | + ) |
| 66 | + Y_zero_point_ = graph_module.graph.call_function( |
| 67 | + torch.ops.aten.full.default, |
| 68 | + ([1], dequants_inputs[1].args[2]), |
| 69 | + {"dtype": torch.int32}, |
| 70 | + ) |
| 71 | + args = ( |
| 72 | + inputs_inputs[0], |
| 73 | + X_scale_, |
| 74 | + X_zero_point_, |
| 75 | + inputs_inputs[1], |
| 76 | + Y_scale_, |
| 77 | + Y_zero_point_, |
| 78 | + quant_node.args[1], |
| 79 | + quant_node.args[2], |
| 80 | + ) |
| 81 | + |
| 82 | + kwargs = {} |
| 83 | + return args, kwargs |
| 84 | + |
| 85 | + |
44 | 86 | # Helper function to get the args and kwargs for the linear replacement op
|
45 | 87 | def get_args_and_kwargs_linear(
|
46 | 88 | graph_module: GraphModule,
|
@@ -339,7 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
|
339 | 381 | )
|
340 | 382 | for fused_partition in fused_partitions:
|
341 | 383 | anchors = pattern.get_anchors(graph_module, fused_partition)
|
342 |
| - if not anchors: |
| 384 | + if not anchors or anchors.empty: |
343 | 385 | continue
|
344 | 386 | if any(self.is_fused(p.nodes) for p in fused_partition):
|
345 | 387 | continue
|
@@ -385,7 +427,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
|
385 | 427 | inputs_inputs + weights_inputs + other_inputs + bias_inputs
|
386 | 428 | )
|
387 | 429 | kwargs = {}
|
388 |
| - if isinstance(pattern, (Conv1dPattern, Conv2dPattern)): |
| 430 | + if isinstance(pattern, AddPattern): |
| 431 | + args, kwargs = get_args_and_kwargs_add( |
| 432 | + graph_module, |
| 433 | + inputs_inputs, |
| 434 | + dequants_inputs, |
| 435 | + quant_node, |
| 436 | + ) |
| 437 | + elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)): |
389 | 438 | args, kwargs = get_args_and_kwargs_conv(
|
390 | 439 | graph_module,
|
391 | 440 | inputs_inputs,
|
|
0 commit comments