From 0d22f912511fe6419b867930b59ff66e28f2c35c Mon Sep 17 00:00:00 2001 From: Bradley Davis Date: Tue, 30 Jan 2024 13:34:58 -0800 Subject: [PATCH] allow binary ops where only one arg is an immutable IntVarTensor (#987) Summary: rarely, the output of a call to .size() is one of the operands in a binary op, and as such has type IntVarTensor. In this case, it is okay to forward the call to elmentwise (instead of raising an error before a call to int_elementwise(). Reviewed By: khabinov Differential Revision: D53240440 --- fx2ait/fx2ait/converters/utils.py | 20 ++++++++++++++++ .../test/converters/test_ait_binary_op.py | 24 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/fx2ait/fx2ait/converters/utils.py b/fx2ait/fx2ait/converters/utils.py index d64cb084a..c7bf901b2 100644 --- a/fx2ait/fx2ait/converters/utils.py +++ b/fx2ait/fx2ait/converters/utils.py @@ -86,6 +86,26 @@ def create_binary_op( ) return res + if ( + isinstance(lhs, AITTensor) + and isinstance(rhs, IntVarTensor) + and isinstance(rhs._attrs["int_var"], IntImm) + and rhs_is_constant + ): + # If rhs is a constant IntVarTensor but lhs is not, proceed + rhs = rhs_constant + return elementwise(op_type)(lhs, rhs) + + if ( + isinstance(rhs, AITTensor) + and isinstance(lhs, IntVarTensor) + and isinstance(lhs._attrs["int_var"], IntImm) + and lhs_is_constant + ): + # If lhs is a constant IntVarTensor but rhs is not, proceed + lhs = lhs_constant + return elementwise(op_type)(lhs, rhs) + if isinstance(lhs, IntVarTensor) or isinstance(rhs, IntVarTensor): lhs = IntVarTensor(IntImm(lhs)) if isinstance(lhs, int) else lhs rhs = IntVarTensor(IntImm(rhs)) if isinstance(rhs, int) else rhs diff --git a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py index 1a13daaa9..1da7b11e1 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py +++ b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py @@ -167,3 +167,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: [torch.randn(2, 4).half().cuda()], expected_ops={acc_ops.reshape, acc_ops.mul}, ) + + def test_binary_one_intmm_constant_lhs(self) -> None: + class TestModule(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.add(input, input.size()[0]) + + model = TestModule().cuda() + self.run_test( + model, + [torch.randn((1, 1)).half().cuda()], + expected_ops={acc_ops.add}, + ) + + def test_binary_one_intmm_constant_rhs(self) -> None: + class TestModule(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.add(input.size()[0], input) + + model = TestModule().cuda() + self.run_test( + model, + [torch.randn((1, 1)).half().cuda()], + expected_ops={acc_ops.add}, + )