We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I tried to run tests/fft_test.py::FftTest::testFftfreq5 on g5.24xlarge instance which has 4 GPU devices.
tests/fft_test.py::FftTest::testFftfreq5
The test works fine if executed individually
pytest -s -v tests/fft_test.py::FftTest::testFftfreq5
But the test constantly failed if executed as part of tests/fft_test.py execution (even with a single pytest worker mode)
Problem that x1 and x2 arguments are on different devices - dev0 and dev3
pytest -n 1 -s -v tests/fft_test.py fun = <function true_divide at 0x796a1ebc3d00> jit_info = PjitInfo(fun_sourceinfo='true_divide at /home/ubuntu/workspace/jax/jax/_src/numpy/ufuncs.py:2292', fun_signature=<Sign...e, backend=None, keep_unused=False, inline=True, abstracted_axes=None, use_resource_env=False, compiler_options_kvs=()) args = (Array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)), kwargs = {} p = PjitParams(consts=[], params={'jaxpr': { lambda ; a:f32[10] b:f32[]. let c:f32[10] = div a b in (c,) }, 'in_shardings'...*), {})), out_tree=PyTreeDef(*), donated_invars=(False, False), arg_names=('x1', 'x2'), num_consts=0, attrs_tracked=[]) args_flat = [Array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)], arg = Array(1., dtype=float32) fails = [DeviceAssignmentMismatch(da=(CudaDevice(id=0),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None), DeviceAssignmentMismatch(da=(CudaDevice(id=3),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None)] api_name = 'jit', fun_name = 'true_divide' msg = 'Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU' def _python_pjit_helper(fun, jit_info, *args, **kwargs): p, args_flat = _infer_params(fun, jit_info, args, kwargs) for arg in args_flat: dispatch.check_arg(arg) if p.attrs_tracked: init_states = _get_states(p.attrs_tracked) args_flat = [*init_states, *args_flat] try: out_flat = pjit_p.bind(*args_flat, **p.params) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( fun_name, fails, args_flat, api_name, p.arg_names) > raise ValueError(msg) from None E ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU jax/_src/pjit.py:195: ValueError ================================================================================== short test summary info ================================================================================== FAILED tests/fft_test.py::FftTest::testFftfreq5 - ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divid... ========================================================================= 1 failed, 96 passed, 2 skipped in 13.86s ==========================================================================
>>> import jax; jax.print_environment_info() jax: 0.4.36.dev20241007+86038f84e jaxlib: 0.4.35 numpy: 2.1.3 python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] device info: NVIDIA A10G-4, 4 local devices" process_count: 1 platform: uname_result(system='Linux', node='ip-172-31-15-167', release='6.8.0-1018-aws', version='#19~22.04.1-Ubuntu SMP Wed Oct 9 16:48:22 UTC 2024', machine='x86_64') $ nvidia-smi Fri Nov 8 18:29:02 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA A10G Off | 00000000:00:1B.0 Off | 0 | | 0% 19C P0 29W / 300W | 259MiB / 23028MiB | 2% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA A10G Off | 00000000:00:1C.0 Off | 0 | | 0% 19C P0 27W / 300W | 259MiB / 23028MiB | 2% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA A10G Off | 00000000:00:1D.0 Off | 0 | | 0% 20C P0 29W / 300W | 259MiB / 23028MiB | 1% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 | | 0% 19C P0 27W / 300W | 259MiB / 23028MiB | 1% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 537624 C python3 250MiB | | 1 N/A N/A 537624 C python3 250MiB | | 2 N/A N/A 537624 C python3 250MiB | | 3 N/A N/A 537624 C python3 250MiB | +-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
We've been debugging this after seeing it in our own CI. Fix coming soon, hopefully.
Sorry, something went wrong.
Thank you! Another test which failed on 4GPUs setup is described here - #24796 @hawkinsp
dougalm
No branches or pull requests
Description
I tried to run
tests/fft_test.py::FftTest::testFftfreq5
on g5.24xlarge instance which has 4 GPU devices.The test works fine if executed individually
But the test constantly failed if executed as part of tests/fft_test.py execution (even with a single pytest worker mode)
Problem that x1 and x2 arguments are on different devices - dev0 and dev3
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: