-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama support #32
Llama support #32
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Does "sharded llama models" imply that this could cover 70b as well? |
For now I have only tested with an 8B model on a single-host TPU v5e 4x2. I reckon we would need something bigger to test the 70b model. |
Do you need help with capacity? |
imported from transformers v4.40.1.
This essentially copies commit 8a4a98d2472b8e0180eb9bd4a1824f983e220811 from optimum-neuron, that fixed the same problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 👏🏻
|
||
@pytest.mark.slow | ||
def test_distributed_model_prefill_llama3_8b(): | ||
_test_distributed_model_prefill("meta-llama/Meta-Llama-3-8B") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we start parametrize the test rather than dupplicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should definitely do a bit of refactoring on the tests to avoid duplicaiton... will do later if you agree.
pad_token_id = self.tokenizer.pad_token_id | ||
if pad_token_id is None: | ||
if isinstance(self.tokenizer.eos_token_id, list): | ||
pad_token_id = self.tokenizer.eos_token_id[0] | ||
else: | ||
pad_token_id = self.tokenizer.eos_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block doesn't seem to depend on i
should we move it outside the for loop?
What does this PR do?
This add support for sharded Llama models on TPU, tested on TPU v5e litepod-8.
A test that show inference with Llama3-8b on TGI has been added.
Before submitting