From c791f5839f8a132d5f9ac5128938a253b6d466a8 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Wed, 18 Sep 2024 09:55:42 -0700 Subject: [PATCH] [fix][lmi] validate token exists in streaming output formatters to handle chunked_prefill correctly --- engines/python/setup/djl_python/output_formatter.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 3689a65d6..30d594ce7 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -222,6 +222,10 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput): best_sequence = request_output.sequences[ request_output.best_sequence_index] next_token, _, last_token = best_sequence.get_next_token() + # with chunked prefill, we don't generate any tokens until the full prompt has been processed. + # that means we sometimes don't have a token to return + if next_token is None: + return "" token_dict = next_token.as_tgi_dict( ) if tgi_compat else next_token.as_dict() final_dict = {"token": token_dict} @@ -239,6 +243,10 @@ def _jsonlines_3p_output_formatter(request_output: TextGenerationOutput): best_sequence = request_output.sequences[ request_output.best_sequence_index] next_token, first_token, last_token = best_sequence.get_next_token() + # with chunked prefill, we don't generate any tokens until the full prompt has been processed. + # that means we sometimes don't have a token to return + if next_token is None: + return "" token_details = next_token.as_dict() body = {"generation": token_details["text"]} num_prompt_tokens = len( @@ -336,6 +344,10 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput): best_sequence = request_output.sequences[ request_output.best_sequence_index] next_token, first_token, last_token = best_sequence.get_next_token() + # with chunked prefill, we don't generate any tokens until the full prompt has been processed. + # that means we sometimes don't have a token to return + if next_token is None: + return "" created = int(time.time()) delta = {"content": next_token.text}