diff --git a/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py index 95fcfe5e..49c59d85 100644 --- a/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +++ b/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py @@ -16,11 +16,17 @@ from ai_edge_torch import lowertools import torch +fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros.default) +fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros_like.default) + class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase): def is_zero_tensor_node(self, node: torch.fx.Node): - return node.target == torch.ops.aten.zeros.default + return node.target in ( + torch.ops.aten.zeros.default, + torch.ops.aten.zeros_like.default, + ) def call(self, exported_program: torch.export.ExportedProgram): graph = exported_program.graph_module.graph diff --git a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py index 476d93fb..68821748 100644 --- a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py +++ b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py @@ -41,6 +41,9 @@ def forward(self, *args, **kwargs): module = func exported_program = torch.export.export(module, export_args) + exported_program = fx_infra.safe_run_decompositions( + exported_program, fx_infra.decomp.pre_convert_decomp() + ) exported_program = fx_infra.run_passes( exported_program, [