Skip to content

Migrate torch_xla.device() to torch.device('xla') #9253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ghpvnist
Copy link
Collaborator

@ghpvnist ghpvnist commented May 27, 2025

fixes #9252

Copy link
Collaborator

@yaoshiang yaoshiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved pending successful ci/cd tests. Thank you for finally getting this fixed! As another issue, we should update all our docs to reflect this (the MDs not just the docstrings, which I think this PR covers). THANK YOU!!!

@ghpvnist ghpvnist force-pushed the 9252 branch 5 times, most recently from 2a8f9e9 to 0c01883 Compare June 9, 2025 20:38
@ghpvnist ghpvnist force-pushed the 9252 branch 8 times, most recently from b9ed0df to 2fd8ab0 Compare June 11, 2025 22:46
@ghpvnist
Copy link
Collaborator Author

So, after some digging I realized that the semantics of torch.device('xla') is not the same as torch_xla.device() in many cases. See torch.device docs. torch.device('xla') returns the device, but without the ordinal present like xla:X, it defaults to torch.cuda.current_device() assuming for gpus. However, torch_xla.device() has different behavior depending on some conditions: 1) with XLA_USE_SPMD env var set, torch_xla.device() will always return xla:0 2) torch_xla.device() will otherwise return torch.device(torch_xla._XLAC._xla_get_default_device()) which can be xla:0 or another ordinal depending on single vs multi process context. 3) if called via torch_xla.device(X) (equivalent to torch.device('xla:X')), it'll return xla:X. That said, I don't think we can fully remove the torch_xla.device api in favor of e.g. torch.device('xla:X') because torch_xla.device is a wrapper with more functionality than torch.device api.

@yaoshiang
Copy link
Collaborator

Thanks for this good analysis. Clearly the APIs are not the same, but, do they have enough expressiveness to do the same thing? setting an ordinal for device number appears supported. The one area that does not appear supported is SPMD. I wonder if torch.device("gspmd") would be the better way to do it rather than torch_xla.device() which always returns "xla:0".

There is room to get this right. The closest I can think of is torch.device("meta"). similarly, "gspmd" might represent a virtual device. That said, if we are going to try to redo the gSPMD stuff to align with DTensor, we should just worry about it then.

@ghpvnist
Copy link
Collaborator Author

If you know which ordinal you are querying for, both have the same expressiveness. Apart from SPMD, torch.device(torch_xla._XLAC._xla_get_default_device()) cannot be expressed by torch.device api since it requires knowing which ordinal you want (the default will return a device with no ordinal). The torch_xla.device() by default in non-SPMD context will return both the device and the ordinal value without explicitly querying for it. This is where the semantic differs between the two apis.

@yaoshiang
Copy link
Collaborator

Thanks. I don't know what is the setter for torch_xla._XLAC._xla_get_default_device(), but the fact that is in _XLAC but not exposed properly outside of it is already a redflag this part of the API is incomplete...

I wonder if this would allow us to have a get and set default device that is idiomatic.

https://docs.pytorch.org/docs/stable/generated/torch.set_default_device.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Migrate torch_xla.device() to torch.device('xla')
2 participants