Skip to content

Commit

Permalink
[python] Fix chat completions logprobs output (#1712)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Apr 1, 2024
1 parent bbd9042 commit 449c75f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 23 deletions.
71 changes: 50 additions & 21 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# the specific language governing permissions and limitations under the License.
import json
import logging
import time
from abc import ABC, abstractmethod
from typing import List, Union, List, Callable, Optional

Expand Down Expand Up @@ -110,6 +111,7 @@ def _json_chat_output_formatter(token: Token, first_token: bool,
:return: formatted output
"""
created = int(time.time())
choice1 = {
"index": 0,
"message": {
Expand All @@ -120,19 +122,37 @@ def _json_chat_output_formatter(token: Token, first_token: bool,
response1 = {
"id": f"chatcmpl-{id}",
"object": "chat.completion",
"created": created,
"choices": [choice1] # Currently only support 1 choice
}
json_encoded_str = f"{json.dumps(response1, ensure_ascii=False)[:-6]}" if first_token else ""
json_encoded_str = f"{json.dumps(response1, ensure_ascii=False)[:-5]}" if first_token else ""
json_encoded_str = f"{json_encoded_str}{json.dumps(token.text, ensure_ascii=False)[1:-1]}"
if last_token:
logprobs = None
parameters = details.get("parameters", {})
if parameters.get("logprobs"):
logprobs = [
{
"token":
t.get("text"),
"logprob":
t.get("log_prob"),
"bytes":
(b := [ord(c)
for c in t.get("text")] if t.get("text") else None),
"top_logprobs": # Currently only support 1 top_logprobs
[{
"token": t.get("text"),
"logprob": t.get("log_prob"),
"bytes": b
}]
} for t in details.get("tokens", [])
]
choice2 = {
"logprobs": {
"content": [{
"logprob": token.log_prob,
"token": token.text
}]
"content": logprobs
},
"finish_reason": details.get("finish_reason", None)
"finish_reason": details.get("finish_reason")
}
prompt_tokens = int(details.get("prompt_tokens", 0))
completion_tokens = int(details.get("generated_tokens", 0))
Expand All @@ -141,7 +161,7 @@ def _json_chat_output_formatter(token: Token, first_token: bool,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
"total_tokens": (prompt_tokens + completion_tokens)
}
}
json_encoded_str = f"{json_encoded_str}\"}}, {json.dumps(response2, ensure_ascii=False)[14:]}"
Expand All @@ -156,23 +176,36 @@ def _jsonlines_chat_output_formatter(token: Token, first_token: bool,
:return: formatted output
"""
created = int(time.time())
logprobs = None
parameters = details.get("parameters", {})
if parameters.get("logprobs"):
logprobs = {
"content":
[{
"token": token.text,
"logprob": token.log_prob,
"bytes": (b := [ord(c) for c in token.text] if token.text else None),
"top_logprobs": # Currently only support 1 top_logprobs
[{
"token": token.log_prob,
"logprob": token.log_prob,
"bytes": b
}]
}]
},
choice = {
"index": 0,
"delta": {
"role": "assistant",
"content": generated_tokens
},
"logprobs": {
"content": [{
"logprob": token.log_prob,
"token": token.text
}]
"content": token.text
},
"finish_reason": details.get("finish_reason", None)
"logprobs": logprobs,
"finish_reason": details.get("finish_reason")
}
response = {
"id": f"chatcmpl-{id}",
"object": "chat.completion.chunk",
"created": created,
"choices": [choice] # Currently only support 1 choice
}
json_encoded_str = json.dumps(response, ensure_ascii=False) + "\n"
Expand Down Expand Up @@ -314,14 +347,10 @@ def set_next_token(self,
details_dict["generated_tokens"] = len(self.token_cache)
details_dict["input_text"] = self.input_text
details_dict["parameters"] = self.parameters
details_dict["prompt_tokens"] = len(self.input_ids)
generated_text = self.full_text_prefix
if last_token:
generated_text = generated_text + ''.join(self.generated_tokens)
if self.details:
details_dict["finish_reason"] = finish_reason
details_dict["tokens"] = self.token_cache
details_dict["generated_tokens"] = len(self.token_cache)
details_dict["prompt_tokens"] = len(self.input_ids)
if self.output_formatter is None:
self.next_token_str = next_token.text
else: # output only supports size one now
Expand Down
7 changes: 5 additions & 2 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ def get_tokenizer(self):
"parameters": {
"max_new_tokens": 256,
"details": True
}
},
"prompt_tokens": 7
}
req.set_next_token(Token(576, "llo", -0.123123))
print(req.get_next_token(), end='')
Expand All @@ -357,7 +358,9 @@ def get_tokenizer(self):
"parameters": {
"max_new_tokens": 256,
"details": True
}
},
"prompt_tokens":
7
}
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')
print(req.get_next_token(), end='')
Expand Down

0 comments on commit 449c75f

Please sign in to comment.