Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Support PyTorch XLA as a frontend #945

Open
Kelvin-Ng opened this issue Jul 16, 2023 · 3 comments
Open

Support PyTorch XLA as a frontend #945

Kelvin-Ng opened this issue Jul 16, 2023 · 3 comments

Comments

@Kelvin-Ng
Copy link

Kelvin-Ng commented Jul 16, 2023

System information

  • Alpa version:
  • Are you willing to contribute it (Yes/No): Maybe

Describe the new feature and the current behavior/state

Most models in the transformers library, like LLaMa, do not have a Flax version. Alpa uses torch.fx.symbolic_trace to convert these models to graphs, but they have something dynamic and so it will not work.

Although the models have something dynamic, PyTorch XLA is actually able to run them. For example, here is the XLA IR for the LLaMa model from transformers: https://gist.github.com/Kelvin-Ng/6fa6ededead2a42a806a69fd4c932a3e

Supporting PyTorch XLA as a frontend will enable the use of many models from the transformers library.

Will this change the current API? How?

Describe alternatives you've considered

  1. Provide an API that allows manually supplying an XLA IR, so that I can just use PyTorch XLA to convert the models to XLA IR and then supply that into Alpa.

  2. Apparently torch.jit.trace also works. Is it possible for Alpa to accept output of torch.jit.trace instead of torch.fx.symbolic_trace?

Additional context

@richardliaw
Copy link

Hey @Kelvin-Ng , I think this makes a lot of sense. Does Alpa primarily need to accept the output of torch.jit.trace instead?

@Kelvin-Ng
Copy link
Author

Kelvin-Ng commented Aug 2, 2023

In fact I was proposing two alternatives: either use PyTorch XLA, or use torch.jit.trace.

I think using PyTorch XLA is a better option because it should be the most general -- anything that PyTorch can run will be supported. However, that may requires more modifications in Alpa, because I see that some pipeline parallelism stuff operates on JAX graph which we don't have for PyTorch XLA.

Or we can use torch.jit.trace (instead of symbolic_trace that Alpa currently uses). I guess we only need to modify the conversion code (in here I suppose: https://github.com/alpa-projects/alpa/blob/main/alpa/torch/nn/__init__.py#L22). It currently do conversion on the code generated by symbolic_trace, and we need to change it to support torch.jit.trace. However, I do not fully understand this piece of code, so I am not sure how to do that reliably, especially because the code generated by torch.jit.trace is more complicated than symbolic_trace according to the PyTorch documentation. However, the benefit of this approach is that no modification to the Alpa core code is necessary.

@gjoliver
Copy link
Member

gjoliver commented Aug 2, 2023

We should go the TorchXLA route if possible.
I think the current torch.jit implementation was largely a one-off effort and will be hard to maintain / not as scalable.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants