You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
Describe the new feature and the current behavior/state
Most models in the transformers library, like LLaMa, do not have a Flax version. Alpa uses torch.fx.symbolic_trace to convert these models to graphs, but they have something dynamic and so it will not work.
Supporting PyTorch XLA as a frontend will enable the use of many models from the transformers library.
Will this change the current API? How?
Describe alternatives you've considered
Provide an API that allows manually supplying an XLA IR, so that I can just use PyTorch XLA to convert the models to XLA IR and then supply that into Alpa.
Apparently torch.jit.trace also works. Is it possible for Alpa to accept output of torch.jit.trace instead of torch.fx.symbolic_trace?
Additional context
The text was updated successfully, but these errors were encountered:
In fact I was proposing two alternatives: either use PyTorch XLA, or use torch.jit.trace.
I think using PyTorch XLA is a better option because it should be the most general -- anything that PyTorch can run will be supported. However, that may requires more modifications in Alpa, because I see that some pipeline parallelism stuff operates on JAX graph which we don't have for PyTorch XLA.
Or we can use torch.jit.trace (instead of symbolic_trace that Alpa currently uses). I guess we only need to modify the conversion code (in here I suppose: https://github.com/alpa-projects/alpa/blob/main/alpa/torch/nn/__init__.py#L22). It currently do conversion on the code generated by symbolic_trace, and we need to change it to support torch.jit.trace. However, I do not fully understand this piece of code, so I am not sure how to do that reliably, especially because the code generated by torch.jit.trace is more complicated than symbolic_trace according to the PyTorch documentation. However, the benefit of this approach is that no modification to the Alpa core code is necessary.
We should go the TorchXLA route if possible.
I think the current torch.jit implementation was largely a one-off effort and will be hard to maintain / not as scalable.
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
System information
Describe the new feature and the current behavior/state
Most models in the transformers library, like LLaMa, do not have a Flax version. Alpa uses torch.fx.symbolic_trace to convert these models to graphs, but they have something dynamic and so it will not work.
Although the models have something dynamic, PyTorch XLA is actually able to run them. For example, here is the XLA IR for the LLaMa model from transformers: https://gist.github.com/Kelvin-Ng/6fa6ededead2a42a806a69fd4c932a3e
Supporting PyTorch XLA as a frontend will enable the use of many models from the transformers library.
Will this change the current API? How?
Describe alternatives you've considered
Provide an API that allows manually supplying an XLA IR, so that I can just use PyTorch XLA to convert the models to XLA IR and then supply that into Alpa.
Apparently torch.jit.trace also works. Is it possible for Alpa to accept output of torch.jit.trace instead of torch.fx.symbolic_trace?
Additional context
The text was updated successfully, but these errors were encountered: