Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nightly builds on Mac ARM fail complex function numerical tests #24787

Open
hawkinsp opened this issue Nov 8, 2024 · 2 comments
Open

Nightly builds on Mac ARM fail complex function numerical tests #24787

hawkinsp opened this issue Nov 8, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 8, 2024

Description

On Mac ARM, the following functions are failing in the nightly jax/jaxlib build:

________ FunctionAccuracyTest.testSuccessOnComplexPlane_log1p_complex64 ________
[gw7] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   log1p in q1.imag, is_cpu=True is_cuda=False,
E   jax.numpy.log1p((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j) [(0.7578582763671875+0j)], expected (4.517187335295603e-08+1.1754943508222875e-38j) [(0.7578582763671875+1.9721522630525295e-31j)]
E   jax.numpy.log1p((4.517187335295603e-08+5.204541205073606e-31j)) -> (4.517187335295603e-08+5.204541205073606e-31j) [(0.7578582763671875+8.731771197842018e-24j)], expected (4.517187335295603e-08+5.204540734875866e-31j) [(0.7578582763671875+8.731770408981113e-24j)]
E   jax.numpy.log1p((1.0202490673854366e-15+1.1754943508222875e-38j)) -> (1.0202490673854366e-15+0j) [(0.5743491649627686+0j)], expected (1.0202490673854366e-15+1.1754943508222875e-38j) [(0.5743491649627686+6.617444900424222e-24j)]
E   jax.numpy.log1p((4.517187335295603e-08+2.3043281796057088e-23j)) -> (4.517187335295603e-08+2.3043281796057088e-23j) [(0.7578582763671875+3.866021160413177e-16j)], expected (4.517187335295603e-08+2.3043280218335278e-23j) [(0.7578582763671875+3.866020895715381e-16j)]
E   jax.numpy.log1p((2.3043281796057088e-23+1.1754943508222875e-38j)) -> (2.3043281796057088e-23+0j) [(0.8705505728721619+0j)], expected (2.3043281796057088e-23+1.1754943508222875e-38j) [(0.8705505728721619+4.440892098500626e-16j)]
E   jax.numpy.log1p((5.204541205073606e-31+1.1754943508222875e-38j)) -> (5.204541205073606e-31+0j) [(0.6597539782524109+0j)], expected (5.204541205073606e-31+1.1754943508222875e-38j) [(0.6597539782524109+1.4901161193847656e-08j)]
E   jax.numpy.log1p((4.517187335295603e-08+4.517187335295603e-08j)) -> (4.517187335295603e-08+4.517187335295603e-08j) [(0.37892913818359375+0.37892913818359375j)], expected (4.517187335295603e-08+4.517186980024235e-08j) [(0.37892913818359375+0.37892910838127136j)]
E   jax.numpy.log1p((4.517187335295603e-08+2j)) -> (0.8047189712524414+1.1071487665176392j) [(0.4023594856262207+0.5535743832588196j)], expected (0.8047189712524414+1.1071486473083496j) [(0.4023594856262207+0.5535743236541748j)]
E   jax.numpy.log1p((1.1754943508222875e-38+1.1754943508222875e-38j)) -> (1.1754943508222875e-38+0j) [(0.5+0j)], expected (1.1754943508222875e-38+1.1754943508222875e-38j) [(0.5+0.5j)]
E   Mismatched elements: 1 / 121 (0.826%)
E   Max absolute difference among violations: 0.5
E   Max relative difference among violations: 1.
E    ACTUAL: array([[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
E    DESIRED: array([[5.000000e-01, 1.490116e-08, 4.440892e-16, 6.617445e-24,
E           1.972152e-31, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_square_complex64 ________
[gw7] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   square in ninfj.real, is_cpu=True is_cuda=False,
E   jax.numpy.square((-1.735863837493982e+23-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E   jax.numpy.square((-3.4028234663852886e+38-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E   jax.numpy.square((-7.685595991398373e+30-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E   jax.numpy.square((1.735863837493982e+23-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E   jax.numpy.square((3.4028234663852886e+38-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E   jax.numpy.square((7.685595991398373e+30-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E   nan location mismatch:
E    ACTUAL: array([ nan,  nan,  nan,  nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,  nan,
E           nan,  nan,  nan], dtype=float32)
E    DESIRED: array([ nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E          -inf, -inf,  nan], dtype=float32)
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_arcsin_complex64 ________
[gw6] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   arcsin in q1.real, is_cpu=True is_cuda=False,
E   jax.numpy.arcsin((1.1754943508222875e-38+4.517187335295603e-08j)) -> 4.517187335295603e-08j [0.7578582763671875j], expected (1.1754943508222875e-38+4.517187335295603e-08j) [(1.9721522630525295e-31+0.7578582763671875j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+1.0202490673854366e-15j)) -> 1.0202490673854366e-15j [0.5743491649627686j], expected (1.1754943508222875e-38+1.0202490673854366e-15j) [(6.617444900424222e-24+0.5743491649627686j)]
E   jax.numpy.arcsin((1.0202490673854366e-15+2j)) -> (4.562692259942242e-16+1.4436354637145996j) [(2.281346129971121e-16+0.7218177318572998j)], expected (4.562692789337834e-16+1.4436354637145996j) [(2.281346394668917e-16+0.7218177318572998j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+2.3043281796057088e-23j)) -> 2.3043281796057088e-23j [0.8705505728721619j], expected (1.1754943508222875e-38+2.3043281796057088e-23j) [(4.440892098500626e-16+0.8705505728721619j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+5.204541205073606e-31j)) -> 5.204541205073606e-31j [0.6597539782524109j], expected (1.1754943508222875e-38+5.204541205073606e-31j) [(1.4901161193847656e-08+0.6597539782524109j)]
E   jax.numpy.arcsin((4.517187335295603e-08+4.517187335295603e-08j)) -> (4.5171876905669706e-08+4.517187335295603e-08j) [(0.37892916798591614+0.37892913818359375j)], expected (4.517187335295603e-08+4.517187335295603e-08j) [(0.37892913818359375+0.37892913818359375j)]
E   jax.numpy.arcsin((4.517187335295603e-08+1.0202490673854366e-15j)) -> (4.5171876905669706e-08+1.0202490673854366e-15j) [(0.7578583359718323+1.7116938977324025e-08j)], expected (4.517187335295603e-08+1.0202490673854366e-15j) [(0.7578582763671875+1.7116938977324025e-08j)]
E   jax.numpy.arcsin((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.5171876905669706e-08+1.175494490952134e-38j) [(0.7578583359718323+1.9721524981513997e-31j)], expected (4.517187335295603e-08+1.1754943508222875e-38j) [(0.7578582763671875+1.9721522630525295e-31j)]
E   jax.numpy.arcsin((4.517187335295603e-08+2.3043281796057088e-23j)) -> (4.5171876905669706e-08+2.30432833737789e-23j) [(0.7578583359718323+3.866021425110973e-16j)], expected (4.517187335295603e-08+2.3043281796057088e-23j) [(0.7578582763671875+3.866021160413177e-16j)]
E   jax.numpy.arcsin((4.517187335295603e-08+5.204541205073606e-31j)) -> (4.5171876905669706e-08+5.204541675271346e-31j) [(0.7578583359718323+8.731771986702923e-24j)], expected (4.517187335295603e-08+5.204541205073606e-31j) [(0.7578582763671875+8.731771197842018e-24j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+1.1754943508222875e-38j)) -> 1.1754943508222875e-38j [0.5j], expected (1.1754943508222875e-38+1.1754943508222875e-38j) [(0.5+0.5j)]
E   Mismatched elements: 1 / 121 (0.826%)
E   Max absolute difference among violations: 0.5
E   Max relative difference among violations: 1.
E    ACTUAL: array([[0.000000e+00, 6.597540e-01, 8.705506e-01, 5.743492e-01,
E           7.578583e-01, 3.926991e-01, 4.908739e-02, 2.454369e-02,
E           2.454369e-02, 1.227185e-02, 1.227185e-02],...
E    DESIRED: array([[5.000000e-01, 6.597540e-01, 8.705506e-01, 5.743492e-01,
E           7.578583e-01, 3.926991e-01, 4.908739e-02, 2.454369e-02,
E           2.454369e-02, 1.227185e-02, 1.227185e-02],...
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_arcsinh_complex64 _______
[gw6] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   arcsinh in q1.imag, is_cpu=True is_cuda=False,
E   jax.numpy.arcsinh((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j) [(0.7578582763671875+0j)], expected (4.517187335295603e-08+1.1754943508222875e-38j) [(0.7578582763671875+1.9721522630525295e-31j)]
E   jax.numpy.arcsinh((1.0202490673854366e-15+1.1754943508222875e-38j)) -> (1.0202490673854366e-15+0j) [(0.5743491649627686+0j)], expected (1.0202490673854366e-15+1.1754943508222875e-38j) [(0.5743491649627686+6.617444900424222e-24j)]
E   jax.numpy.arcsinh((2+1.0202490673854366e-15j)) -> (1.4436354637145996+4.562692259942242e-16j) [(0.7218177318572998+2.281346129971121e-16j)], expected (1.4436354637145996+4.562692789337834e-16j) [(0.7218177318572998+2.281346394668917e-16j)]
E   jax.numpy.arcsinh((2.3043281796057088e-23+1.1754943508222875e-38j)) -> (2.3043281796057088e-23+0j) [(0.8705505728721619+0j)], expected (2.3043281796057088e-23+1.1754943508222875e-38j) [(0.8705505728721619+4.440892098500626e-16j)]
E   jax.numpy.arcsinh((5.204541205073606e-31+1.1754943508222875e-38j)) -> (5.204541205073606e-31+0j) [(0.6597539782524109+0j)], expected (5.204541205073606e-31+1.1754943508222875e-38j) [(0.6597539782524109+1.4901161193847656e-08j)]
E   jax.numpy.arcsinh((4.517187335295603e-08+4.517187335295603e-08j)) -> (4.517187335295603e-08+4.5171876905669706e-08j) [(0.37892913818359375+0.37892916798591614j)], expected (4.517187335295603e-08+4.517187335295603e-08j) [(0.37892913818359375+0.37892913818359375j)]
E   jax.numpy.arcsinh((1.0202490673854366e-15+4.517187335295603e-08j)) -> (1.0202490673854366e-15+4.5171876905669706e-08j) [(1.7116938977324025e-08+0.7578583359718323j)], expected (1.0202490673854366e-15+4.517187335295603e-08j) [(1.7116938977324025e-08+0.7578582763671875j)]
E   jax.numpy.arcsinh((1.1754943508222875e-38+4.517187335295603e-08j)) -> (1.175494490952134e-38+4.5171876905669706e-08j) [(1.9721524981513997e-31+0.7578583359718323j)], expected (1.1754943508222875e-38+4.517187335295603e-08j) [(1.9721522630525295e-31+0.7578582763671875j)]
E   jax.numpy.arcsinh((2.3043281796057088e-23+4.517187335295603e-08j)) -> (2.30432833737789e-23+4.5171876905669706e-08j) [(3.866021425110973e-16+0.7578583359718323j)], expected (2.3043281796057088e-23+4.517187335295603e-08j) [(3.866021160413177e-16+0.7578582763671875j)]
E   jax.numpy.arcsinh((5.204541205073606e-31+4.517187335295603e-08j)) -> (5.204541675271346e-31+4.5171876905669706e-08j) [(8.731771986702923e-24+0.7578583359718323j)], expected (5.204541205073606e-31+4.517187335295603e-08j) [(8.731771197842018e-24+0.7578582763671875j)]
E   jax.numpy.arcsinh((1.1754943508222875e-38+1.1754943508222875e-38j)) -> (1.1754943508222875e-38+0j) [(0.5+0j)], expected (1.1754943508222875e-38+1.1754943508222875e-38j) [(0.5+0.5j)]
E   Mismatched elements: 1 / 121 (0.826%)
E   Max absolute difference among violations: 0.5
E   Max relative difference among violations: 1.
E    ACTUAL: array([[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
E    DESIRED: array([[5.000000e-01, 1.490116e-08, 4.440892e-16, 6.617445e-24,
E           1.972152e-31, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
=========================== short test summary info ============================
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_log1p_complex64
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_square_complex64
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_arcsin_complex64
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_arcsinh_complex64

@pearu can you PTAL?

System info (python version, jaxlib version, accelerator, etc.)

Mac ARM.

@hawkinsp hawkinsp added the bug Something isn't working label Nov 8, 2024
@hawkinsp
Copy link
Collaborator Author

hawkinsp commented Nov 8, 2024

There are also failures on Linux ARM:

=================================== FAILURES ===================================
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_square_complex64 ________
[gw15] linux -- Python 3.10.15 /usr/bin/python3.10

self = <lax_test.FunctionAccuracyTest testMethod=testSuccessOnComplexPlane_square_complex64>
name = 'square', dtype = <class 'numpy.complex64'>

    @parameterized.named_parameters(
      dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
      for name, dtype in itertools.product(
          _functions_on_complex_plane,
          jtu.dtypes.supported([np.complex64, np.complex128]),
      ))
    @jtu.skip_on_devices("tpu")
    def testSuccessOnComplexPlane(self, name, dtype):
>     self._testOnComplexPlaneWorker(name, dtype, 'success')

[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0xfffce5419480>, array([ nan,  nan,  nan,  nan, -inf, -inf, -inf, -inf,...inf,
       -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
       -inf, -inf,  nan], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': 'square in ninfj.real, is_cpu=True is_cuda=False,\njax.numpy.square((-1.735863837493982..., expected (-inf-infj) [(-inf-infj)]', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'strict': False, ...}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=1e-06
E           square in ninfj.real, is_cpu=True is_cuda=False,
E           jax.numpy.square((-1.735863837493982e+23-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E           jax.numpy.square((-3.4028234663852886e+38-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E           jax.numpy.square((-7.685595991398373e+30-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E           jax.numpy.square((1.735863837493982e+23-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E           jax.numpy.square((3.4028234663852886e+38-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E           jax.numpy.square((7.685595991398373e+30-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E           nan location mismatch:
E            ACTUAL: array([ nan,  nan,  nan,  nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E                  -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,  nan,
E                   nan,  nan,  nan], dtype=float32)
E            DESIRED: array([ nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E                  -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E                  -inf, -inf,  nan], dtype=float32)

/[usr/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/usr/lib/python3.10/contextlib.py?l=79): AssertionError
=========================== short test summary info ============================
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_square_complex64

@pearu
Copy link
Collaborator

pearu commented Nov 14, 2024

Update:

  • Add square_p #24874 (presumably) fixes tests for square
  • I cannot reproduce Mac ARM log1p, arcsin, and arcsinh test failures on Linux ARM. In all cases, there is only one sample that fails:
    jax.numpy.log1p((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j)
    jax.numpy.arcsin((1.1754943508222875e-38+4.517187335295603e-08j)) -> 4.517187335295603e-08j
    jax.numpy.arcsinh((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j)
    
    (all other samples are within allowed range of errors).
    The expected results are:
    >>> jax.numpy.log1p((4.517187335295603e-08+1.1754943508222875e-38j))
    Array(4.5171873e-08+1.1754944e-38j, dtype=complex64, weak_type=True)
    >>> jax.numpy.arcsin((1.1754943508222875e-38+4.517187335295603e-08j))
    Array(1.1754944e-38+4.5171873e-08j, dtype=complex64, weak_type=True)
    >>> jax.numpy.arcsinh((4.517187335295603e-08+1.1754943508222875e-38j))
    Array(4.5171873e-08+1.1754944e-38j, dtype=complex64, weak_type=True)
    It is possible that FTZ modes are different on Mac ARM and Linux ARM. Can this be verified? (see also feature request in argsort incorrectly handles very small floating-point numbers and -0.0 compared to PyTorch #24280)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants