Skip to content

Commit e7fb2bd

Browse files
authored
[ET-VK][ez] Improve insert_prepack_node pass to handle multiple uses of constant tensors (#10488)
## Context Refer to #6352 for why the `insert_prepack_nodes` pass is needed. The current logic of the pass assumes that each constant tensor node has only one use. However, in reality, a constant tensor node may have multiple uses; some of which may require the insertion of a prepacking node and some which may not (since they may choose to apply some special packing in the operator implementation). Currently, if any uses of a constant tensor node handles its own prepacking, then prepacking nodes will not be inserted. This makes it so that a model will produce a type error during runtime when an operator receives a `TensorRef` but expects a `Tensor`. ## Changes Improve the logic of the pass to handle constant tensor nodes which have multiple uses. If any use does not handle its own prepacking, then a prepacking node will be inserted for those usages. Differential Revision: [D73592619](https://our.internmc.facebook.com/intern/diff/D73592619/)
1 parent b669184 commit e7fb2bd

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88

99
from copy import deepcopy
1010

11-
import executorch.backends.vulkan.custom_ops_lib # noqa
12-
13-
import torch
14-
1511
from executorch.backends.vulkan.op_registry import handles_own_prepacking
1612
from executorch.backends.vulkan.utils import is_param_node
1713

@@ -31,27 +27,27 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3127
argument into the operator implementation.
3228
"""
3329

34-
def prepack_not_required(node: torch.fx.Node) -> bool:
30+
for node in program.graph_module.graph.nodes:
31+
# Prepacking is only needed for constant tensors. Only nodes corresponding to
32+
# constant tensors will proceed beyond this point.
3533
if not is_param_node(program, node):
36-
return True
34+
continue
3735

38-
# Annotate that this node is going to represented as a tensorref in the Vulkan
39-
# compute graph. This will be useful for later graph passes.
36+
# Mark that this node is going to be represented as a TensorRef type in the
37+
# Vulkan compute graph. This annotation is used in later graph passes.
4038
node.meta["vkdg_tensorref"] = True
4139

40+
# Get the list of node users that do not handle their own prepacking
41+
nodes_to_replace_input = []
4242
for user in node.users:
43-
if user.op == "call_function" and handles_own_prepacking(
44-
# pyre-ignore
45-
user.target
46-
):
47-
return True
43+
if user.op == "call_function" and not handles_own_prepacking(user.target):
44+
nodes_to_replace_input.append(user)
4845

49-
return False
50-
51-
for node in program.graph_module.graph.nodes:
52-
if prepack_not_required(node):
46+
if len(nodes_to_replace_input) == 0:
5347
continue
5448

49+
replace_all_uses = len(nodes_to_replace_input) == len(node.users)
50+
5551
with program.graph_module.graph.inserting_after(node):
5652
prepack_node = program.graph_module.graph.create_node(
5753
"call_function",
@@ -74,9 +70,14 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
7470
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
7571
# memory object.
7672
prepack_node.meta["spec"].mem_obj_id = -1
77-
node.replace_all_uses_with(
78-
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
79-
)
73+
if replace_all_uses:
74+
node.replace_all_uses_with(
75+
prepack_node,
76+
lambda x, y=prepack_node: (x != y and x.op != "output"),
77+
)
78+
else:
79+
for user_node in nodes_to_replace_input:
80+
user_node.replace_input_with(node, prepack_node)
8081

8182
program.graph.eliminate_dead_code()
8283
return program

0 commit comments

Comments
 (0)