From c76f7a12f915d6bf32ed68e29294f454e9dc2d8e Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Thu, 29 Aug 2024 18:48:48 +0100 Subject: [PATCH] [testing] Complete test_const with the rest of the failing tests. (#4599) As per title --- python/test/unit/language/test_core.py | 17 ++++++++++++----- python/triton/compiler/code_generator.py | 4 ++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9e5ff8a2ce37..c496ecff8e19 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3606,12 +3606,19 @@ def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.co with pytest.raises(triton.CompilationError) as exc_info: patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) if constexpr: - assert "Cannot store to a constant pointer" in str(exc_info.value.__cause__), "Wrong error message!" - elif not constexpr and mode == "call": - assert "Inconsistent return types" in str(exc_info.value.__cause__), "Wrong error message!" + error = "Cannot store to a constant pointer" else: - # TODO: Add error messages for the other cases - pass + if mode == "call": + error = "Inconsistent return types" + elif mode == "if": + error = "Mismatched type for final_out" + elif mode == "ternary": + error = "Ternary expression with dynamic condition has inconsistent type" + else: + assert mode == "direct" and choose_const + error = "Cannot store to a constant pointer" + error_msg = exc_info.value.error_message or str(exc_info.value.__cause__) + assert error in error_msg, "Wrong error message!" else: patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) assert torch.all(input == output) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ee3426bd4c65..199953df6570 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -609,7 +609,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): then_ty = then_defs[name].type else_ty = else_defs[name].type assert then_ty == else_ty, \ - f'mismatched type for {name} between then block ({then_ty}) '\ + f'Mismatched type for {name} between then block ({then_ty}) '\ f'and else block ({else_ty})' names.append(name) ret_types.append(then_ty) @@ -731,7 +731,7 @@ def visit_IfExp(self, node): self._set_insertion_point_and_loc(ip, last_loc) assert then_val.type == else_val.type, \ - f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' ret_type = then_val.type ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []