diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index 4e6712c43cf..265ae97e3e7 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -449,6 +449,16 @@ def distributed(request, local_rank, world_size): @pytest.hookimpl def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None: + if any(fx in pyfuncitem.fixturenames for fx in ["distributed", "multinode_distributed"]): + # Run distributed tests on a single worker to avoid RACE conditions + # This requires that the --dist=loadgroup option be passed to pytest. + pyfuncitem.add_marker(pytest.mark.xdist_group("distributed")) + # Add timeouts to prevent hanging + if "tpu" in pyfuncitem.fixturenames: + pyfuncitem.add_marker(pytest.mark.timeout(60)) + else: + pyfuncitem.add_marker(pytest.mark.timeout(45)) + if pyfuncitem.stash.get(is_horovod_stash_key, False): def testfunc_wrapper(test_func, **kwargs): @@ -498,14 +508,3 @@ def xla_worker(index, fn): assert ex_.code == 0, "Didn't successfully exit in XLA test" pyfuncitem.obj = functools.partial(testfunc_wrapper, pyfuncitem.obj) - - -def pytest_collection_modifyitems(items): - for item in items: - if "distributed" in item.fixturenames: - # Run distributed tests on a single worker to avoid RACE conditions - # This requires that the --dist=loadgroup option be passed to pytest. - item.add_marker(pytest.mark.xdist_group("distributed")) - item.add_marker(pytest.mark.timeout(45)) - if "multinode_distributed" in item.fixturenames: - item.add_marker(pytest.mark.timeout(45))