diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 30d594ce7..7b97dab37 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -111,7 +111,7 @@ def _json_output_formatter(request_output: TextGenerationOutput): # TODO: Fix this so it is not required. Right now, this call is needed to # advance the token iterator, which is needed for rolling batch to work properly next_token, _, is_last_token = best_sequence.get_next_token() - if not is_last_token: + if not request_output.finished: return "" details = get_details_dict(request_output, include_tokens=True) if details.get("finish_reason") == "error": diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py index 64f0bec16..bc315f7d8 100644 --- a/engines/python/setup/djl_python/tests/test_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -47,7 +47,6 @@ def test_json_speculative_decoding(self): req = Request(req_input) req.request_output = TextGenerationOutput(request_id=0, input=req_input) - req.request_output.finished = True req.request_output.set_next_token(Token(244, "He", -0.334532)) req.request_output.set_next_token(Token(576, "llo", -0.123123)) req.request_output.set_next_token(Token(4558, " world", -0.567854, @@ -56,7 +55,7 @@ def test_json_speculative_decoding(self): finish_reason='length') self.assertEqual(req.get_next_token(), "") - self.assertEqual(req.get_next_token(), "") + req.request_output.finished = True self.assertEqual(req.get_next_token(), json.dumps({"generated_text": "Hello world"}))