Skip to content

Commit

Permalink
[python] check whether last token is generated for json_output_format… (
Browse files Browse the repository at this point in the history
#2381)

Co-authored-by: Sindhu Somasundaram <[email protected]>
  • Loading branch information
siddvenk and sindhuvahinis authored Sep 12, 2024
1 parent 32326ad commit a7cb6c2
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 63 deletions.
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def _json_output_formatter(request_output: TextGenerationOutput):
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, _, _ = best_sequence.get_next_token()
if not request_output.finished:
next_token, _, is_last_token = best_sequence.get_next_token()
if not is_last_token:
return ""
details = get_details_dict(request_output, include_tokens=True)
if details.get("finish_reason") == "error":
Expand Down
7 changes: 3 additions & 4 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from djl_python.output_formatter import get_output_formatter, adapt_legacy_output_formatter
from djl_python.request_io import Token, TextGenerationOutput, TextInput, RequestOutput, RequestInput
from djl_python.utils import wait_till_generation_finished
from djl_python.utils import is_streaming


class Request(object):
Expand Down Expand Up @@ -113,9 +113,8 @@ def get_next_token(self) -> str:
if self.legacy_formatter:
self.next_token_str = adapt_legacy_output_formatter(
self.request_output)
elif wait_till_generation_finished(
self.request_output.input.parameters):
# there is no need for iterators for best_of and num_beams.
elif not is_streaming(self.request_output.input.parameters):
# there is no need for iterators in non-streaming use-cases
self.next_token_str = self.output_formatter(self.request_output)
else:
best_sequence = self.request_output.sequences[
Expand Down
212 changes: 155 additions & 57 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,38 @@ def test_json_fmt(self):
req2 = Request(req_input2)
for req in [req1, req2]:
req.set_next_token(Token(244, "He", -0.334532))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(576, "llo", -0.123123))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(4558, " world", -0.567854), True,
'length')
print(req.get_next_token(), end='')
assert req.get_next_token() == json.dumps(
{"generated_text": "Hello world"})

def test_json_speculative_decoding(self):
req_input = TextInput(
request_id=0,
input_text="This is a wonderful day",
parameters={"max_new_tokens": 256},
output_formatter=_json_output_formatter,
)
req = Request(req_input)
req.request_output = TextGenerationOutput(request_id=0,
input=req_input)
req.request_output.finished = True
req.request_output.set_next_token(Token(244, "He", -0.334532))
req.request_output.set_next_token(Token(576, "llo", -0.123123))
req.request_output.set_next_token(Token(4558, " world", -0.567854,
True),
is_last_token=True,
finish_reason='length')

self.assertEqual(req.get_next_token(), "")
self.assertEqual(req.get_next_token(), "")
self.assertEqual(req.get_next_token(),
json.dumps({"generated_text": "Hello world"}))

def test_json_fmt_with_appending(self):
req_input1 = TextInput(request_id=0,
input_text="This is a wonderful day",
Expand All @@ -54,9 +77,9 @@ def test_json_fmt_with_appending(self):
req2 = Request(req_input2)
for req in [req1, req2]:
req.set_next_token(Token(244, "He", -0.334532))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(576, "llo", -0.123123))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(4558, " world", -0.567854), True,
'length')
print(req.get_next_token(), end='')
Expand All @@ -76,9 +99,9 @@ def test_fmt_hf_compat(self):
tgi_compat=True))

req.set_next_token(Token(244, "He", -0.334532))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(576, "llo", -0.123123))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')
final_json = json.loads(req.get_next_token())
print(final_json, end='')
Expand Down Expand Up @@ -153,6 +176,47 @@ def test_jsonlines_fmt(self):
"generated_text": "Hello world"
}

def test_jsonlines_speculative_decoding(self):
request_input = TextInput(request_id=0,
input_text="This is a wonderful day",
parameters={"max_new_tokens": 256},
output_formatter=_jsonlines_output_formatter)
req = Request(request_input=request_input)
req.request_output = TextGenerationOutput(request_id=0,
input=request_input)
req.request_output.finished = True
req.request_output.set_next_token(Token(244, "He", -0.334532))
print(req.get_next_token(), end='')
self.assertEqual(
{"token": {
"id": 244,
"text": "He",
"log_prob": -0.334532
}}, json.loads(req.get_next_token()))
req.reset_next_token()
req.request_output.set_next_token(Token(576, "llo", -0.123123))
print(req.get_next_token(), end='')
self.assertEqual(
{"token": {
"id": 576,
"text": "llo",
"log_prob": -0.123123
}}, json.loads(req.get_next_token()))
req.reset_next_token()
req.request_output.set_next_token(Token(4558, " world", -0.567854),
is_last_token=True,
finish_reason='length')
print(req.get_next_token(), end='')
self.assertEqual(
{
"token": {
"id": 4558,
"text": " world",
"log_prob": -0.567854
},
"generated_text": "Hello world"
}, json.loads(req.get_next_token()))

def test_sse_fmt(self):
request_input = TextInput(request_id=0,
input_text="This is a wonderful day",
Expand Down Expand Up @@ -266,11 +330,11 @@ def test_3p_fmt(self):
Token(8, "day", -0.7),
]
req.set_next_token(Token(244, "He", -0.334532))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(244, "llo", -0.123123))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(4558, " world", -0.567854))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(245, "", -1, True), True, "length")
output = json.loads(req.get_next_token())
print(req.get_next_token())
Expand Down Expand Up @@ -381,14 +445,15 @@ def test_return_full_text(self):
output_formatter=_json_output_formatter))

