diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index 3d3860da3..586b941d2 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -127,7 +127,9 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): new_logprobs = [] for token_id, logprobs in zip(new_token_ids, new_logprobs_list): new_logprobs.append(logprobs[token_id].logprob) - token_texts.append(logprobs[token_id].decoded_token) + decoded_token = logprobs[token_id].decoded_token if logprobs[ + token_id].decoded_token else "" + token_texts.append(decoded_token) for token_id_key, logprob in logprobs.items(): top_tokens.append( Token(id=token_id_key, @@ -157,10 +159,14 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): sequence_index].cumulative_log_prob = completion_output.cumulative_logprob if new_token_ids: + # During last generation, length of token_texts could be lesser than new_token_ids, since the + # last token could be a special end_token_id, for which token_text would not be returned for SD. + new_tokens_len = min(len(new_token_ids), len(output_token_texts), + len(new_logprobs)) for i, (token_id, token_text, logprob) in enumerate( zip(new_token_ids, output_token_texts, new_logprobs)): token = Token(token_id, token_text, logprob) - is_last_token = i == (len(new_token_ids) - + is_last_token = i == (new_tokens_len - 1) and finish_reason is not None request_output.sequences[sequence_index].set_next_token( token, is_last_token) diff --git a/engines/python/setup/setup.py b/engines/python/setup/setup.py index 9d1b30504..dbe2a8e66 100644 --- a/engines/python/setup/setup.py +++ b/engines/python/setup/setup.py @@ -56,8 +56,8 @@ def run(self): requirements = ['psutil', 'packaging', 'wheel'] test_requirements = [ - 'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops', - 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', + 'numpy<2', 'requests', 'Pillow', 'transformers==4.43.4', 'torch', + 'einops', 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', 'pydantic>=2.0', "objgraph" ]