Skip to content

Commit

Permalink
Fix the unstable behavior of 'test_where_warning'
Browse files Browse the repository at this point in the history
An unexpected 'test_where_warning' failure occurred if the test was restarted.
This happens because the compiler uses the cache and does not write a warning.
The environment parameter 'TRITON_ALWAYS_COMPILE' is now used to skip the cache.

Signed-off-by: Kirill Suvorov <[email protected]>
  • Loading branch information
Retribution98 committed Sep 18, 2024
1 parent f4c48a9 commit a2d73b3
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def kernel(a=1, B: tl.constexpr = ""):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))


def test_where_warning():
def test_where_warning(monkeypatch):

@triton.jit
def kernel():
Expand All @@ -350,8 +350,10 @@ def kernel():
c = tl.full((64, ), 2, tl.float32)
tl.where(a, b, c)

with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
with monkeypatch.context() as m:
m.setenv("TRITON_ALWAYS_COMPILE", "1")
with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
Expand Down

0 comments on commit a2d73b3

Please sign in to comment.