From 72734f086b3a70a0399b7e9d21b83d5d8dc7e1d5 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 30 Jul 2024 16:10:04 -0400 Subject: [PATCH] [TEST] Remove unnecessary capability query in test_subproc.py (#4432) It's an unused parameter. Seems like it was used in the first iteration of this test, but that changed along the way. Removing this query makes the test portable across backends. --- python/test/unit/runtime/test_subproc.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 9599cbd6adab..7240fb7bb562 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -1,8 +1,6 @@ import multiprocessing import shutil -import torch - import triton import triton.language as tl from triton.compiler import ASTSource @@ -10,7 +8,7 @@ target = triton.runtime.driver.active.get_current_target() -def compile_fn(attrs, capability): +def compile_fn(attrs): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): @@ -27,18 +25,15 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: - major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) - multiprocessing.set_start_method('fork') - proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc = multiprocessing.Process(target=compile_fn, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 -def compile_fn_dot(attrs, capability): +def compile_fn_dot(attrs): @triton.jit def kernel_dot(Z): @@ -52,12 +47,9 @@ def kernel_dot(Z): def test_compile_in_forked_subproc(fresh_triton_cache) -> None: - major, minor = torch.cuda.get_device_capability(0) - capability = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0