Skip to content

Commit

Permalink
[python] Fix new logprobs computation in vllm_utils (#2146)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Jul 4, 2024
1 parent 1482ace commit 46e05cb
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
cur_len] if prev_len < cur_len else completion_output.logprobs
new_logprobs = []
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
new_logprobs.append(logprobs[token_id].logprob)
for token_id_key, logprob in logprobs.items():
new_logprobs.append(logprobs[token_id].logprob)
top_tokens.append(
Token(id=token_id_key,
text=logprob.decoded_token,
Expand All @@ -137,7 +137,7 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
for i, (token_id, token_text, logprob) in enumerate(
zip(new_token_ids, output_token_texts, new_logprobs)):
token = Token(token_id, token_text, logprob)
is_last_token = i == (len(new_logprobs) -
is_last_token = i == (len(new_token_ids) -
1) and finish_reason is not None
request_output.sequences[sequence_index].set_next_token(
token, is_last_token)
Expand Down

0 comments on commit 46e05cb

Please sign in to comment.