Skip to content

Commit

Permalink
[fix][lmi][specdec] fix issue with json output formatter not returnin… (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Sep 25, 2024
1 parent 2d303d3 commit ca751a0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _json_output_formatter(request_output: TextGenerationOutput):
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, _, is_last_token = best_sequence.get_next_token()
if not is_last_token:
if not request_output.finished:
return ""
details = get_details_dict(request_output, include_tokens=True)
if details.get("finish_reason") == "error":
Expand Down
3 changes: 1 addition & 2 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_json_speculative_decoding(self):
req = Request(req_input)
req.request_output = TextGenerationOutput(request_id=0,
input=req_input)
req.request_output.finished = True
req.request_output.set_next_token(Token(244, "He", -0.334532))
req.request_output.set_next_token(Token(576, "llo", -0.123123))
req.request_output.set_next_token(Token(4558, " world", -0.567854,
Expand All @@ -56,7 +55,7 @@ def test_json_speculative_decoding(self):
finish_reason='length')

self.assertEqual(req.get_next_token(), "")
self.assertEqual(req.get_next_token(), "")
req.request_output.finished = True
self.assertEqual(req.get_next_token(),
json.dumps({"generated_text": "Hello world"}))

Expand Down

0 comments on commit ca751a0

Please sign in to comment.