Skip to content

Commit

Permalink
[testing] Complete test_const with the rest of the failing tests. (#4599
Browse files Browse the repository at this point in the history
)

As per title
  • Loading branch information
lezcano committed Aug 29, 2024
1 parent 23c744b commit c76f7a1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
17 changes: 12 additions & 5 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3606,12 +3606,19 @@ def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.co
with pytest.raises(triton.CompilationError) as exc_info:
patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE)
if constexpr:
assert "Cannot store to a constant pointer" in str(exc_info.value.__cause__), "Wrong error message!"
elif not constexpr and mode == "call":
assert "Inconsistent return types" in str(exc_info.value.__cause__), "Wrong error message!"
error = "Cannot store to a constant pointer"
else:
# TODO: Add error messages for the other cases
pass
if mode == "call":
error = "Inconsistent return types"
elif mode == "if":
error = "Mismatched type for final_out"
elif mode == "ternary":
error = "Ternary expression with dynamic condition has inconsistent type"
else:
assert mode == "direct" and choose_const
error = "Cannot store to a constant pointer"
error_msg = exc_info.value.error_message or str(exc_info.value.__cause__)
assert error in error_msg, "Wrong error message!"
else:
patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE)
assert torch.all(input == output)
Expand Down
4 changes: 2 additions & 2 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
then_ty = then_defs[name].type
else_ty = else_defs[name].type
assert then_ty == else_ty, \
f'mismatched type for {name} between then block ({then_ty}) '\
f'Mismatched type for {name} between then block ({then_ty}) '\
f'and else block ({else_ty})'
names.append(name)
ret_types.append(then_ty)
Expand Down Expand Up @@ -731,7 +731,7 @@ def visit_IfExp(self, node):
self._set_insertion_point_and_loc(ip, last_loc)

assert then_val.type == else_val.type, \
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
ret_type = then_val.type

ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
Expand Down

0 comments on commit c76f7a1

Please sign in to comment.