diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py index a28d5219..26d61a0b 100644 --- a/text-generation-inference/tests/test_gpt2.py +++ b/text-generation-inference/tests/test_gpt2.py @@ -65,7 +65,10 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ def test_decode_multiple(model_path): - generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH) + generator = TpuGenerator.from_pretrained(model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH) input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token