-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✈️ Introduce Jetstream/Pytorch in TGI #88
Changes from 18 commits
68c77df
b28ef47
c74900e
be56089
6db3c2c
02ffeea
8e98023
e3840e1
3ff7197
6c9348c
42ebaef
0af77a4
3d782ab
33bb7d4
3cc2ff8
e5e2fd4
aac4237
07a71db
b77a352
76fbf94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
import sys | ||
|
||
|
||
def jetstream_pt_available() -> bool: | ||
"""Check if the necessary imports to use jetstream_pt are available. | ||
""" | ||
try: | ||
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable. | ||
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1" | ||
if not jetstream_pt_enabled: | ||
return False | ||
# Torch XLA should not be imported before torch_xla2 to avoid conflicts. | ||
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules: | ||
return False | ||
# Import torch_xla2 first! | ||
import torch_xla2 # noqa: F401, isort:skip | ||
|
||
import jetstream_pt # noqa: F401 | ||
|
||
return True | ||
except ImportError: | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from .generator_base import Generator | ||
from .jetstream_pt_support import check | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
or something along this lines could be more descriptive. Possibly could just change the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed it |
||
|
||
|
||
class AutoGenerator: | ||
|
||
@staticmethod | ||
def from_pretrained( | ||
model_path: str, revision: str, max_batch_size: int, max_sequence_length: int | ||
) -> Generator: | ||
"""Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA. | ||
|
||
Args: | ||
model_path (`str`): | ||
The path to a local model. This path must also contain a Tokenizer. | ||
revision (`str`): | ||
The revision of the model. | ||
max_batch_size (`int`): | ||
The maximum batch size. | ||
max_sequence_length (`int`): | ||
The maximum sequence length. | ||
|
||
Returns: | ||
A TpuGenerator. | ||
""" | ||
if check(model_path): | ||
from .jetstream_pt_support.generator import TpuGeneratorJetStream | ||
return TpuGeneratorJetStream.from_pretrained( | ||
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length | ||
) | ||
else: | ||
from .generator import TpuGenerator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would useful to a user to log 1) when we have successfully loaded jetstream and 2) when we're falling back to the base generator |
||
return TpuGenerator.from_pretrained( | ||
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .compatibility import check, create_engine |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from typing import Any | ||
|
||
from transformers import AutoConfig | ||
|
||
from optimum.tpu import jetstream_pt_available | ||
|
||
|
||
def check(model_path: str) -> bool: | ||
"""Checks if the model is supported by Jetstream Pytorch on Optimum TPU and if the required dependencies to provide | ||
the engine are installed. | ||
""" | ||
config = AutoConfig.from_pretrained(model_path) | ||
# For now only Llama 2 with tokenizer.model is supported | ||
if config.model_type != "llama" or not os.path.exists( | ||
os.path.join(model_path, "tokenizer.model") | ||
): | ||
return False | ||
if jetstream_pt_available(): | ||
return True | ||
return False | ||
|
||
|
||
def create_engine( | ||
model_path: str, | ||
batch_size: int, | ||
sequence_length: int, | ||
max_input_tokens: int, | ||
max_output_tokens: int, | ||
) -> Any: | ||
if not check(model_path): | ||
# The model is not compatible with Jetstream PyTorch, just exit | ||
return None | ||
|
||
# Now import engine_loader to prevent importing it at the top when not supported | ||
from .engine_loader import create_engine | ||
return create_engine( | ||
model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would it make sense to emit a warning here? Like "
JETSTREAM_PT
is enabled, but torch_xla2 is not installed. Falling back to torch_xla".There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's actually a little trickier than that:
torch_xla
cannot be imported aftertorch_xla
has been imported. I will add a warning.