Skip to content

Commit

Permalink
[FRONTEND] Fix binary compare op on constexprs (#1801)
Browse files Browse the repository at this point in the history
Example:

```
if static_a == 0 and static_b == 1:
    ...
```

The return value of `static_a == 0` should be `constexpr(True)` but not
`True`, otherwise the bool object (True/False) doesn't have the
`logical_and` method.
  • Loading branch information
Jokeren committed Jun 19, 2023
1 parent 04e47d7 commit 1851c8c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
13 changes: 9 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,11 +2484,11 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr):
# -------------


@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and_dynamic", "if_and_static"])
def test_if(if_type):

@triton.jit
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
pid = tl.program_id(0)
cond = tl.load(Cond)
if IfType == "if":
Expand All @@ -2498,17 +2498,22 @@ def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_exp":
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and":
elif IfType == "if_and_dynamic":
if BoolVar and pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and_static":
if StaticVaue != 0 and StaticVaue != 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))

cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1)
assert torch.equal(ret, x_true)


Expand Down
12 changes: 7 additions & 5 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,14 @@ def visit_Pass(self, node):
def visit_Compare(self, node):
if not (len(node.comparators) == 1 and len(node.ops) == 1):
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported")
lhs = _unwrap_if_constexpr(self.visit(node.left))
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
lhs_value = _unwrap_if_constexpr(lhs)
rhs_value = _unwrap_if_constexpr(rhs)
if type(node.ops[0]) == ast.Is:
return constexpr(lhs is rhs)
return constexpr(lhs_value is rhs_value)
if type(node.ops[0]) == ast.IsNot:
return constexpr(lhs is not rhs)
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
Expand Down Expand Up @@ -988,7 +990,7 @@ def execute_static_assert(self, node: ast.Call) -> None:
if not (0 < arg_count <= 2) or len(node.keywords):
raise TypeError("`static_assert` requires one or two positional arguments only")

passed = self.visit(node.args[0])
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
if not isinstance(passed, bool):
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
if not passed:
Expand Down

0 comments on commit 1851c8c

Please sign in to comment.