Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT errors when setting compile options #24869

Open
man2machine opened this issue Nov 12, 2024 · 0 comments
Open

AOT errors when setting compile options #24869

man2machine opened this issue Nov 12, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@man2machine
Copy link

Description

Here is a code example to produce the error I am getting:

import jax
import jax.numpy as jnp
import jax.lib.xla_bridge as xb
from jax.sharding import Mesh

num_devices = len(jax.devices())
device_mesh = np.array(jax.devices(), dtype=np.object_).reshape((num_devices, 1)) 
mesh = Mesh(devices=device_mesh, axis_names=('gpu_index', 'host_index'))  # type: ignore

def f(
    x: jax.Array
) -> tuple[jax.Array, jax.Array]:

    W = jnp.ones((4, 4))  # type: ignore
    z = jnp.dot(x, W)  # type: ignore

    return z, W


x = jnp.ones((8, 4))  # type: ignore

backend = xb.get_backend()

options = xb.get_compile_options(  # type: ignore
    num_replicas=device_mesh.shape[0],
    num_partitions=device_mesh.shape[1],
    device_assignment=device_mesh,
    use_auto_spmd_partitioning=True,
    auto_spmd_partitioning_mesh_shape=list(device_mesh.shape),
    auto_spmd_partitioning_mesh_ids=[d.id for d in device_mesh.flatten()]
)

input_dtype_struct = jax.ShapeDtypeStruct(x.shape, x.dtype)  # type: ignore

f_new = jax.jit(f).lower(input_dtype_struct).compile(compiler_options=options)  # type: ignore

When running this I get the error

  File ".../site-packages/jax/_src/interpreters/pxla.py", line 2786, in from_hlo
    compiler_options.keys()) if compiler_options is not None else None
    ^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'jaxlib.xla_extension.CompileOptions' object has no attribute 'keys'

If I try putting the compiler options as a dictionary instead, I keep getting

jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: No such compile option: 'num_replicas'

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 16:05:46) [GCC 13.3.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='hgx', release='5.4.0-144-generic', version='#161-Ubuntu SMP Fri Feb 3 14:49:04 UTC 2023', machine='x86_64')


$ nvidia-smi
Tue Nov 12 14:03:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:07:00.0 Off |                    0 |
| N/A   31C    P0              66W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000000:0A:00.0 Off |                    0 |
| N/A   28C    P0              64W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM4-80GB          On  | 00000000:44:00.0 Off |                    0 |
| N/A   29C    P0              69W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM4-80GB          On  | 00000000:4A:00.0 Off |                    0 |
| N/A   32C    P0              67W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM4-80GB          On  | 00000000:84:00.0 Off |                    0 |
| N/A   32C    P0              67W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM4-80GB          On  | 00000000:8A:00.0 Off |                    0 |
| N/A   29C    P0              65W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM4-80GB          On  | 00000000:C0:00.0 Off |                    0 |
| N/A   29C    P0              71W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM4-80GB          On  | 00000000:C3:00.0 Off |                    0 |
| N/A   31C    P0              63W / 350W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    404443      C   python                                      416MiB |
|    1   N/A  N/A    404443      C   python                                      416MiB |
|    2   N/A  N/A    404443      C   python                                      416MiB |
|    3   N/A  N/A    404443      C   python                                      416MiB |
|    4   N/A  N/A    404443      C   python                                      416MiB |
|    5   N/A  N/A    404443      C   python                                      416MiB |
|    6   N/A  N/A    404443      C   python                                      416MiB |
|    7   N/A  N/A    404443      C   python                                      416MiB |
+---------------------------------------------------------------------------------------+
@man2machine man2machine added the bug Something isn't working label Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant