Skip to content

Commit

Permalink
add TGI compat feature for rollingbatch (#1866)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored May 6, 2024
1 parent 874cf9d commit 9a8a601
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 57 deletions.
55 changes: 33 additions & 22 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# the specific language governing permissions and limitations under the License.
import json
import logging
import os
import time
from abc import ABC, abstractmethod
from typing import List, Union, List, Callable, Optional

FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"]
TGI_COMPAT = False


class Token(object):
Expand Down Expand Up @@ -64,6 +66,8 @@ def _json_output_formatter(token: Token, first_token: bool, last_token: bool,
:return: formatted output
"""
json_encoded_str = f"{{\"generated_text\": \"{generated_text}" if first_token else ""
if first_token and TGI_COMPAT:
json_encoded_str = f"[{json_encoded_str}"
json_encoded_str = f"{json_encoded_str}{json.dumps(token.text, ensure_ascii=False)[1:-1]}"
if last_token:
if details:
Expand All @@ -81,6 +85,8 @@ def _json_output_formatter(token: Token, first_token: bool, last_token: bool,
json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
else:
json_encoded_str = f"{json_encoded_str}\"}}"
if TGI_COMPAT:
json_encoded_str = f"{json_encoded_str}]"

return json_encoded_str

Expand Down Expand Up @@ -140,20 +146,20 @@ def _json_chat_output_formatter(token: Token, first_token: bool,
if parameters.get("logprobs"):
logprobs = {
"content": [{
"token":
t.get("text"),
"logprob":
t.get("log_prob"),
"bytes":
(b := [ord(c)
for c in t.get("text")] if t.get("text") else None),
"top_logprobs": # Currently only support 1 top_logprobs
[{
"token": t.get("text"),
"logprob": t.get("log_prob"),
"bytes": b
}]
} for t in details.get("tokens", [])
"token":
t.get("text"),
"logprob":
t.get("log_prob"),
"bytes":
(b := [ord(c)
for c in t.get("text")] if t.get("text") else None),
"top_logprobs": # Currently only support 1 top_logprobs
[{
"token": t.get("text"),
"logprob": t.get("log_prob"),
"bytes": b
}]
} for t in details.get("tokens", [])
]
}
choice2 = {
Expand Down Expand Up @@ -192,17 +198,17 @@ def _jsonlines_chat_output_formatter(token: Token, first_token: bool,
if parameters.get("logprobs"):
logprobs = {
"content":
[{
"token": token.text,
"logprob": token.log_prob,
"bytes": (b := [ord(c) for c in token.text] if token.text else None),
"top_logprobs": # Currently only support 1 top_logprobs
[{
"token": token.log_prob,
"token": token.text,
"logprob": token.log_prob,
"bytes": b
"bytes": (b := [ord(c) for c in token.text] if token.text else None),
"top_logprobs": # Currently only support 1 top_logprobs
[{
"token": token.log_prob,
"logprob": token.log_prob,
"bytes": b
}]
}]
}]
},
choice = {
"index": 0,
Expand Down Expand Up @@ -445,6 +451,11 @@ def __init__(self, **kwargs):
self.waiting_steps = kwargs.get("waiting_steps", None)
self.current_step = 0
self.default_output_formatter = kwargs.get("output_formatter", None)
# TODO: remove global context through refactoring
global TGI_COMPAT
# TODO: better handling to make it part of properties
TGI_COMPAT = os.environ.get("OPTION_TGI_COMPAT",
"false").lower() == 'true'

def reset(self):
self.pending_requests = []
Expand Down
63 changes: 29 additions & 34 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import unittest
from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter, \
RollingBatch

import djl_python.rolling_batch.rolling_batch
from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter


class TestRollingBatch(unittest.TestCase):
Expand Down Expand Up @@ -29,6 +30,32 @@ def test_json_fmt(self):
print(req.get_next_token(), end='')
assert req.get_next_token() == ' world"}'

def test_json_fmt_hf_compat(self):
djl_python.rolling_batch.rolling_batch.TGI_COMPAT = True

req = Request(0,
"This is a wonderful day",
parameters={
"max_new_tokens": 256,
"return_full_text": True,
},
output_formatter=_json_output_formatter)

final_str = []
req.set_next_token(Token(244, "He", -0.334532))
final_str.append(req.get_next_token())
req.set_next_token(Token(576, "llo", -0.123123))
final_str.append(req.get_next_token())
req.set_next_token(Token(4558, " world", -0.567854), True, 'length')
final_str.append(req.get_next_token())
final_json = json.loads(''.join(final_str))
print(final_json, end='')
assert final_json == [{
"generated_text":
"This is a wonderful dayHello world",
}]
djl_python.rolling_batch.rolling_batch.TGI_COMPAT = False

def test_jsonlines_fmt(self):
req1 = Request(0,
"This is a wonderful day",
Expand Down Expand Up @@ -277,22 +304,6 @@ def custom_fmt(token: Token, first_token: bool, last_token: bool,
result["finish_reason"] = details["finish_reason"]
return json.dumps(result) + "\n"

class CustomRB(RollingBatch):

def preprocess_requests(self, requests):
pass

def postprocess_results(self):
pass

def inference(self, input_data, parameters):
pass

def get_tokenizer(self):
pass

rb = CustomRB()

req = Request(132,
"This is a wonderful day",
parameters={
Expand Down Expand Up @@ -331,22 +342,6 @@ def custom_fmt(token: Token, first_token: bool, last_token: bool,
result = details
return json.dumps(result) + "\n"

class CustomRB(RollingBatch):

def preprocess_requests(self, requests):
pass

def postprocess_results(self):
pass

def inference(self, input_data, parameters):
pass

def get_tokenizer(self):
pass

rb = CustomRB()

req = Request(132,
"This is a wonderful day",
parameters={
Expand Down
9 changes: 8 additions & 1 deletion engines/python/setup/djl_python/tests/test_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import json
import os
import unittest
from djl_python.test_model import TestHandler
Expand Down Expand Up @@ -49,12 +50,15 @@ def test_all_code(self):
result = handler.inference_rolling_batch(
inputs, serving_properties=serving_properties)
self.assertEqual(len(result), len(inputs))
self.assertTrue(json.loads(result[0]), dict)
self.assertTrue(json.loads(result[1]), dict)

def test_with_env(self):
envs = {
"OPTION_MODEL_ID": "NousResearch/Nous-Hermes-Llama2-13b",
"SERVING_LOAD_MODELS": "test::MPI=/opt/ml/model",
"OPTION_ROLLING_BATCH": "auto"
"OPTION_ROLLING_BATCH": "auto",
"OPTION_TGI_COMPAT": "true"
}
for key, value in envs.items():
os.environ[key] = value
Expand All @@ -78,6 +82,9 @@ def test_with_env(self):
result = handler.inference_rolling_batch(inputs)
self.assertEqual(len(result), len(inputs))
self.assertTrue(len(result[1]) > len(result[0]))
# TGI compat tests
self.assertTrue(json.loads(result[0]), list)
self.assertTrue(json.loads(result[1]), list)

for key in envs.keys():
os.environ[key] = ""
Expand Down

0 comments on commit 9a8a601

Please sign in to comment.