diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a2459891ebe5..759f082dea51 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2484,11 +2484,11 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): # ------------- -@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"]) +@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and_dynamic", "if_and_static"]) def test_if(if_type): @triton.jit - def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr): + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): pid = tl.program_id(0) cond = tl.load(Cond) if IfType == "if": @@ -2498,17 +2498,22 @@ def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr tl.store(Ret, tl.load(XFalse)) elif IfType == "if_exp": tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse)) - elif IfType == "if_and": + elif IfType == "if_and_dynamic": if BoolVar and pid % 2 == 0: tl.store(Ret, tl.load(XTrue)) else: tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) cond = torch.ones(1, dtype=torch.int32, device='cuda') x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') ret = torch.empty(1, dtype=torch.float32, device='cuda') - kernel[(1,)](cond, x_true, x_false, ret, if_type, True) + kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1) assert torch.equal(ret, x_true) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 7737a7902adb..0d91e0d35bd0 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -595,12 +595,14 @@ def visit_Pass(self, node): def visit_Compare(self, node): if not (len(node.comparators) == 1 and len(node.ops) == 1): raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported") - lhs = _unwrap_if_constexpr(self.visit(node.left)) - rhs = _unwrap_if_constexpr(self.visit(node.comparators[0])) + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) if type(node.ops[0]) == ast.Is: - return constexpr(lhs is rhs) + return constexpr(lhs_value is rhs_value) if type(node.ops[0]) == ast.IsNot: - return constexpr(lhs is not rhs) + return constexpr(lhs_value is not rhs_value) method_name = self._method_name_for_comp_op.get(type(node.ops[0])) if method_name is None: raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) @@ -988,7 +990,7 @@ def execute_static_assert(self, node: ast.Call) -> None: if not (0 < arg_count <= 2) or len(node.keywords): raise TypeError("`static_assert` requires one or two positional arguments only") - passed = self.visit(node.args[0]) + passed = _unwrap_if_constexpr(self.visit(node.args[0])) if not isinstance(passed, bool): raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values") if not passed: