From a7cb6c2785779843a2a96d9de8ac0fc34c786e44 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Thu, 12 Sep 2024 15:27:32 -0700 Subject: [PATCH] =?UTF-8?q?[python]=20check=20whether=20last=20token=20is?= =?UTF-8?q?=20generated=20for=20json=5Foutput=5Fformat=E2=80=A6=20(#2381)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sindhu Somasundaram <56774226+sindhuvahinis@users.noreply.github.com> --- .../setup/djl_python/output_formatter.py | 4 +- engines/python/setup/djl_python/request.py | 7 +- .../djl_python/tests/test_rolling_batch.py | 212 +++++++++++++----- engines/python/setup/djl_python/utils.py | 9 + 4 files changed, 169 insertions(+), 63 deletions(-) diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 17709388e..3689a65d6 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -110,8 +110,8 @@ def _json_output_formatter(request_output: TextGenerationOutput): request_output.best_sequence_index] # TODO: Fix this so it is not required. Right now, this call is needed to # advance the token iterator, which is needed for rolling batch to work properly - next_token, _, _ = best_sequence.get_next_token() - if not request_output.finished: + next_token, _, is_last_token = best_sequence.get_next_token() + if not is_last_token: return "" details = get_details_dict(request_output, include_tokens=True) if details.get("finish_reason") == "error": diff --git a/engines/python/setup/djl_python/request.py b/engines/python/setup/djl_python/request.py index ad79685a2..53a08d2cc 100644 --- a/engines/python/setup/djl_python/request.py +++ b/engines/python/setup/djl_python/request.py @@ -15,7 +15,7 @@ from djl_python.output_formatter import get_output_formatter, adapt_legacy_output_formatter from djl_python.request_io import Token, TextGenerationOutput, TextInput, RequestOutput, RequestInput -from djl_python.utils import wait_till_generation_finished +from djl_python.utils import is_streaming class Request(object): @@ -113,9 +113,8 @@ def get_next_token(self) -> str: if self.legacy_formatter: self.next_token_str = adapt_legacy_output_formatter( self.request_output) - elif wait_till_generation_finished( - self.request_output.input.parameters): - # there is no need for iterators for best_of and num_beams. + elif not is_streaming(self.request_output.input.parameters): + # there is no need for iterators in non-streaming use-cases self.next_token_str = self.output_formatter(self.request_output) else: best_sequence = self.request_output.sequences[ diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py index 1e76a7e28..c90087d83 100644 --- a/engines/python/setup/djl_python/tests/test_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -28,15 +28,38 @@ def test_json_fmt(self): req2 = Request(req_input2) for req in [req1, req2]: req.set_next_token(Token(244, "He", -0.334532)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(576, "llo", -0.123123)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(4558, " world", -0.567854), True, 'length') print(req.get_next_token(), end='') assert req.get_next_token() == json.dumps( {"generated_text": "Hello world"}) + def test_json_speculative_decoding(self): + req_input = TextInput( + request_id=0, + input_text="This is a wonderful day", + parameters={"max_new_tokens": 256}, + output_formatter=_json_output_formatter, + ) + req = Request(req_input) + req.request_output = TextGenerationOutput(request_id=0, + input=req_input) + req.request_output.finished = True + req.request_output.set_next_token(Token(244, "He", -0.334532)) + req.request_output.set_next_token(Token(576, "llo", -0.123123)) + req.request_output.set_next_token(Token(4558, " world", -0.567854, + True), + is_last_token=True, + finish_reason='length') + + self.assertEqual(req.get_next_token(), "") + self.assertEqual(req.get_next_token(), "") + self.assertEqual(req.get_next_token(), + json.dumps({"generated_text": "Hello world"})) + def test_json_fmt_with_appending(self): req_input1 = TextInput(request_id=0, input_text="This is a wonderful day", @@ -54,9 +77,9 @@ def test_json_fmt_with_appending(self): req2 = Request(req_input2) for req in [req1, req2]: req.set_next_token(Token(244, "He", -0.334532)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(576, "llo", -0.123123)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(4558, " world", -0.567854), True, 'length') print(req.get_next_token(), end='') @@ -76,9 +99,9 @@ def test_fmt_hf_compat(self): tgi_compat=True)) req.set_next_token(Token(244, "He", -0.334532)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(576, "llo", -0.123123)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(4558, " world", -0.567854), True, 'length') final_json = json.loads(req.get_next_token()) print(final_json, end='') @@ -153,6 +176,47 @@ def test_jsonlines_fmt(self): "generated_text": "Hello world" } + def test_jsonlines_speculative_decoding(self): + request_input = TextInput(request_id=0, + input_text="This is a wonderful day", + parameters={"max_new_tokens": 256}, + output_formatter=_jsonlines_output_formatter) + req = Request(request_input=request_input) + req.request_output = TextGenerationOutput(request_id=0, + input=request_input) + req.request_output.finished = True + req.request_output.set_next_token(Token(244, "He", -0.334532)) + print(req.get_next_token(), end='') + self.assertEqual( + {"token": { + "id": 244, + "text": "He", + "log_prob": -0.334532 + }}, json.loads(req.get_next_token())) + req.reset_next_token() + req.request_output.set_next_token(Token(576, "llo", -0.123123)) + print(req.get_next_token(), end='') + self.assertEqual( + {"token": { + "id": 576, + "text": "llo", + "log_prob": -0.123123 + }}, json.loads(req.get_next_token())) + req.reset_next_token() + req.request_output.set_next_token(Token(4558, " world", -0.567854), + is_last_token=True, + finish_reason='length') + print(req.get_next_token(), end='') + self.assertEqual( + { + "token": { + "id": 4558, + "text": " world", + "log_prob": -0.567854 + }, + "generated_text": "Hello world" + }, json.loads(req.get_next_token())) + def test_sse_fmt(self): request_input = TextInput(request_id=0, input_text="This is a wonderful day", @@ -266,11 +330,11 @@ def test_3p_fmt(self): Token(8, "day", -0.7), ] req.set_next_token(Token(244, "He", -0.334532)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(244, "llo", -0.123123)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(4558, " world", -0.567854)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(245, "", -1, True), True, "length") output = json.loads(req.get_next_token()) print(req.get_next_token()) @@ -381,14 +445,15 @@ def test_return_full_text(self): output_formatter=_json_output_formatter)) req.set_next_token(Token(244, "He", -0.334532)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(576, "llo", -0.123123)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(4558, " world", -0.567854), True, 'length') assert req.get_next_token() == json.dumps( {"generated_text": "This is a wonderful dayHello world"}) + def test_return_full_text_stream(self): req = Request( TextInput(request_id=0, input_text="This is a wonderful day", @@ -398,7 +463,21 @@ def test_return_full_text(self): }, output_formatter=_jsonlines_output_formatter)) req.set_next_token(Token(244, "He", -0.334532)) + self.assertEqual( + {"token": { + "id": 244, + "text": "He", + "log_prob": -0.334532, + }}, json.loads(req.get_next_token())) + req.reset_next_token() req.set_next_token(Token(576, "llo", -0.123123)) + self.assertEqual( + {"token": { + "id": 576, + "text": "llo", + "log_prob": -0.123123, + }}, json.loads(req.get_next_token())) + req.reset_next_token() req.set_next_token(Token(4558, " world", -0.567854), True, 'length') print(req.get_next_token(), end='') assert json.loads(req.get_next_token().splitlines()[-1]) == { @@ -427,37 +506,39 @@ def test_details(self): Token(8, "day", -0.7), ] req.set_next_token(Token(244, "He", -0.334532)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(576, "llo", -0.123123)) - req.get_next_token() + self.assertEqual(req.get_next_token(), "") req.set_next_token(Token(4558, " world", -0.567854), True, 'length') - final_json = json.loads(req.get_next_token()) - assert final_json == { - "generated_text": "Hello world", - "details": { - 'inputs': - 'This is a wonderful day', - "finish_reason": - "length", - "generated_tokens": - 3, - "tokens": [{ - "id": 244, - "text": "He", - "log_prob": -0.334532 - }, { - "id": 576, - "text": "llo", - "log_prob": -0.123123 - }, { - "id": 4558, - "text": " world", - "log_prob": -0.567854 - }] - } - } + self.assertEqual( + { + "generated_text": "Hello world", + "details": { + 'inputs': + 'This is a wonderful day', + "finish_reason": + "length", + "generated_tokens": + 3, + "tokens": [{ + "id": 244, + "text": "He", + "log_prob": -0.334532 + }, { + "id": 576, + "text": "llo", + "log_prob": -0.123123 + }, { + "id": 4558, + "text": " world", + "log_prob": -0.567854 + }] + } + }, final_json) + + def test_details_stream(self): # Jsonlines tests req = Request( TextInput(request_id=0, @@ -475,7 +556,23 @@ def test_details(self): Token(8, "day", -0.7), ] req.set_next_token(Token(244, "He", -0.334532)) + next_token = req.get_next_token() + self.assertEqual( + {"token": { + "id": 244, + "text": "He", + "log_prob": -0.334532, + }}, json.loads(next_token)) + req.reset_next_token() req.set_next_token(Token(576, "llo", -0.123123)) + next_token = req.get_next_token() + self.assertEqual( + {"token": { + "id": 576, + "text": "llo", + "log_prob": -0.123123, + }}, json.loads(next_token)) + req.reset_next_token() req.set_next_token(Token(4558, " world", -0.567854), True, 'length') print(req.get_next_token(), end='') assert json.loads(req.get_next_token().splitlines()[-1]) == { @@ -809,21 +906,22 @@ def custom_fmt(token: Token, first_token: bool, last_token: bool, ] req.set_next_token(Token(244, "He", -0.334532)) print(req.get_next_token(), end='') - assert json.loads(req.get_next_token()) == { - 'finish_reason': None, - 'generated_tokens': 1, - 'inputs': 'This is a wonderful day', - 'tokens': [{ - 'id': 244, - 'log_prob': -0.334532, - 'text': 'He' - }], - "parameters": { - "max_new_tokens": 256, - "details": True - }, - "prompt_tokens": 5 - } + self.assertEqual( + json.loads(req.get_next_token()), { + 'finish_reason': None, + 'generated_tokens': 1, + 'inputs': 'This is a wonderful day', + 'tokens': [{ + 'id': 244, + 'log_prob': -0.334532, + 'text': 'He' + }], + "parameters": { + "max_new_tokens": 256, + "details": True + }, + "prompt_tokens": 5 + }) req.reset_next_token() req.set_next_token(Token(576, "llo", -0.123123)) print(req.get_next_token(), end='') @@ -912,7 +1010,7 @@ def custom_fmt_wait(request_output: TextGenerationOutput): elif best_sequence.finish_reason == "error": result["finish_reason"] = best_sequence.finish_reason return json.dumps(result) + "\n" - return json.dumps("") + "\n" + return "" parameters = {"max_new_tokens": 256, "details": True, "stream": False} @@ -922,15 +1020,15 @@ def custom_fmt_wait(request_output: TextGenerationOutput): parameters=parameters, output_formatter=custom_fmt_wait)) print(req.request_input.parameters) - assert req.request_input.parameters == parameters + self.assertEqual(req.request_input.parameters, parameters) req.set_next_token(Token(244, "He", -0.334532)) print(req.get_next_token(), end='') - assert json.loads(req.get_next_token()) == "" + self.assertEqual(req.get_next_token(), "") req.reset_next_token() req.set_next_token(Token(576, "llo", -0.123123)) print(req.get_next_token(), end='') - assert json.loads(req.get_next_token()) == "" + self.assertEqual(req.get_next_token(), "") req.reset_next_token() req.set_next_token(Token(4558, " world", -0.567854), True, 'length') print(req.get_next_token(), end='') diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index 9cc54b3f6..f59722514 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -90,6 +90,15 @@ def is_multiple_sequences(parameters: dict) -> bool: return "n" in parameters.keys() and parameters.get("n") > 1 +def is_streaming(parameters: dict) -> bool: + """ + Returns whether token streaming is enabled for the request + :param parameters: parameters dictionary + :return: boolean + """ + return "stream" in parameters.keys() and parameters.get("stream") + + def wait_till_generation_finished(parameters): return is_best_of(parameters) or is_multiple_sequences( parameters) or is_beam_search(parameters)