From b3e840093702f7ca23554138b7801053386040d7 Mon Sep 17 00:00:00 2001 From: Kefei Lu Date: Thu, 15 Feb 2024 22:31:42 -0800 Subject: [PATCH] ait: Explicitly throw when indexing a boolean tensor for masking (#992) Summary: Currently AIT hasn't implemented `tensor[boolean_tensor]` for masking. It fail shortly after this call, at: > 'Tensor' object has no attribute 'upper_bound' ``` > link-tree/aitemplate/utils/shape_utils.py(195)convert_IntVar_to_int() -> if var.upper_bound() == var.lower_bound(): ``` Reviewed By: frank-wei, khabinov Differential Revision: D53654054 --- fx2ait/fx2ait/converters/ait_converters.py | 7 +++++++ .../test/converters/test_ait_binary_op.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 3c21094ba..b04448c34 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -668,6 +668,13 @@ def acc_ops_getitem( isinstance(idx, Sequence) and any(isinstance(x, slice) for x in idx) ): return acc_ops_slice(target, args, kwargs, name) + + if isinstance(idx, AITTensor) and idx.dtype() == "bool": + # TODO: could do something similar to acc_ops_masked_select + raise NotImplementedError( + "AIT does not support tensor[boolean_tensor] masking yet" + ) + if isinstance(input_val, AITTensor): return acc_ops_slice(target, args, kwargs, name) diff --git a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py index 1da7b11e1..fbedbc1e6 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py +++ b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py @@ -154,6 +154,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: expected_ops={acc_op}, ) + def test_getitem_boolean_index(self) -> None: + """Verify that NotImplementatedError is thrown encountering + tensor[boolean_mask_tensor] + """ + + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return x[mask] + + mod = TestModule().cuda() + x = torch.rand(10, 4).half().cuda() + mask = (torch.rand((10,)) > 0.5).cuda() + mod(x, mask) + + self.assertRaises( + NotImplementedError, + lambda: self.run_test(mod, [x, mask], expected_ops={}), + ) + # This is a common binary op combo usage for ads models. def test_binary_op_combo(self) -> None: class TestModule(torch.nn.Module):