|
22 | 22 | from executorch.backends.cadence.aot.remove_ops import (
|
23 | 23 | RemoveAliasCopyOpPass,
|
24 | 24 | RemoveBranchedQuantDequant,
|
| 25 | + RemoveCatFromSliceCopyPass, |
25 | 26 | RemoveCloneOpPass,
|
26 | 27 | RemoveContiguousOpPass,
|
27 | 28 | RemoveDetachCopyPass,
|
@@ -741,3 +742,54 @@ def forward(self, x):
|
741 | 742 | },
|
742 | 743 | )
|
743 | 744 | )
|
| 745 | + |
| 746 | + def test_remove_cat_from_slice_copy_all_removal(self) -> None: |
| 747 | + class M(torch.nn.Module): |
| 748 | + def __init__(self): |
| 749 | + super().__init__() |
| 750 | + |
| 751 | + def forward(self, x, y): |
| 752 | + x1 = torch.cat((x, y), 0) # (2, 4) |
| 753 | + return torch.slice_copy(x1, dim=0, start=0, end=1) |
| 754 | + |
| 755 | + inputs = tuple(torch.randn(2, 4) for _ in range(2)) |
| 756 | + graph_module = export_to_edge(M(), inputs).exported_program().graph_module |
| 757 | + p = RemoveCatFromSliceCopyPass() |
| 758 | + graph_module = cast(PassResult, p(graph_module)).graph_module |
| 759 | + |
| 760 | + # Ensure both cat nodes were removed |
| 761 | + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) |
| 762 | + |
| 763 | + def test_remove_cat_from_slice_copy_no_removal(self) -> None: |
| 764 | + class M(torch.nn.Module): |
| 765 | + def __init__(self): |
| 766 | + super().__init__() |
| 767 | + |
| 768 | + def forward(self, x, y): |
| 769 | + x1 = torch.cat((x, y), 0) # (2, 4) |
| 770 | + return torch.slice_copy(x1, dim=0, start=0, end=3) |
| 771 | + |
| 772 | + inputs = tuple(torch.randn(2, 4) for _ in range(2)) |
| 773 | + graph_module = export_to_edge(M(), inputs).exported_program().graph_module |
| 774 | + p = RemoveCatFromSliceCopyPass() |
| 775 | + graph_module = cast(PassResult, p(graph_module)).graph_module |
| 776 | + |
| 777 | + # Ensure both cat nodes were removed |
| 778 | + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) |
| 779 | + |
| 780 | + def test_remove_cat_from_slice_copy_zero_range(self) -> None: |
| 781 | + class M(torch.nn.Module): |
| 782 | + def __init__(self): |
| 783 | + super().__init__() |
| 784 | + |
| 785 | + def forward(self, x, y): |
| 786 | + x1 = torch.cat((x, y), 0) # (2, 4) |
| 787 | + return torch.slice_copy(x1, dim=0, start=0, end=0) |
| 788 | + |
| 789 | + inputs = tuple(torch.randn(2, 4) for _ in range(2)) |
| 790 | + graph_module = export_to_edge(M(), inputs).exported_program().graph_module |
| 791 | + p = RemoveCatFromSliceCopyPass() |
| 792 | + graph_module = cast(PassResult, p(graph_module)).graph_module |
| 793 | + |
| 794 | + # Ensure both cat nodes were removed |
| 795 | + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) |
0 commit comments