diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index c7cea31b49..bf4a274134 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -33,6 +33,7 @@ ExecutorchProgramManager, to_edge, ) +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass @@ -186,7 +187,6 @@ def export_to_edge( edge_prog_manager = to_edge( expo_program, compile_config=EdgeCompileConfig( - _skip_dim_order=True, # Allow specific non-core aten ops in the IR. _core_aten_ops_exception_list=[ torch.ops.aten._native_batch_norm_legit_functional.default, @@ -194,6 +194,10 @@ def export_to_edge( torch.ops.aten.linalg_vector_norm.default, torch.ops.aten.unfold.default, torch.ops.aten.angle.default, + # cadence replaced to_dim_order_copy with _to_copy for performance + # skip _to_copy op to get around of dim order check + # We should remove this op once cadence can support dim order + exir_ops.edge.aten._to_copy.default, ], ), constant_methods=constant_methods, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 89ef821c56..6a579b622f 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -11,6 +11,7 @@ # pyre-unsafe +import copy import math from operator import neg from typing import cast, Dict, Iterable, Sequence, Set, Tuple @@ -35,7 +36,12 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.dim_order_utils import get_memory_format from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.passes.dim_order_ops_registry import ( + DimOrderOpsMap, + MemoryFormatOpsMap, +) from torch._subclasses import FakeTensor from torch.fx.node import Argument @@ -1799,6 +1805,62 @@ def call_operator( ) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceToDimOrderCopyWithToCopyPass(ExportPass): + """ + dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass. + If the dim order is sequential, we don't need the extra work with strides and + can just use to_copy. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in DimOrderOpsMap: + return super().call_operator(op, args, kwargs, meta) + + # new kwargs with dim_order, and no memory_format for the new op + nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable + + ndim = None + + # can always get the shape, assuming rank is specialized + + # pyre-ignore[16]: `None` has no attribute `to_tensor` + if isinstance(args[0], ProxyValue) and args[0].is_tensor(): + # pyre-ignore[16]: `None` has no attribute `to_tensor` + ndim = args[0].to_tensor().dim() + elif isinstance(args[0], torch.Tensor): + # pyre-ignore[16]: `None` has no attribute `dim` + ndim = args[0].dim() + elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): + # pyre-ignore[6]: Incompatible parameter type + ndim = len(args[0]) + else: + assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" + + # get the "to" memory format for the EdgeOp + default_dim_order = list(range(ndim)) + dim_order = nkwargs.pop("dim_order", default_dim_order) + + # bring back memory format + # pyre-ignore[6]: Incompatible parameter type + nkwargs["memory_format"] = get_memory_format(dim_order) + + memory_format_op = MemoryFormatOpsMap[op] + + return super().call_operator( + memory_format_op, + args, + nkwargs, + meta, + ) + + @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceFullLikeWithFullPass(ExportPass): """ @@ -2108,4 +2170,5 @@ class CadenceReplaceOpsInGraph: ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, + ReplaceToDimOrderCopyWithToCopyPass, ]