Skip to content

Commit

Permalink
Correct extra token, start preparing docker image for TGI/Jetstream Pt (
Browse files Browse the repository at this point in the history
#93)

* fix(Jetstream Pt): remove extra token in decode

Jetstream's `generate` function returns input token as result token. The
next token is instead available in the decode_state, so this change uses
this instead.

* fix(engine): set batch_size and sequence_length

* fix(Jetstream PT): correct warmup internal params

* test(tgi): added a warmup test

* chore(jetstream pt): check input type in decode

* fix(token selector): seed can be a very large number

Before, we could have an error is the seed was bigger than a 64 bit
number.

* fix(Jetstream PT): handle slot's seed in a clean way

* feat(docker): TGI image now include Jetstream Pytorch dependencies

This allows testing TGI images with Jetstream Pytorch.

* fix(Jetstream Pt): batch returned in prefill initialized to None

This is required when there are no more tokens generated after prefill.

* feat(Jetstream Pt): speed-up prefill by avoiding redundant compilation

A new slot is created at each prefill request, and its selector is
passed as argument to a jitted function. The problem is that each new
slot has a new signature, even if the contents are the same. The
solution is to wrap that in a singleton slot object for the prefill, so
the compiler will always see the same object and stop recompiling.

* chore(generator): use prefill bucket sizes defined in Jetstream
  • Loading branch information
tengomucho committed Sep 17, 2024
1 parent 03b6573 commit 4265e13
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 32 deletions.
5 changes: 4 additions & 1 deletion text-generation-inference/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ COPY . /opt/optimum-tpu

# Install requirements for optimum-tpu, then for TGI then optimum-tpu
RUN python3 -m pip install hf_transfer safetensors==${SAFETENSORS_VERSION} && \
python3 -m pip install -e /opt/optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html
python3 -m pip install -e /opt/optimum-tpu[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html

# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def create_engine(

env = JetEngineEnvironment(env_data)
model = instantiate_model_from_repo_id(model_path, env)
# Update config with engine data
model.config.batch_size = batch_size
model.config.sequence_length = sequence_length

weight_shardings = model.get_sharding_annotations()
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch
import torch_xla2
from jetstream.engine.token_utils import pad_tokens, take_nearest_length
from jetstream.engine.token_utils import pad_tokens, take_nearest_length, DEFAULT_PREFILL_BUCKETS
from jetstream_pt.engine import PyTorchEngine
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand All @@ -36,20 +36,6 @@
optimum_logger = logging.getLogger("optimum.tpu")
optimum_logger.setLevel("CRITICAL")

# These will do some bucketing on prefill lengths to avoid too many different sizes
PREFILL_LENGTHS = [
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
32768,
]

class Slot:
"""Represents a slot in a static batch"""
Expand Down Expand Up @@ -78,6 +64,7 @@ def clear(self):
self._generated_text = ""
self._next_text = ""
self._truncate = 0
self._seed = 0

@property
def id(self) -> int:
Expand Down Expand Up @@ -134,7 +121,7 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC
self._generation_config.do_sample = request.parameters.do_sample
self._generation_config.repetition_penalty = request.parameters.repetition_penalty
self._truncate = request.truncate
self.seed = request.parameters.seed
self._seed = request.parameters.seed
# TODO: watermark
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
self._max_new_tokens = self._generation_config.max_new_tokens
Expand Down Expand Up @@ -237,6 +224,20 @@ def next_token(self) -> int:
def empty(self) -> bool:
return len(self._tokens) == 0

@property
def seed(self) -> int:
return self._seed


class PrefillSlot:
def __init__(self):
self._curslot = None

def set(self, slot: Slot):
self._curslot = slot

def select(self, logits: jnp.ndarray) -> int:
return self._curslot.select(logits)

class TpuGeneratorJetStream(Generator):
"""A Generator for models running on TPU, single threaded."""
Expand Down Expand Up @@ -267,6 +268,7 @@ def __init__(
self.batch_id = 0
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
self.prefill_slot = PrefillSlot()

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -328,31 +330,30 @@ def warmup(self, batch: Batch) -> int:
# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
bucket_seq_len = take_nearest_length(PREFILL_LENGTHS, seq_len)
dummy_request = self._create_dummy_request(seq_len)
bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, seq_len)
decode_done = False
for l in reversed(PREFILL_LENGTHS):
for l in reversed(DEFAULT_PREFILL_BUCKETS):
# Skip all the unsupported lengths
if l > bucket_seq_len:
continue
# Set all truncate values for all requests
dummy_request.truncate = l
dummy_request.stopping_parameters.max_new_tokens = 10
# create a dummy request with the current sequence length
dummy_request = self._create_dummy_request(l)
# We define few max_new_tokens to request at least one (by prefill) and another by decode.
MAX_NEW_TOKENS = 10
dummy_request.stopping_parameters.max_new_tokens = MAX_NEW_TOKENS
warmup_batch = Batch(id=0,
requests=[dummy_request],
size=1,
max_tokens=batch.max_tokens)
logger.debug(f"Warmup for requests, len {l} seq_len {seq_len}")
_generations, next_batch = self.prefill(warmup_batch)
if not decode_done and next_batch is not None:
if next_batch is not None:
self.decode([next_batch])
decode_done = True
self.clear()
if not decode_done:
logger.debug("No decode done during warmup")

self.prefill(batch)
self.clear()
elapsed = time.time() - start
logger.debug(f"Warmup done, took {elapsed:.2f}s")
seq_len = self.engine.env.seq_len
Expand Down Expand Up @@ -390,11 +391,13 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
max_length=max_length,
add_special_tokens=False,
)
# max_prefill_length must be a power of 2
max_prefill_length = take_nearest_length(DEFAULT_PREFILL_BUCKETS, self.model.config.sequence_length)
tokens, true_length = pad_tokens(input_ids[0],
self.tokenizer.bos_token_id,
self.tokenizer.pad_token_id,
is_bos=True,
max_prefill_length=self.model.config.sequence_length,
max_prefill_length=max_prefill_length,
jax_padding=True,
)
return tokens, true_length
Expand Down Expand Up @@ -436,6 +439,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self._get_slot_id(), self.tokenizer)
self.prefill_slot.set(slot)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
Expand All @@ -452,7 +456,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
)
slot.reset(truncated_input_ids, selector)
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(slot.select)
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
prefill_results, _result_tokens = self.engine.prefill(
params=self.params,
Expand All @@ -469,6 +473,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
self.slots.append(slot)
len_active_slots += 1

batch = None
if len_active_slots > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
Expand Down Expand Up @@ -499,6 +504,13 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""

# In python we should use type duck, but if elements passed on the list are not of the right type, this will
# prevent raising an error and wasting time. Return an empty generation instead.
if any(not isinstance(item, CachedBatch) for item in batches):
logger.error("Unexpected type in decode, expected CachedBatch")
return [], None

# batches contains a list composed of ongoing requests:
# - the batch id returned by the last decode,
# - the batch id(s) returned by the last prefill(s)
Expand Down Expand Up @@ -532,7 +544,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# Get the next token.
# Note that for now we ignore is_valid and length as we don't use them, we will re-parse these in post
# generation.
next_token, _is_valid, _length = result_tokens.data[slot.id]
next_token = self.decode_state.tokens[slot.id].item()

if slot.state != Slot.State.READY:
logger.error(f"Unexpected Slot {slot.id} is not ready for decoding, skipping.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
# Seed needs to fit a 64-bit integer, so we modulo it in case is bigger (that can happen!)
seed = seed % jnp.iinfo(jnp.int64).max
self.key = jax.random.PRNGKey(seed)

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def _test_decode_single(params):
DecodeTestParams(
model_id="meta-llama/Llama-2-7b-hf",
sequence_length=256,
expected_text="\n\nThe clocks were striking thirteen\nThe clocks were striking thirteen\n",
expected_text="\nThe clocks were striking thirteen\nThe clocks were striking thirteen\nThe",
),
DecodeTestParams(
model_id="meta-llama/Meta-Llama-3-8B",
sequence_length=256,
expected_text=" Winston Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, M",
expected_text=" Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, Minit",
),
],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"],
Expand All @@ -123,7 +123,7 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
DecodeTestParams(
model_id="Maykeye/TinyLLama-v0",
sequence_length=256,
expected_text=" She She had a big and it had a big, blue, and a big, red and a",
expected_text=" She had a big and it had a big, blue, and a big, red and a big",
),
],
ids=["TinyLLama-v0"],
Expand Down
30 changes: 30 additions & 0 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@


import pytest
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


def test_warmup_jetstream_pytorch():
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
model_id = "Maykeye/TinyLLama-v0"

# The maximum sequence length of the model is set to 1000, but warmup will round that up to the next power of two
# in prefill (1024).
sequence_length = 1000

model_path = prepare_model(model_id, sequence_length)
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = 20

generator = AutoGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length
)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generator.warmup(batch)

0 comments on commit 4265e13

Please sign in to comment.