Skip to content

Commit

Permalink
fix(warmup): make warmup work for smallest prefill size
Browse files Browse the repository at this point in the history
Also, add timing checks in warmup test.
  • Loading branch information
tengomucho committed Sep 20, 2024
1 parent 93ca260 commit a49d3af
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ def warmup(self, batch: Batch) -> int:
# Skip all the unsupported lengths
if l > bucket_seq_len:
continue
# create a dummy request with the current sequence length
dummy_request = self._create_dummy_request(l)
# create a dummy request with the current sequence length -1 (so it gets padded up to l)
dummy_request = self._create_dummy_request(l - 1)
# We define few max_new_tokens to request at least one (by prefill) and another by decode.
MAX_NEW_TOKENS = 10
dummy_request.stopping_parameters.max_new_tokens = MAX_NEW_TOKENS
Expand Down
28 changes: 25 additions & 3 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@


from time import time

import pytest
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
Expand All @@ -14,17 +16,37 @@ def test_warmup_jetstream_pytorch():
model_id = "Maykeye/TinyLLama-v0"

# The maximum sequence length of the model is set to 1000, but warmup will round that up to the next power of two
# in prefill (1024).
sequence_length = 1000
# in prefill (256).
sequence_length = 250

model_path = prepare_model(model_id, sequence_length)
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = 20

generator = AutoGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length
model_path, revision="", max_batch_size=2, max_sequence_length=sequence_length
)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generator.warmup(batch)

# Prepare a new request with different settings. Warmup should have triggered compilation so this can be run
# quickly.
input_text = "What is Deep Learning?"
max_new_tokens = 3
max_tokens = 13
request1 = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=1, requests=[request1], size=1, max_tokens=max_tokens)

start = time()
_generations, new_batch = generator.prefill(batch)
_generations, new_batch = generator.decode([new_batch])
end = time()

# Prefill and decode time should be less than 1 second (rather fast)
assert end - start < 1.0

if __name__ == "__main__":
import os
os.environ["JETSTREAM_PT"] = "1"
test_warmup_jetstream_pytorch()

0 comments on commit a49d3af

Please sign in to comment.