Skip to content

Commit

Permalink
Add better token handling under hazardous condition (#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored May 6, 2024
1 parent 9a8a601 commit ae3b455
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 23 deletions.
22 changes: 10 additions & 12 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ def set_next_token(self,
details_dict["inputs"] = self.input_text
details_dict["parameters"] = self.original_params
details_dict["prompt_tokens"] = len(self.input_ids)
# Special handling for error case
elif finish_reason == "error":
details_dict["finish_reason"] = finish_reason
generated_text = self.full_text_prefix
if last_token:
generated_text = generated_text + ''.join(self.generated_tokens)
Expand Down Expand Up @@ -413,19 +416,14 @@ def try_catch_handling(self, *args, **kwargs):
return func(self, *args, **kwargs)
except Exception:
logging.exception("Rolling batch inference error")
err = {
"data": "",
"last": True,
"step_token_num": 0,
"code": 424,
"error": ""
}
results = []
for i in range(
len(self.active_requests) + len(self.pending_requests)):
results.append(err)
for request in self.active_requests:
token = Token([-1], "", -1, None)
request.set_next_token(token,
last_token=True,
finish_reason="error")
response = self.postprocess_results()
self.reset()
return results
return response

return try_catch_handling

Expand Down
2 changes: 2 additions & 0 deletions engines/python/setup/djl_python/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def __init__(self,
self.service = load_model_service(model_dir, entry_point, "-1")
else:
self.service = ModelService(entry_point, model_dir)
# Prevent pre-initialized class objects in unit tests
getattr(self.service.module, "_service").initialized = False

def inference(self, inputs: Input) -> Output:
function_name = inputs.get_function_name() if inputs.get_function_name(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, model_id_or_path, properties, **kwargs):
self.tokens = self.tokens[:self.total_length]
self.cache = OrderedDict()

def get_tokenizer(self):
return self.tokenizer

def reset(self):
self.cache = OrderedDict()
super().reset()
Expand Down Expand Up @@ -92,3 +95,25 @@ def inference(self, input_data, parameters, adapters=None):

def preprocess_requests(self, requests):
raise NotImplementedError("Not implemented for vLLM rolling batcher")


class FakeRollingBatchWithException(FakeRollingBatch):

def __init__(self, model_id_or_path, properties, **kwargs):
super().__init__(model_id_or_path, properties, **kwargs)
self.dead_counter = 0
self.dead_trigger = random.randint(1, 50)

def reset(self):
super().reset()
self.dead_counter = 0
self.dead_trigger = random.randint(1, 50)

@stop_on_any_exception
def inference(self, input_data, parameters, adapters=None):

if self.dead_counter < self.dead_trigger:
self.dead_counter += 1
return super().inference(input_data, parameters, adapters)
else:
raise RuntimeError("Death trigger triggered...")
59 changes: 57 additions & 2 deletions engines/python/setup/djl_python/tests/test_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@
import unittest
from djl_python.test_model import TestHandler
from djl_python import huggingface
from .rolling_batch.fake_rolling_batch import FakeRollingBatch
from .rolling_batch.fake_rolling_batch import FakeRollingBatch, FakeRollingBatchWithException


def override_rolling_batch(rolling_batch_type: str, is_mpi: bool,
model_config):
return FakeRollingBatch


huggingface.get_rolling_batch_class_from_str = override_rolling_batch
def override_rolling_batch_with_exception(rolling_batch_type: str,
is_mpi: bool, model_config):
return FakeRollingBatchWithException


class TestTestModel(unittest.TestCase):

def test_all_code(self):
model_id = "NousResearch/Nous-Hermes-Llama2-13b"
huggingface.get_rolling_batch_class_from_str = override_rolling_batch
handler = TestHandler(huggingface)
inputs = [{
"inputs": "The winner of oscar this year is",
Expand Down Expand Up @@ -62,6 +65,7 @@ def test_with_env(self):
}
for key, value in envs.items():
os.environ[key] = value
huggingface.get_rolling_batch_class_from_str = override_rolling_batch
handler = TestHandler(huggingface)
self.assertEqual(handler.serving_properties["model_id"],
envs["OPTION_MODEL_ID"])
Expand Down Expand Up @@ -91,6 +95,7 @@ def test_with_env(self):

def test_all_code_chat(self):
model_id = "TheBloke/Llama-2-7B-Chat-fp16"
huggingface.get_rolling_batch_class_from_str = override_rolling_batch
handler = TestHandler(huggingface)
inputs = [{
"inputs":
Expand Down Expand Up @@ -122,6 +127,7 @@ def test_with_env_chat(self):
}
for key, value in envs.items():
os.environ[key] = value
huggingface.get_rolling_batch_class_from_str = override_rolling_batch
handler = TestHandler(huggingface)
self.assertEqual(handler.serving_properties["model_id"],
envs["OPTION_MODEL_ID"])
Expand All @@ -147,3 +153,52 @@ def test_with_env_chat(self):

for key in envs.keys():
os.environ[key] = ""

def test_exception_handling(self):
huggingface.get_rolling_batch_class_from_str = override_rolling_batch_with_exception
model_id = "NousResearch/Nous-Hermes-Llama2-13b"
handler = TestHandler(huggingface)
inputs = [{
"inputs": "The winner of oscar this year is",
"parameters": {
"min_new_tokens": 100,
"max_new_tokens": 256,
}
}, {
"inputs": "Hello world",
"parameters": {
"min_new_tokens": 100,
"max_new_tokens": 512,
}
}]
serving_properties = {
"engine": "Python",
"rolling_batch": "auto",
"model_id": model_id
}
result = handler.inference_rolling_batch(
inputs, serving_properties=serving_properties)
for key, value in result.items():
final_dict = json.loads(value)
self.assertEqual(final_dict["details"]["finish_reason"], 'error')
# test streaming
inputs = [{
"inputs": "The winner of oscar this year is",
"parameters": {
"min_new_tokens": 100,
"max_new_tokens": 256,
},
"stream": True,
}, {
"inputs": "Hello world",
"parameters": {
"min_new_tokens": 100,
"max_new_tokens": 512,
},
"stream": True,
}]
result = handler.inference_rolling_batch(
inputs, serving_properties=serving_properties)
for _, value in result.items():
final_dict = json.loads(value.splitlines()[-1])
self.assertEqual(final_dict["details"]["finish_reason"], 'error')
40 changes: 31 additions & 9 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,28 @@ def log_metrics(response_times):
f.close()


def response_checker(res, message):
if 'content-type' in res.headers.keys():
if 'application/json' == res.headers['content-type']:
output_json = json.loads(message)
if isinstance(output_json,
dict) and "details" in output_json.keys():
if "error" == output_json["details"]["finish_reason"]:
raise RuntimeError(f"Inference failed!")
elif 'application/jsonlines' == res.headers['content-type']:
json_lines = []
for item in message.splitlines():
json_lines.append(json.loads(item))
output_json = json_lines[-1]
if "details" in output_json.keys():
if "error" == output_json["details"]["finish_reason"]:
raise RuntimeError(f"Inference failed!")
else:
logging.info(
f"Skipping content check given non-supported content type {res.headers['content-type']}"
)


def test_handler_rolling_batch(model, model_spec):
if model not in model_spec:
raise ValueError(
Expand All @@ -858,10 +880,11 @@ def test_handler_rolling_batch(model, model_spec):
if "adapters" in spec:
req["adapters"] = spec.get("adapters")[0]
logging.info(f"req {req}")
res = send_json(req).content.decode("utf-8")
logging.info(f"res: {res}")
if "error" in res and "code" in res and "424" in res:
raise RuntimeError(f"Inference failed!")
res = send_json(req)
message = res.content.decode("utf-8")
logging.info(f"res: {message}")
response_checker(res, message)

# awscurl little benchmark phase
for i, batch_size in enumerate(spec["batch_size"]):
for seq_length in spec["seq_length"]:
Expand Down Expand Up @@ -896,10 +919,10 @@ def test_handler_adapters(model, model_spec):
reqs.append(req)
logging.info(f"reqs {reqs}")
for req in reqs:
res = send_json(req).content.decode("utf-8")
logging.info(f"res: {res}")
if "error" in res and "code" in res and "424" in res:
raise RuntimeError(f"Inference failed!")
res = send_json(req)
message = res.content.decode("utf-8")
logging.info(f"res: {message}")
response_checker(res, message)
# awscurl little benchmark phase
for i, batch_size in enumerate(spec["batch_size"]):
for seq_length in spec["seq_length"]:
Expand Down Expand Up @@ -1082,7 +1105,6 @@ def test_transformers_neuronx_handler(model, model_spec):


def run(raw_args):

parser = argparse.ArgumentParser(description="Build the LLM configs")
parser.add_argument("handler", help="the handler used in the model")
parser.add_argument("model", help="The name of model")
Expand Down

0 comments on commit ae3b455

Please sign in to comment.