Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public API for mesh axis index size #25515

Open
gabbard opened this issue Dec 16, 2024 · 2 comments
Open

Public API for mesh axis index size #25515

gabbard opened this issue Dec 16, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@gabbard
Copy link

gabbard commented Dec 16, 2024

You can use jax.lax.axis_index(axis) to get your index along a mesh axis without needing access to the mesh object. However, there does not seem to be a way to get access to the size of a mesh axis through the public API.

We have been doing:

        tensor_parallel_size = jax.interpreters.pxla.thread_resources.env.shape["TP"]

but jax.interpreters is not part of the public API.

Should such a function be added? Or are we doing something wrong if we think we sometimes need it?

@gabbard gabbard added the enhancement New feature or request label Dec 16, 2024
@yashk2810
Copy link
Collaborator

No, there is no public API yet but we are working on it.

One thing you can do is create your own context manager and thread the mesh around and access it wherever you want.

@gabbard
Copy link
Author

gabbard commented Dec 16, 2024

Thanks, that's a good idea for a workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants