You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The following test code produces the aforementioned error when using JAX v0.4.37 (its works error free before this). This is running on Ubuntu 24.04 with Python 3.12 and using CUDA 12.4
def compute_svd_rot(m: jnp.ndarray, eps=1e-10) -> jnp.ndarray:
"""Maps 3x3 matrices onto SO(3) via symmetric orthogonalization.
Source: Google research - https://github.com/google-research/google-research/blob/193eb9d7b643ee5064cb37fd8e6e3ecde78737dc/special_orthogonalization/utils.py#L93-L115
"""
"""
m = jax.lax.cond(jnp.linalg.matrix_rank(m) < 3,
true_fun=lambda x: x + jnp.eye(3) * 1e-10,
false_fun=lambda x: x,
operand=m)
"""
m_reg = m + jnp.eye(3) * eps
U, _, Vh = jnp.linalg.svd(m_reg, full_matrices=False)
det = jnp.linalg.det(jnp.matmul(U, Vh))
return jnp.matmul(jnp.c_[U[:, :-1], U[:, -1] * det], Vh)
jit_compute_svd_rot = jit(compute_svd_rot)
# test compute_svd_rot, input is a (3,3) array
test_input = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
test_output = jit_compute_svd_rot(test_input)
print(test_output)
2024-12-10 23:49:08.560387: W external/xla/xla/service/gpu/ir_emitter_unnested.cc:1171] Unable to parse backend config for custom call: Could not convert JSON string to proto: Expected : between key:value pair.
= true, full_matrice
^
Fall back to parse the raw backend config str.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.37
jaxlib: 0.4.36
numpy: 1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
Driver Version: 550.135 CUDA Version: 12.4
The text was updated successfully, but these errors were encountered:
Thanks for the report. This is known and I believe that's a warning, not an error. Are you seeing a runtime exception somewhere, or just extra logging?
Description
The following test code produces the aforementioned error when using JAX v0.4.37 (its works error free before this). This is running on Ubuntu 24.04 with Python 3.12 and using CUDA 12.4
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: