Skip to content

Commit

Permalink
[refactor] Support static short circuit bool operations (#3958)
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-hitonami authored Jan 7, 2022
1 parent ef6237a commit ec54b89
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
37 changes: 30 additions & 7 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,16 +591,39 @@ def inner(operands):

return inner

@staticmethod
def build_static_short_circuit_and(operands):
for operand in operands:
if not operand.ptr:
return operand.ptr
return operands[-1].ptr

@staticmethod
def build_static_short_circuit_or(operands):
for operand in operands:
if operand.ptr:
return operand.ptr
return operands[-1].ptr

@staticmethod
def build_BoolOp(ctx, node):
build_stmts(ctx, node.values)
ops = {
ast.And: ASTTransformer.build_short_circuit_and,
ast.Or: ASTTransformer.build_short_circuit_or,
} if impl.get_runtime().short_circuit_operators else {
ast.And: ASTTransformer.build_normal_bool_op(ti_ops.logical_and),
ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or),
}
if ctx.is_in_static_scope:
ops = {
ast.And: ASTTransformer.build_static_short_circuit_and,
ast.Or: ASTTransformer.build_static_short_circuit_or,
}
elif impl.get_runtime().short_circuit_operators:
ops = {
ast.And: ASTTransformer.build_short_circuit_and,
ast.Or: ASTTransformer.build_short_circuit_or,
}
else:
ops = {
ast.And:
ASTTransformer.build_normal_bool_op(ti_ops.logical_and),
ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or),
}
op = ops.get(type(node.op))
node.ptr = op(node.values)
return node.ptr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,21 @@ def func() -> ti.i32:
return False or True

assert func() == 1


@ti.test(debug=True)
def test_static_or():
@ti.kernel
def func() -> ti.i32:
return ti.static(0 or 3 or 5)

assert func() == 3


@ti.test(debug=True)
def test_static_and():
@ti.kernel
def func() -> ti.i32:
return ti.static(5 and 2 and 0)

assert func() == 0

0 comments on commit ec54b89

Please sign in to comment.