Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✈️ Introduce Jetstream/Pytorch in TGI #88

Merged
merged 20 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
68c77df
feat(tgi): add functions to load Jetstream Pytorch engine for Llama2
tengomucho Jul 26, 2024
b28ef47
chore(TokenSelector): remove XLA xm rng seed set
tengomucho Aug 9, 2024
c74900e
fix(version): remove warning on deprecated API
tengomucho Aug 13, 2024
be56089
fix(generator): use pad_token_id for padding
tengomucho Aug 19, 2024
6db3c2c
fix(decode): clear unrequested slots
tengomucho Aug 19, 2024
02ffeea
feat(imports): add function to check if Jetstream Pytorch can be used
tengomucho Aug 26, 2024
8e98023
feat(Jetstream): improved support for engine load
tengomucho Aug 26, 2024
e3840e1
feat(TGI): Added Jetstream/Pytorch generator
tengomucho Aug 26, 2024
3ff7197
chore(fsdp v2): avoid importing PretrainedModel
tengomucho Aug 26, 2024
6c9348c
feat(tgi): introduce AutoGenerator
tengomucho Aug 26, 2024
42ebaef
feat(Jetstream PT): Enable support only if env var is set
tengomucho Aug 29, 2024
0af77a4
feat(TGI): use AutoGenerator in model server
tengomucho Aug 29, 2024
3d782ab
feat(package): add optional dependency on Jetstream/Pytorch
tengomucho Aug 29, 2024
33bb7d4
test(Jetstream Pytorch): added a simple decode test
tengomucho Aug 29, 2024
3cc2ff8
test(decode): added a variant with do_sample=True with Jetstream PT
tengomucho Aug 29, 2024
e5e2fd4
fix(README): correct link
tengomucho Aug 29, 2024
aac4237
doc(README): add mention on how to install and enable Pytorch/Jetstream
tengomucho Aug 29, 2024
07a71db
feat(build): make clean removes old TGI builds too
tengomucho Aug 29, 2024
b77a352
review: comply to comments requests
tengomucho Sep 6, 2024
76fbf94
review(AutoGenerator): log if using Jetstream/PT or torch xla
tengomucho Sep 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,13 @@ jobs:
- name: Build and test TGI server
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test

# Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu]
- name: Install and test TGI server (Jetstream Pytorch)
run: |
pip install -U .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests -k jetstream
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)

clean:
rm -rf dist
make -C text-generation-inference/server/ clean

tpu-tgi:
docker build --rm -f text-generation-inference/docker/Dockerfile \
Expand Down
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,20 @@ Other TPU versions will be supported along the way.
As part of the integration, we do support a [text-generation-inference (TGI)](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference) backend allowing to deploy and serve
incoming HTTP requests and execute them on Cloud TPUs.

Please see the [TGI specific documentation]() on how to get started
Please see the [TGI specific documentation](text-generation-inference) on how to get started.

### JetStream Pytorch Engine

