Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✈️ Introduce Jetstream/Pytorch in TGI #88

Merged
merged 20 commits into from
Sep 9, 2024

Conversation

tengomucho
Copy link
Collaborator

What does this PR do?

This allows to use TGI with the meta-llama/Llama-2-7b-hf model using the Jetstream/Pytorch engine.
This should be the starting point for a more complete integration in the future. It is not ready yet to replace the legacy implementation, in particular because:

  • no other models have been tested, and some work is required for weights conversion;
  • when I tried other Llama2 models it did not work, I still need to investigate the reason, this should follow next;
  • further work should be done to simplify the coexistence of torch_xla and Jetstream/Pytorch. For now this feature is optional, in particular because the dependency installation is not that straightforward and execution of these two engines (Jetstream Pytorch and Pytorch/XLA) is exclusive.

Before submitting

  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

use packaging.version's parse instead of pkg_resources' parse_version.
The custom HfEngine contains functions that will allow for prefill and
generate functions to use custom sampling functions.
This implementation is equivalent to the torch_xla one, but uses the
Jetstream/Pytorch engine instead.
This way we can aboid trying to import torch xla.
This is just a way to provide a factory class method to create
Jetstream/Pytorch or Pytorch XLA generator.
There are still some issues related to some fine-tuned models, so for
now just enable only when JETSTREAM_PT is set.
For now it is possible to install dependency after optimum-tpu has been
instelled, issuing this command:

pip install "optimum-tpu[jetstream-pt]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
@tengomucho tengomucho force-pushed the introduce-jetstream-pytorch branch 5 times, most recently from aab4506 to c11d2bb Compare September 2, 2024 08:47
)
return tokens, true_length

def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share more insights on where the server take the request and call prefill?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. prefill receives the request from the model server interface, you can see the code here.
The model server is called by the TGI router. You can see more information about TGI architecture here.
Let me know if you need any more specific detail!

Copy link
Collaborator

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Alvaro, great work! I took a first pass and left some fairly minor comments.

return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it make sense to emit a warning here? Like "JETSTREAM_PT is enabled, but torch_xla2 is not installed. Falling back to torch_xla".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's actually a little trickier than that: torch_xla cannot be imported after torch_xla has been imported. I will add a warning.

@@ -0,0 +1,35 @@
from .generator_base import Generator
from .jetstream_pt_support import check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .jetstream_pt_support import check as should_use_jetstream

or something along this lines could be more descriptive. Possibly could just change the check def within jetstream_pt_support

Copy link
Collaborator Author

@tengomucho tengomucho Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it model_can_use_jetstream_pt.

model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length
)
else:
from .generator import TpuGenerator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would useful to a user to log 1) when we have successfully loaded jetstream and 2) when we're falling back to the base generator

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than re-implement llama/model_exportable.py, could we implement some type of parameter transformation logic instead? That would allow us to directly use jetstream_pt's code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I tried to do at first, but if we want to support models as they are defined in transformers, the simplest way is to extract the model parameters from the config file. In the model definition in transformers, for Llama some of the original parameters (hidden_dim, multiple_of and ffn_dim_multiplier) were combined in the intermediate_size variable. I could not see a trivial way to go back to the original values, That is why I ended up re-implementing FeedForward, and as a consequence I ended up modifying the other classes that use that. If you think about a a way to get the original parameters back in a reliable way, then I can drop most of this and just use jetstream_pt's code.

return len(self._tokens) == 0


class TpuGeneratorJetStream(Generator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One general comment (no need for change at this point), since this is essentially re-implementing the responsibility of JetStream's orchestrator as designed, this will lose out on features like disaggregated serving and will likely result in different performance

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about it, and I agree with you that using only the engine and not the orchestrator means we will end up with different performance results. The reason why I did this was the API: the engine API is similar to TGI's model_server, while the orchestrator is not meant to interact via a Python API, but rather through gRPC, and its interface is more similar to the one in the TGI router. So interfacing the orchestrator with TGI would mean taking the TGI requests, re-encode them as requests for the jetstream orchestrator and forward them, then re-transcode the responses. So yes, at some point we might need to look at a way to integrate those, but it seems more complicated and I think we can do that later.

from jetstream_pt import engine


class HfEngine(engine.PyTorchEngine):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General note (no need to respond to this within this PR), Ray support for multi-node currently lives within PyTorchRayEngine. So as is, this won't be able to take advantage of Ray multi-host. A few options:

  1. [within JetStream] Consolidate PyTorchRayEngine with PyTorchEngine - probably preferred since we saw issues rise because of the decoupled design (cc @FanhaiLu1)
  2. [within TGI] Create a RayHfEngine or use some type of mixin

- Added warning when trying to load torch_xla2 adter torch_xla
- renamed jetstream_pt_support.check to model_can_use_jetstream_pt

def __call__(self, logits: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
if self.temperature != 1.0:
logits = logits / self.temperature
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what happens if temp = 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question @miladm. In that the operation will give an array with [inf, -inf] values. The generation will still give some result, though probably not the one you would expect (in my case it was as if it was using greedy search).
BTW, you will have the same division in the Jetstream sampling code.

@tengomucho tengomucho merged commit fa24cc4 into main Sep 9, 2024
4 checks passed
@tengomucho tengomucho deleted the introduce-jetstream-pytorch branch September 9, 2024 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants