Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to parse backend config for custom call: Could not convert JSON string to proto: Expected : between key:value pair. #25389

Open
adam-hartshorne opened this issue Dec 10, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@adam-hartshorne
Copy link

adam-hartshorne commented Dec 10, 2024

Description

The following test code produces the aforementioned error when using JAX v0.4.37 (its works error free before this). This is running on Ubuntu 24.04 with Python 3.12 and using CUDA 12.4

def compute_svd_rot(m: jnp.ndarray, eps=1e-10) -> jnp.ndarray:
    """Maps 3x3 matrices onto SO(3) via symmetric orthogonalization.
    Source: Google research - https://github.com/google-research/google-research/blob/193eb9d7b643ee5064cb37fd8e6e3ecde78737dc/special_orthogonalization/utils.py#L93-L115
    """
    """
    m = jax.lax.cond(jnp.linalg.matrix_rank(m) < 3,
                     true_fun=lambda x: x + jnp.eye(3) * 1e-10,
                     false_fun=lambda x: x,
                     operand=m)
    """
    m_reg = m + jnp.eye(3) * eps
    U, _, Vh = jnp.linalg.svd(m_reg, full_matrices=False)
    det = jnp.linalg.det(jnp.matmul(U, Vh))
    return jnp.matmul(jnp.c_[U[:, :-1], U[:, -1] * det], Vh)

jit_compute_svd_rot = jit(compute_svd_rot)

# test compute_svd_rot, input is a (3,3) array
test_input = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
test_output = jit_compute_svd_rot(test_input)
print(test_output)

2024-12-10 23:49:08.560387: W external/xla/xla/service/gpu/ir_emitter_unnested.cc:1171] Unable to parse backend config for custom call: Could not convert JSON string to proto: Expected : between key:value pair.
= true, full_matrice
^
Fall back to parse the raw backend config str.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.37
jaxlib: 0.4.36
numpy:  1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
Driver Version: 550.135        CUDA Version: 12.4  
@adam-hartshorne adam-hartshorne added the bug Something isn't working label Dec 10, 2024
@dfm
Copy link
Collaborator

dfm commented Dec 11, 2024

Thanks for the report. This is known and I believe that's a warning, not an error. Are you seeing a runtime exception somewhere, or just extra logging?

cc @ezhulenev

@adam-hartshorne
Copy link
Author

Just the warning / extra logging.

@itk22
Copy link

itk22 commented Dec 14, 2024

Hi,
Is there any way to suppress this warning?

@superbobry
Copy link
Collaborator

This should be fixed in the nightly jaxlib. I don't know if there a way to suppress the warning otherwise, unfortunately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants