diff --git a/python/aitemplate/compiler/ops/common/elementwise.py b/python/aitemplate/compiler/ops/common/elementwise.py index a5bc5847a..1881074bf 100644 --- a/python/aitemplate/compiler/ops/common/elementwise.py +++ b/python/aitemplate/compiler/ops/common/elementwise.py @@ -23,7 +23,7 @@ from aitemplate.compiler.op_registry import OP_REGISTRY from aitemplate.compiler.ops.common.epilogue import FuncEnum from aitemplate.compiler.ops.common.int_elementwise import INT_ELEMENTWISE_FUNC - +from aitemplate.compiler.ops.tensor import cast from aitemplate.utils import shape_utils # pylint: disable=C0103,W0221,W0102,C0301,W0223,R1724 @@ -225,12 +225,31 @@ def __call__(self, *args: Tensor) -> Tensor: symbolic_args.append(arg._attrs["int_var"].symbolic_value()) elif isinstance(arg, Tensor): converted_args.append(arg) + arg_dtype = normalize_dtype(arg.dtype()) if common_dtype is None: - common_dtype = normalize_dtype(arg.dtype()) - elif normalize_dtype(arg.dtype()) != common_dtype: - raise NotImplementedError( - f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}" - ) + common_dtype = arg_dtype + elif arg_dtype != common_dtype: + if arg.dtype() == "bool" and common_dtype != "bool": + # If this arg is bool, and the common is not bool, cast to the common type. + converted_args[-1] = cast()( + x=converted_args[-1], dtype=common_dtype + ) + elif ( + arg.dtype() != "bool" + and common_dtype == "bool" + and len(converted_args) >= 2 + ): + # If this arg is non-bool and the common type is bool, + # cast all previous bool args to the non-bool type. + common_dtype = arg_dtype + for i in range(0, len(converted_args) - 1): + converted_args[i] = cast()( + x=converted_args[i], dtype=common_dtype + ) + else: + raise NotImplementedError( + f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}" + ) symbolic_args.append(arg._attrs.get("symbolic_value", None)) else: raise RuntimeError( diff --git a/tests/unittest/compiler/test_op_common_elementwise.py b/tests/unittest/compiler/test_op_common_elementwise.py new file mode 100644 index 000000000..2460ac4b1 --- /dev/null +++ b/tests/unittest/compiler/test_op_common_elementwise.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from aitemplate import compiler + +from aitemplate.compiler import compile_model, ops + +from aitemplate.compiler.base import Tensor +from aitemplate.compiler.ops.common.epilogue import FuncEnum +from aitemplate.compiler.transform.fuse_ops import ( + fuse_elementwise, + process_singleton_elementwise, +) +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + get_random_torch_tensor, + get_torch_empty_tensor, +) + + +def _make_graph(): + X0 = Tensor( + shape=[3, 5, 7, 9], + dtype="float16", + name="X0", + is_input=True, + ) + + Y = ops.elementwise(FuncEnum.ABS)(ops.elementwise(FuncEnum.SIN)(X0)) + + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + return Y + + +class OpCommonElementwiseTestCase(unittest.TestCase): + def test_elementwise_type_promotion_bool_rhs(self): + X0 = Tensor( + shape=[3, 5, 2], + dtype="float16", + name="X0", + is_input=True, + ) + X1 = Tensor( + shape=[3, 5, 2], + dtype="bool", + name="X1", + is_input=True, + ) + Y = ops.elementwise(FuncEnum.MUL)(X0, X1) + Y._attrs["name"] = "output0" + Y._attrs["is_output"] = True + target = detect_target() + module = compile_model( + Y, + target, + "./tmp", + "test_elementwise_type_promotion_bool_rhs", + ) + x0_pt = get_random_torch_tensor([3, 5, 2], "float16") + x1_pt = get_random_torch_tensor([3, 5, 2], "bool") + out_pt = get_torch_empty_tensor([3, 5, 2], "float16") + module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt}) + + def test_elementwise_type_promotion_bool_lhs(self): + X0 = Tensor( + shape=[3, 5, 2], + dtype="bool", + name="X1", + is_input=True, + ) + X1 = Tensor( + shape=[3, 5, 2], + dtype="float16", + name="X0", + is_input=True, + ) + Y = ops.elementwise(FuncEnum.MUL)(X0, X1) + Y._attrs["name"] = "output0" + Y._attrs["is_output"] = True + target = detect_target() + module = compile_model( + Y, + target, + "./tmp", + "test_elementwise_type_promotion_bool_lhs", + ) + x0_pt = get_random_torch_tensor([3, 5, 2], "float16") + x1_pt = get_random_torch_tensor([3, 5, 2], "bool") + out_pt = get_torch_empty_tensor([3, 5, 2], "float16") + module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt})