Skip to content

Commit 7c16c49

Browse files
authored
[ET-VK][ez] Update requirements for partitioning to_dim_order_copy (#7949)
Pull Request resolved: #7859 ## Context The previous registration of the to dim order copy op is incorrect. Currently, there is no implementation for the op in the Vulkan backend, but since Vulkan manages memory layout internally the op node can be removed as long as the only thing being changed is dim order. In some instances the op can be used to modify the dtype, in which case it will not be removed and the Vulkan delegate cannot execute the op correctly. Therefore, update the registration of the op to reflect this restriction. This diff should unblock enabling dim order ops for Vulkan. ghstack-source-id: 262710507 @exported-using-ghexport Differential Revision: [D68528213](https://our.internmc.facebook.com/intern/diff/D68528213/) Co-authored-by: Stephen Jia <[email protected]> (cherry picked from commit 5ee5f2f)
1 parent 374f034 commit 7c16c49

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

backends/vulkan/op_registry.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ def update_features_impl(op: OpKey):
228228
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229229
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231-
# dim order copy operator will be removed; memory layout is handled internally
232-
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
233231
]
234232
)
235233
def register_ephemeral_op(features: OpFeatures):
@@ -322,6 +320,37 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
322320
return features
323321

324322

323+
@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
324+
def register_to_copy_dim_order_op(features: OpFeatures):
325+
features.texture_impl = TextureImplFeatures(
326+
uses_axis_map=True,
327+
valid_packed_dims=all_packed_dims,
328+
)
329+
features.buffer_impl = True
330+
features.resize_fn = True
331+
332+
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
333+
# removed as long as the operator is not changing the dtype, i.e. the operator call
334+
# is modifying the dim order only. Therefore, check that the input and output dtypes
335+
# are the same, if so the operator is safe to remove.
336+
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
337+
in_arg = node.args[0]
338+
if not isinstance(in_arg, torch.fx.Node):
339+
return False
340+
341+
in_tensor = in_arg.meta.get("val", None)
342+
out_tensor = node.meta.get("val", None)
343+
344+
if in_tensor.dtype != out_tensor.dtype:
345+
return False
346+
347+
return True
348+
349+
features.check_node_fn = check_dim_order_copy_node
350+
351+
return features
352+
353+
325354
@update_features(
326355
[
327356
exir_ops.edge.aten.bmm.default,

0 commit comments

Comments
 (0)