Skip to content

Commit

Permalink
Fix the unstable behavior of 'test_where_warning'
Browse files Browse the repository at this point in the history
An unexpected 'test_where_warning' failure occurred if the test was restarted.
This happens because the compiler uses the cache and does not write a warning.
The environment parameter 'TRITON_ALWAYS_COMPILE' is now used to skip the cache.

Signed-off-by: Kirill Suvorov <[email protected]>
  • Loading branch information
Retribution98 committed Sep 18, 2024
1 parent f4c48a9 commit 769e49c
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,13 @@ def kernel():
c = tl.full((64, ), 2, tl.float32)
tl.where(a, b, c)

always_compile_default = os.getenv('TRITON_ALWAYS_COMPILE', '0')
# Set this variable because the cache should not be used
os.environ['TRITON_ALWAYS_COMPILE'] = '1'
with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
# Return the original value
os.environ['TRITON_ALWAYS_COMPILE'] = always_compile_default


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
Expand Down

0 comments on commit 769e49c

Please sign in to comment.