Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flaky test tests/fft_test.py::FftTest::testFftfreq5 #24798

Open
apivovarov opened this issue Nov 8, 2024 · 2 comments
Open

Flaky test tests/fft_test.py::FftTest::testFftfreq5 #24798

apivovarov opened this issue Nov 8, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@apivovarov
Copy link
Contributor

Description

I tried to run tests/fft_test.py::FftTest::testFftfreq5 on g5.24xlarge instance which has 4 GPU devices.

The test works fine if executed individually

pytest -s -v tests/fft_test.py::FftTest::testFftfreq5

But the test constantly failed if executed as part of tests/fft_test.py execution (even with a single pytest worker mode)

Problem that x1 and x2 arguments are on different devices - dev0 and dev3

pytest -n 1 -s -v tests/fft_test.py

fun = <function true_divide at 0x796a1ebc3d00>
jit_info = PjitInfo(fun_sourceinfo='true_divide at /home/ubuntu/workspace/jax/jax/_src/numpy/ufuncs.py:2292', fun_signature=<Sign...e, backend=None, keep_unused=False, inline=True, abstracted_axes=None, use_resource_env=False, compiler_options_kvs=())
args = (Array([ 0.,  1.,  2.,  3.,  4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)), kwargs = {}
p = PjitParams(consts=[], params={'jaxpr': { lambda ; a:f32[10] b:f32[]. let c:f32[10] = div a b in (c,) }, 'in_shardings'...*), {})), out_tree=PyTreeDef(*), donated_invars=(False, False), arg_names=('x1', 'x2'), num_consts=0, attrs_tracked=[])
args_flat = [Array([ 0.,  1.,  2.,  3.,  4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)], arg = Array(1., dtype=float32)
fails = [DeviceAssignmentMismatch(da=(CudaDevice(id=0),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None), DeviceAssignmentMismatch(da=(CudaDevice(id=3),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None)]
api_name = 'jit', fun_name = 'true_divide'
msg = 'Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU'

    def _python_pjit_helper(fun, jit_info, *args, **kwargs):
      p, args_flat = _infer_params(fun, jit_info, args, kwargs)
    
      for arg in args_flat:
        dispatch.check_arg(arg)
    
      if p.attrs_tracked:
        init_states = _get_states(p.attrs_tracked)
        args_flat = [*init_states, *args_flat]
    
      try:
        out_flat = pjit_p.bind(*args_flat, **p.params)
      except pxla.DeviceAssignmentMismatchError as e:
        fails, = e.args
        api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
        fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
        msg = _device_assignment_mismatch_error(
            fun_name, fails, args_flat, api_name, p.arg_names)
>       raise ValueError(msg) from None
E       ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU

jax/_src/pjit.py:195: ValueError
================================================================================== short test summary info ==================================================================================
FAILED tests/fft_test.py::FftTest::testFftfreq5 - ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divid...
========================================================================= 1 failed, 96 passed, 2 skipped in 13.86s ==========================================================================

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.36.dev20241007+86038f84e
jaxlib: 0.4.35
numpy:  2.1.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
device info: NVIDIA A10G-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-15-167', release='6.8.0-1018-aws', version='#19~22.04.1-Ubuntu SMP Wed Oct  9 16:48:22 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Nov  8 18:29:02 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A10G                    Off |   00000000:00:1B.0 Off |                    0 |
|  0%   19C    P0             29W /  300W |     259MiB /  23028MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A10G                    Off |   00000000:00:1C.0 Off |                    0 |
|  0%   19C    P0             27W /  300W |     259MiB /  23028MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A10G                    Off |   00000000:00:1D.0 Off |                    0 |
|  0%   20C    P0             29W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A10G                    Off |   00000000:00:1E.0 Off |                    0 |
|  0%   19C    P0             27W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    537624      C   python3                                       250MiB |
|    1   N/A  N/A    537624      C   python3                                       250MiB |
|    2   N/A  N/A    537624      C   python3                                       250MiB |
|    3   N/A  N/A    537624      C   python3                                       250MiB |
+-----------------------------------------------------------------------------------------+
@apivovarov apivovarov added the bug Something isn't working label Nov 8, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 8, 2024

We've been debugging this after seeing it in our own CI. Fix coming soon, hopefully.

@apivovarov
Copy link
Contributor Author

Thank you! Another test which failed on 4GPUs setup is described here - #24796
@hawkinsp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants