From ca751a04eae7c38a3016620ca80ba44ef91db2ee Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Mon, 23 Sep 2024 17:18:43 -0700 Subject: [PATCH] =?UTF-8?q?[fix][lmi][specdec]=20fix=20issue=20with=20json?= =?UTF-8?q?=20output=20formatter=20not=20returnin=E2=80=A6=20(#2403)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- engines/python/setup/djl_python/output_formatter.py | 2 +- engines/python/setup/djl_python/tests/test_rolling_batch.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 30d594ce7..7b97dab37 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -111,7 +111,7 @@ def _json_output_formatter(request_output: TextGenerationOutput): # TODO: Fix this so it is not required. Right now, this call is needed to # advance the token iterator, which is needed for rolling batch to work properly next_token, _, is_last_token = best_sequence.get_next_token() - if not is_last_token: + if not request_output.finished: return "" details = get_details_dict(request_output, include_tokens=True) if details.get("finish_reason") == "error": diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py index c90087d83..3cdf8f64e 100644 --- a/engines/python/setup/djl_python/tests/test_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -47,7 +47,6 @@ def test_json_speculative_decoding(self): req = Request(req_input) req.request_output = TextGenerationOutput(request_id=0, input=req_input) - req.request_output.finished = True req.request_output.set_next_token(Token(244, "He", -0.334532)) req.request_output.set_next_token(Token(576, "llo", -0.123123)) req.request_output.set_next_token(Token(4558, " world", -0.567854, @@ -56,7 +55,7 @@ def test_json_speculative_decoding(self): finish_reason='length') self.assertEqual(req.get_next_token(), "") - self.assertEqual(req.get_next_token(), "") + req.request_output.finished = True self.assertEqual(req.get_next_token(), json.dumps({"generated_text": "Hello world"}))