Skip to content

Commit e6d44fa

Browse files
authored
add pass to remove cat from slice pass
Differential Revision: D70425971 Pull Request resolved: #8857
1 parent 169e6ae commit e6d44fa

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

backends/cadence/aot/remove_ops.py

+66
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,72 @@ def remove_branched(
807807
user.replace_all_uses_with(node.args[0])
808808

809809

810+
class RemoveCatFromSliceCopyPass(ExportPass):
811+
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
812+
slice_copy_nodes = [
813+
node
814+
for node in graph_module.graph.nodes
815+
if node.target == exir_ops.edge.aten.slice_copy.Tensor
816+
]
817+
for slice_copy_node in slice_copy_nodes:
818+
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
819+
input_node, *other_args = slice_copy_node.args
820+
if len(other_args) >= 1:
821+
slice_dim = other_args[0]
822+
if len(other_args) >= 2:
823+
start_idx = other_args[1]
824+
if len(other_args) >= 3:
825+
end_idx = other_args[2]
826+
if len(other_args) >= 4:
827+
step = other_args[3]
828+
if step != 1:
829+
continue
830+
slice_copy_dtype = slice_copy_node.meta["val"].dtype
831+
if input_node.target != exir_ops.edge.aten.cat.default:
832+
continue
833+
cat_dtype = input_node.meta["val"].dtype
834+
if slice_copy_dtype != cat_dtype:
835+
continue
836+
cat_dim = input_node.args[1:]
837+
if len(cat_dim) == 0:
838+
cat_dim = 0
839+
if cat_dim != slice_dim:
840+
continue
841+
cat_output_shape = input_node.meta["val"].shape
842+
start_idx = (
843+
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
844+
)
845+
end_idx = (
846+
cat_output_shape[cat_dim]
847+
if end_idx > cat_output_shape[cat_dim]
848+
else end_idx
849+
)
850+
base_idx = 0
851+
cat_input_to_keep = None
852+
for cat_input_node in input_node.args[0]:
853+
cat_input_dtype = cat_input_node.meta["val"].dtype
854+
if slice_copy_dtype != cat_input_dtype:
855+
continue
856+
cat_input_shape = cat_input_node.meta["val"].shape
857+
858+
# check if the slice range overlaps with the cat range
859+
if (
860+
base_idx <= start_idx
861+
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
862+
):
863+
cat_input_to_keep = cat_input_node
864+
break
865+
base_idx += list(cat_input_shape)[cat_dim]
866+
if cat_input_to_keep is not None:
867+
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
868+
869+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
870+
self._remove_unused_cat(graph_module)
871+
graph_module.recompile()
872+
graph_module.graph.eliminate_dead_code()
873+
return super().call(graph_module)
874+
875+
810876
# The following class consolidates functions to remove ops that are redundant
811877
# in Jarvis. Currently, each function in this class iterates over each node of
812878
# the graph module once. In future, we could consolidate them into a monolithic

backends/cadence/aot/tests/test_remove_ops_passes.py

+52
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
2424
RemoveBranchedQuantDequant,
25+
RemoveCatFromSliceCopyPass,
2526
RemoveCloneOpPass,
2627
RemoveContiguousOpPass,
2728
RemoveDetachCopyPass,
@@ -741,3 +742,54 @@ def forward(self, x):
741742
},
742743
)
743744
)
745+
746+
def test_remove_cat_from_slice_copy_all_removal(self) -> None:
747+
class M(torch.nn.Module):
748+
def __init__(self):
749+
super().__init__()
750+
751+
def forward(self, x, y):
752+
x1 = torch.cat((x, y), 0) # (2, 4)
753+
return torch.slice_copy(x1, dim=0, start=0, end=1)
754+
755+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
756+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
757+
p = RemoveCatFromSliceCopyPass()
758+
graph_module = cast(PassResult, p(graph_module)).graph_module
759+
760+
# Ensure both cat nodes were removed
761+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
762+
763+
def test_remove_cat_from_slice_copy_no_removal(self) -> None:
764+
class M(torch.nn.Module):
765+
def __init__(self):
766+
super().__init__()
767+
768+
def forward(self, x, y):
769+
x1 = torch.cat((x, y), 0) # (2, 4)
770+
return torch.slice_copy(x1, dim=0, start=0, end=3)
771+
772+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
773+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
774+
p = RemoveCatFromSliceCopyPass()
775+
graph_module = cast(PassResult, p(graph_module)).graph_module
776+
777+
# Ensure both cat nodes were removed
778+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
779+
780+
def test_remove_cat_from_slice_copy_zero_range(self) -> None:
781+
class M(torch.nn.Module):
782+
def __init__(self):
783+
super().__init__()
784+
785+
def forward(self, x, y):
786+
x1 = torch.cat((x, y), 0) # (2, 4)
787+
return torch.slice_copy(x1, dim=0, start=0, end=0)
788+
789+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
790+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
791+
p = RemoveCatFromSliceCopyPass()
792+
graph_module = cast(PassResult, p(graph_module)).graph_module
793+
794+
# Ensure both cat nodes were removed
795+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

0 commit comments

Comments
 (0)