diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index ee1e548d6..fe2110c59 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -12,11 +12,13 @@ # the specific language governing permissions and limitations under the License. import json import logging +import os import time from abc import ABC, abstractmethod from typing import List, Union, List, Callable, Optional FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"] +TGI_COMPAT = False class Token(object): @@ -64,6 +66,8 @@ def _json_output_formatter(token: Token, first_token: bool, last_token: bool, :return: formatted output """ json_encoded_str = f"{{\"generated_text\": \"{generated_text}" if first_token else "" + if first_token and TGI_COMPAT: + json_encoded_str = f"[{json_encoded_str}" json_encoded_str = f"{json_encoded_str}{json.dumps(token.text, ensure_ascii=False)[1:-1]}" if last_token: if details: @@ -81,6 +85,8 @@ def _json_output_formatter(token: Token, first_token: bool, last_token: bool, json_encoded_str = f"{json_encoded_str}\", {details_str}}}" else: json_encoded_str = f"{json_encoded_str}\"}}" + if TGI_COMPAT: + json_encoded_str = f"{json_encoded_str}]" return json_encoded_str @@ -140,20 +146,20 @@ def _json_chat_output_formatter(token: Token, first_token: bool, if parameters.get("logprobs"): logprobs = { "content": [{ - "token": - t.get("text"), - "logprob": - t.get("log_prob"), - "bytes": - (b := [ord(c) - for c in t.get("text")] if t.get("text") else None), - "top_logprobs": # Currently only support 1 top_logprobs - [{ - "token": t.get("text"), - "logprob": t.get("log_prob"), - "bytes": b - }] - } for t in details.get("tokens", []) + "token": + t.get("text"), + "logprob": + t.get("log_prob"), + "bytes": + (b := [ord(c) + for c in t.get("text")] if t.get("text") else None), + "top_logprobs": # Currently only support 1 top_logprobs + [{ + "token": t.get("text"), + "logprob": t.get("log_prob"), + "bytes": b + }] + } for t in details.get("tokens", []) ] } choice2 = { @@ -192,17 +198,17 @@ def _jsonlines_chat_output_formatter(token: Token, first_token: bool, if parameters.get("logprobs"): logprobs = { "content": - [{ - "token": token.text, - "logprob": token.log_prob, - "bytes": (b := [ord(c) for c in token.text] if token.text else None), - "top_logprobs": # Currently only support 1 top_logprobs [{ - "token": token.log_prob, + "token": token.text, "logprob": token.log_prob, - "bytes": b + "bytes": (b := [ord(c) for c in token.text] if token.text else None), + "top_logprobs": # Currently only support 1 top_logprobs + [{ + "token": token.log_prob, + "logprob": token.log_prob, + "bytes": b + }] }] - }] }, choice = { "index": 0, @@ -445,6 +451,11 @@ def __init__(self, **kwargs): self.waiting_steps = kwargs.get("waiting_steps", None) self.current_step = 0 self.default_output_formatter = kwargs.get("output_formatter", None) + # TODO: remove global context through refactoring + global TGI_COMPAT + # TODO: better handling to make it part of properties + TGI_COMPAT = os.environ.get("OPTION_TGI_COMPAT", + "false").lower() == 'true' def reset(self): self.pending_requests = [] diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py index bc218ec4f..866cc8185 100644 --- a/engines/python/setup/djl_python/tests/test_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -1,7 +1,8 @@ import json import unittest -from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter, \ - RollingBatch + +import djl_python.rolling_batch.rolling_batch +from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter class TestRollingBatch(unittest.TestCase): @@ -29,6 +30,32 @@ def test_json_fmt(self): print(req.get_next_token(), end='') assert req.get_next_token() == ' world"}' + def test_json_fmt_hf_compat(self): + djl_python.rolling_batch.rolling_batch.TGI_COMPAT = True + + req = Request(0, + "This is a wonderful day", + parameters={ + "max_new_tokens": 256, + "return_full_text": True, + }, + output_formatter=_json_output_formatter) + + final_str = [] + req.set_next_token(Token(244, "He", -0.334532)) + final_str.append(req.get_next_token()) + req.set_next_token(Token(576, "llo", -0.123123)) + final_str.append(req.get_next_token()) + req.set_next_token(Token(4558, " world", -0.567854), True, 'length') + final_str.append(req.get_next_token()) + final_json = json.loads(''.join(final_str)) + print(final_json, end='') + assert final_json == [{ + "generated_text": + "This is a wonderful dayHello world", + }] + djl_python.rolling_batch.rolling_batch.TGI_COMPAT = False + def test_jsonlines_fmt(self): req1 = Request(0, "This is a wonderful day", @@ -277,22 +304,6 @@ def custom_fmt(token: Token, first_token: bool, last_token: bool, result["finish_reason"] = details["finish_reason"] return json.dumps(result) + "\n" - class CustomRB(RollingBatch): - - def preprocess_requests(self, requests): - pass - - def postprocess_results(self): - pass - - def inference(self, input_data, parameters): - pass - - def get_tokenizer(self): - pass - - rb = CustomRB() - req = Request(132, "This is a wonderful day", parameters={ @@ -331,22 +342,6 @@ def custom_fmt(token: Token, first_token: bool, last_token: bool, result = details return json.dumps(result) + "\n" - class CustomRB(RollingBatch): - - def preprocess_requests(self, requests): - pass - - def postprocess_results(self): - pass - - def inference(self, input_data, parameters): - pass - - def get_tokenizer(self): - pass - - rb = CustomRB() - req = Request(132, "This is a wonderful day", parameters={ diff --git a/engines/python/setup/djl_python/tests/test_test_model.py b/engines/python/setup/djl_python/tests/test_test_model.py index 660a8fb6a..6dca392fa 100644 --- a/engines/python/setup/djl_python/tests/test_test_model.py +++ b/engines/python/setup/djl_python/tests/test_test_model.py @@ -10,6 +10,7 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import json import os import unittest from djl_python.test_model import TestHandler @@ -49,12 +50,15 @@ def test_all_code(self): result = handler.inference_rolling_batch( inputs, serving_properties=serving_properties) self.assertEqual(len(result), len(inputs)) + self.assertTrue(json.loads(result[0]), dict) + self.assertTrue(json.loads(result[1]), dict) def test_with_env(self): envs = { "OPTION_MODEL_ID": "NousResearch/Nous-Hermes-Llama2-13b", "SERVING_LOAD_MODELS": "test::MPI=/opt/ml/model", - "OPTION_ROLLING_BATCH": "auto" + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TGI_COMPAT": "true" } for key, value in envs.items(): os.environ[key] = value @@ -78,6 +82,9 @@ def test_with_env(self): result = handler.inference_rolling_batch(inputs) self.assertEqual(len(result), len(inputs)) self.assertTrue(len(result[1]) > len(result[0])) + # TGI compat tests + self.assertTrue(json.loads(result[0]), list) + self.assertTrue(json.loads(result[1]), list) for key in envs.keys(): os.environ[key] = ""