Skip to content

Commit

Permalink
fix(test): multiple decode test require max_batch_size to be > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Jul 6, 2024
1 parent 4528bcb commit 88a1c60
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion text-generation-inference/tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_


def test_decode_multiple(model_path):
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH)
generator = TpuGenerator.from_pretrained(model_path,
revision="",
max_batch_size=2,
max_sequence_length=SEQUENCE_LENGTH)
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
Expand Down

0 comments on commit 88a1c60

Please sign in to comment.