`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command:

```shell
pip install "optimum-tpu[jetstream-pt]" \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
```

To enable the support, export the environment variable `JETSTREAM_PT=1`.

## Training

Expand Down
1 change: 1 addition & 0 deletions optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .jetstream_pt_support import jetstream_pt_available # isort:skip
from .fsdp_v2 import get_fsdp_config, use_fsdp_v2
from .modeling import AutoModelForCausalLM
from .version import VERSION, __version__
18 changes: 13 additions & 5 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
"""
Utility functions to provide FSDPv2 configuration for TPU training.
"""
from typing import Dict, List, Union
from typing import Any, Dict, List, Union

from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging


PreTrainedModel = Any
# NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type
# hinting, it is not imported here, so to avoid pulling imports of torch_xla.


def use_fsdp_v2():
"""
Enable FSDPv2 for TPU training.
Expand Down Expand Up @@ -61,6 +65,7 @@ def _unwrap_model(model: PreTrainedModel) -> PreTrainedModel:
"""
try:
from peft.peft_model import LoraModel, PeftModel

if isinstance(model, PeftModel) and isinstance(model.base_model, LoraModel):
return model.base_model.model
return model
Expand Down Expand Up @@ -89,10 +94,13 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
logger = logging.get_logger(__name__)
from torch_xla import __version__ as xla_version

if xla_version == "2.3.0":
logger.warning_once("Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any "
"issues consider using the nightly version, and report the issue on the optimum-tpu "
"GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new.")
logger.warning_once(
"Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any "
"issues consider using the nightly version, and report the issue on the optimum-tpu "
"GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new."
)
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
Expand Down
2 changes: 0 additions & 2 deletions optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional, Union

import torch
import torch_xla.core.xla_model as xm
from transformers.generation import (
GenerationConfig,
GenerationMixin,
Expand Down Expand Up @@ -53,7 +52,6 @@ def __init__(
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
xm.set_rng_state(seed)
self.generator = torch.Generator()
self.generator.manual_seed(seed)

Expand Down
26 changes: 26 additions & 0 deletions optimum/tpu/jetstream_pt_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import sys

from loguru import logger


def jetstream_pt_available() -> bool:
"""Check if the necessary imports to use jetstream_pt are available.
"""
try:
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
if not jetstream_pt_enabled:
return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
logger.warning("torch_xla2 cannot be imported after torch_xla, disabling Jetstream PyTorch support.")
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it make sense to emit a warning here? Like "JETSTREAM_PT is enabled, but torch_xla2 is not installed. Falling back to torch_xla".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's actually a little trickier than that: torch_xla cannot be imported after torch_xla has been imported. I will add a warning.

# Import torch_xla2 first!
import torch_xla2 # noqa: F401, isort:skip

import jetstream_pt # noqa: F401

return True
except ImportError:
return False
4 changes: 2 additions & 2 deletions optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pkg_resources import parse_version
from packaging.version import parse


__version__ = "0.1.5"
VERSION = parse_version(__version__)
VERSION = parse(__version__)
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ keywords = [

dependencies = [
"transformers == 4.41.1",
"torch >= 2.3.0, <= 2.4.0",
"torch-xla[tpu] >= 2.3.0, <= 2.4.0",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
"loguru == 0.6.0",
"sentencepiece == 0.2.0",
]
Expand All @@ -58,6 +58,12 @@ build-backend = "setuptools.build_meta"
[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git#df92015289953c506004e674d57651b03e4e89f2",
"torch-xla[pallas] == 2.4.0"
]

[project.urls]
Homepage = "https://hf.co/hardware"
Expand Down
4 changes: 3 additions & 1 deletion text-generation-inference/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ clean:

# List static sources to be deployed in the package
src_dir := $(mkfile_dir)/$(pkg_name)
sources := $(wildcard $(src_dir)/*.py)
rwildcard_py = $(wildcard $(1)/*.py) $(foreach d,$(wildcard $(1)/*),$(call rwildcard_py,$d))
sources := $(call rwildcard_py,$(src_dir))
deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources))

# Static files are just copied

define COPY
mkdir -p $(dir $@)
cp -f $< $@
endef

Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
]

[tool.setuptools]
packages = ["text_generation_server", "text_generation_server.pb"]
packages = ["text_generation_server", "text_generation_server.pb", "text_generation_server.jetstream_pt_support"]

[tool.setuptools.dynamic]
version = {attr = "text_generation_server.version.__version__"}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from loguru import logger

from .generator_base import Generator
from .jetstream_pt_support import model_can_use_jetstream_pt


class AutoGenerator:

@staticmethod
def from_pretrained(
model_path: str, revision: str, max_batch_size: int, max_sequence_length: int
) -> Generator:
"""Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA.

Args:
model_path (`str`):
The path to a local model. This path must also contain a Tokenizer.
revision (`str`):
The revision of the model.
max_batch_size (`int`):
The maximum batch size.
max_sequence_length (`int`):
The maximum sequence length.

Returns:
A TpuGenerator.
"""
if model_can_use_jetstream_pt(model_path):
logger.debug("Using Jetstream PyTorch generator.")
from .jetstream_pt_support.generator import TpuGeneratorJetStream
return TpuGeneratorJetStream.from_pretrained(
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length
)
else:
logger.debug("Using PyTorch/XLA generator.")
from .generator import TpuGenerator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would useful to a user to log 1) when we have successfully loaded jetstream and 2) when we're falling back to the base generator

return TpuGenerator.from_pretrained(
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length
)
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,19 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) < len(request_ids):
logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
Expand All @@ -608,7 +621,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[batch_size, 1],
fill_value=self.tokenizer.eos_token_id,
fill_value=pad_token_id,
dtype=torch.int64,
)
cache_position = torch.zeros([1], dtype=torch.int64)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .compatibility import create_engine, model_can_use_jetstream_pt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any

from transformers import AutoConfig

from optimum.tpu import jetstream_pt_available


def model_can_use_jetstream_pt(model_path: str) -> bool:
"""Checks if the model is supported by Jetstream Pytorch on Optimum TPU and if the required dependencies to provide
the engine are installed.
"""
config = AutoConfig.from_pretrained(model_path)
# For now only Llama 2 with tokenizer.model is supported
if config.model_type != "llama" or not os.path.exists(
os.path.join(model_path, "tokenizer.model")
):
return False
if jetstream_pt_available():
return True
return False


def create_engine(
model_path: str,
batch_size: int,
sequence_length: int,
max_input_tokens: int,
max_output_tokens: int,
) -> Any:
if not model_can_use_jetstream_pt(model_path):
# The model is not compatible with Jetstream PyTorch, just exit
return None

# Now import engine_loader to prevent importing it at the top when not supported
from .engine_loader import create_engine
return create_engine(
model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens
)
Loading
Loading