req.set_next_token(Token(244, "He", -0.334532))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(576, "llo", -0.123123))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')

assert req.get_next_token() == json.dumps(
{"generated_text": "This is a wonderful dayHello world"})

def test_return_full_text_stream(self):
req = Request(
TextInput(request_id=0,
input_text="This is a wonderful day",
Expand All @@ -398,7 +463,21 @@ def test_return_full_text(self):
},
output_formatter=_jsonlines_output_formatter))
req.set_next_token(Token(244, "He", -0.334532))
self.assertEqual(
{"token": {
"id": 244,
"text": "He",
"log_prob": -0.334532,
}}, json.loads(req.get_next_token()))
req.reset_next_token()
req.set_next_token(Token(576, "llo", -0.123123))
self.assertEqual(
{"token": {
"id": 576,
"text": "llo",
"log_prob": -0.123123,
}}, json.loads(req.get_next_token()))
req.reset_next_token()
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token().splitlines()[-1]) == {
Expand Down Expand Up @@ -427,37 +506,39 @@ def test_details(self):
Token(8, "day", -0.7),
]
req.set_next_token(Token(244, "He", -0.334532))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(576, "llo", -0.123123))
req.get_next_token()
self.assertEqual(req.get_next_token(), "")
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')

final_json = json.loads(req.get_next_token())

assert final_json == {
"generated_text": "Hello world",
"details": {
'inputs':
'This is a wonderful day',
"finish_reason":
"length",
"generated_tokens":
3,
"tokens": [{
"id": 244,
"text": "He",
"log_prob": -0.334532
}, {
"id": 576,
"text": "llo",
"log_prob": -0.123123
}, {
"id": 4558,
"text": " world",
"log_prob": -0.567854
}]
}
}
self.assertEqual(
{
"generated_text": "Hello world",
"details": {
'inputs':
'This is a wonderful day',
"finish_reason":
"length",
"generated_tokens":
3,
"tokens": [{
"id": 244,
"text": "He",
"log_prob": -0.334532
}, {
"id": 576,
"text": "llo",
"log_prob": -0.123123
}, {
"id": 4558,
"text": " world",
"log_prob": -0.567854
}]
}
}, final_json)

def test_details_stream(self):
# Jsonlines tests
req = Request(
TextInput(request_id=0,
Expand All @@ -475,7 +556,23 @@ def test_details(self):
Token(8, "day", -0.7),
]
req.set_next_token(Token(244, "He", -0.334532))
next_token = req.get_next_token()
self.assertEqual(
{"token": {
"id": 244,
"text": "He",
"log_prob": -0.334532,
}}, json.loads(next_token))
req.reset_next_token()
req.set_next_token(Token(576, "llo", -0.123123))
next_token = req.get_next_token()
self.assertEqual(
{"token": {
"id": 576,
"text": "llo",
"log_prob": -0.123123,
}}, json.loads(next_token))
req.reset_next_token()
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token().splitlines()[-1]) == {
Expand Down Expand Up @@ -809,21 +906,22 @@ def custom_fmt(token: Token, first_token: bool, last_token: bool,
]
req.set_next_token(Token(244, "He", -0.334532))
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
'finish_reason': None,
'generated_tokens': 1,
'inputs': 'This is a wonderful day',
'tokens': [{
'id': 244,
'log_prob': -0.334532,
'text': 'He'
}],
"parameters": {
"max_new_tokens": 256,
"details": True
},
"prompt_tokens": 5
}
self.assertEqual(
json.loads(req.get_next_token()), {
'finish_reason': None,
'generated_tokens': 1,
'inputs': 'This is a wonderful day',
'tokens': [{
'id': 244,
'log_prob': -0.334532,
'text': 'He'
}],
"parameters": {
"max_new_tokens": 256,
"details": True
},
"prompt_tokens": 5
})
req.reset_next_token()
req.set_next_token(Token(576, "llo", -0.123123))
print(req.get_next_token(), end='')
Expand Down Expand Up @@ -912,7 +1010,7 @@ def custom_fmt_wait(request_output: TextGenerationOutput):
elif best_sequence.finish_reason == "error":
result["finish_reason"] = best_sequence.finish_reason
return json.dumps(result) + "\n"
return json.dumps("") + "\n"
return ""

parameters = {"max_new_tokens": 256, "details": True, "stream": False}

Expand All @@ -922,15 +1020,15 @@ def custom_fmt_wait(request_output: TextGenerationOutput):
parameters=parameters,
output_formatter=custom_fmt_wait))
print(req.request_input.parameters)
assert req.request_input.parameters == parameters
self.assertEqual(req.request_input.parameters, parameters)

req.set_next_token(Token(244, "He", -0.334532))
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == ""
self.assertEqual(req.get_next_token(), "")
req.reset_next_token()
req.set_next_token(Token(576, "llo", -0.123123))
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == ""
self.assertEqual(req.get_next_token(), "")
req.reset_next_token()
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')
print(req.get_next_token(), end='')
Expand Down
9 changes: 9 additions & 0 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ def is_multiple_sequences(parameters: dict) -> bool:
return "n" in parameters.keys() and parameters.get("n") > 1


def is_streaming(parameters: dict) -> bool:
"""
Returns whether token streaming is enabled for the request
:param parameters: parameters dictionary
:return: boolean
"""
return "stream" in parameters.keys() and parameters.get("stream")


def wait_till_generation_finished(parameters):
return is_best_of(parameters) or is_multiple_sequences(
parameters) or is_beam_search(parameters)
Expand Down

0 comments on commit a7cb6c2

Please sign in to comment.