From 2c82f33051982e90f0a480d9b89593ee63e0df76 Mon Sep 17 00:00:00 2001 From: rohithkrn Date: Thu, 29 Jun 2023 23:00:31 +0000 Subject: [PATCH 1/2] raise inference failure exceptions in default handlers --- engines/python/setup/djl_python/deepspeed.py | 2 +- engines/python/setup/djl_python/fastertransformer.py | 2 +- engines/python/setup/djl_python/huggingface.py | 3 +-- engines/python/setup/djl_python/transformers-neuronx.py | 3 +-- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index 9dad7e7e3..72965c739 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -368,7 +368,7 @@ def inference(self, inputs: Input): outputs.add_property("content-type", "application/json") except Exception as e: logging.exception("DeepSpeed inference failed") - outputs = Output().error((str(e))) + raise e return outputs diff --git a/engines/python/setup/djl_python/fastertransformer.py b/engines/python/setup/djl_python/fastertransformer.py index 16b8edc15..e77e2b2d5 100644 --- a/engines/python/setup/djl_python/fastertransformer.py +++ b/engines/python/setup/djl_python/fastertransformer.py @@ -172,7 +172,7 @@ def inference(self, inputs: Input): outputs.add(generated_text, key=inputs.get_content().key_at(i)) except Exception as e: logging.exception("FasterTransformer inference failed") - outputs = Output().error((str(e))) + raise e return outputs diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 19a6eb860..05cd560de 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -215,8 +215,7 @@ def inference(self, inputs): offset += input_size[i] except Exception as e: logging.exception("Huggingface inference failed") - # error handling - outputs = Output().error(str(e)) + raise e return outputs diff --git a/engines/python/setup/djl_python/transformers-neuronx.py b/engines/python/setup/djl_python/transformers-neuronx.py index 9d03cfb46..864dc5e17 100644 --- a/engines/python/setup/djl_python/transformers-neuronx.py +++ b/engines/python/setup/djl_python/transformers-neuronx.py @@ -216,8 +216,7 @@ def infer(self, inputs): except Exception as e: logging.exception("TransformerNeuronX inference failed") - outputs = Output().error((str(e))) - return outputs + raise e _service = TransformersNeuronXService() From ff1574a5312e1f65db7b2b077dfadda2b9bf1b7a Mon Sep 17 00:00:00 2001 From: rohithkrn Date: Thu, 29 Jun 2023 23:33:46 +0000 Subject: [PATCH 2/2] do not catch exception in default handlers --- engines/python/setup/djl_python/deepspeed.py | 165 +++++++++--------- .../setup/djl_python/fastertransformer.py | 122 +++++++------ .../python/setup/djl_python/huggingface.py | 132 +++++++------- .../setup/djl_python/transformers-neuronx.py | 85 +++++---- 4 files changed, 244 insertions(+), 260 deletions(-) diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index 72965c739..224dadfa8 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -279,96 +279,93 @@ def format_input_for_task(self, input_values): return batch_inputs def inference(self, inputs: Input): - try: - content_type = inputs.get_property("Content-Type") - input_data = [] - input_size = [] - model_kwargs = {} - batch = inputs.get_batches() - if content_type is not None and content_type.startswith( - "application/json"): - first = True - for item in batch: - json_input = item.get_as_json() - if isinstance(json_input, dict): - input_size.append(len(json_input.get("inputs"))) - input_data.extend( - self.format_input_for_task( - json_input.pop("inputs"))) - if first: - model_kwargs = json_input.pop("parameters", {}) - first = False - else: - if model_kwargs != json_input.pop( - "parameters", {}): - return Output().error( - "In order to enable dynamic batching, all input batches must have the same parameters" - ) + content_type = inputs.get_property("Content-Type") + input_data = [] + input_size = [] + model_kwargs = {} + batch = inputs.get_batches() + if content_type is not None and content_type.startswith( + "application/json"): + first = True + for item in batch: + json_input = item.get_as_json() + if isinstance(json_input, dict): + input_size.append(len(json_input.get("inputs"))) + input_data.extend( + self.format_input_for_task( + json_input.pop("inputs"))) + if first: + model_kwargs = json_input.pop("parameters", {}) + first = False else: - input_size.append(len(json_input)) - input_data.extend(json_input) - else: - for item in batch: - input_size.append(1) - input_data.extend(item.get_as_string()) - - outputs = Output() - if self.enable_streaming: - outputs.add_property("content-type", "application/jsonlines") - if self.enable_streaming == "huggingface": - outputs.add_stream_content( - StreamingUtils.use_hf_default_streamer( - self.model, self.tokenizer, input_data, - self.device, **model_kwargs)) + if model_kwargs != json_input.pop( + "parameters", {}): + return Output().error( + "In order to enable dynamic batching, all input batches must have the same parameters" + ) else: - stream_generator = StreamingUtils.get_stream_generator( - "DeepSpeed") - outputs.add_stream_content( - stream_generator(self.model, self.tokenizer, - input_data, self.device, - **model_kwargs)) - return outputs - if self.task == "text-generation": - tokenized_inputs = self.tokenizer(input_data, - padding=True, - return_tensors="pt").to( - self.device) - with torch.no_grad(): - output_tokens = self.model.generate( - input_ids=tokenized_inputs.input_ids, - attention_mask=tokenized_inputs.attention_mask, - **model_kwargs) - generated_text = self.tokenizer.batch_decode( - output_tokens, skip_special_tokens=True) - outputs.add_property("content-type", "application/json") - offset = 0 - for i in range(inputs.get_batch_size()): - result = [{ - "generated_text": s - } for s in generated_text[offset:offset + input_size[i]]] - outputs.add(result, key=inputs.get_content().key_at(i)) - offset += input_size[i] - return outputs - - result = self.pipeline(input_data, **model_kwargs) + input_size.append(len(json_input)) + input_data.extend(json_input) + else: + for item in batch: + input_size.append(1) + input_data.extend(item.get_as_string()) + + outputs = Output() + if self.enable_streaming: + outputs.add_property("content-type", "application/jsonlines") + if self.enable_streaming == "huggingface": + outputs.add_stream_content( + StreamingUtils.use_hf_default_streamer( + self.model, self.tokenizer, input_data, + self.device, **model_kwargs)) + else: + stream_generator = StreamingUtils.get_stream_generator( + "DeepSpeed") + outputs.add_stream_content( + stream_generator(self.model, self.tokenizer, + input_data, self.device, + **model_kwargs)) + return outputs + if self.task == "text-generation": + tokenized_inputs = self.tokenizer(input_data, + padding=True, + return_tensors="pt").to( + self.device) + with torch.no_grad(): + output_tokens = self.model.generate( + input_ids=tokenized_inputs.input_ids, + attention_mask=tokenized_inputs.attention_mask, + **model_kwargs) + generated_text = self.tokenizer.batch_decode( + output_tokens, skip_special_tokens=True) + outputs.add_property("content-type", "application/json") offset = 0 for i in range(inputs.get_batch_size()): - res = result[offset:offset + input_size[i]] - if self.task == "conversational": - res = [{ - "generated_text": s.generated_responses[-1], - "conversation": { - "past_user_inputs": s.past_user_inputs, - "generated_responses": s.generated_responses, - }, - } for s in res] - outputs.add(res, key=inputs.get_content().key_at(i)) + result = [{ + "generated_text": s + } for s in generated_text[offset:offset + input_size[i]]] + outputs.add(result, key=inputs.get_content().key_at(i)) offset += input_size[i] + return outputs - outputs.add_property("content-type", "application/json") - except Exception as e: - logging.exception("DeepSpeed inference failed") - raise e + result = self.pipeline(input_data, **model_kwargs) + offset = 0 + for i in range(inputs.get_batch_size()): + res = result[offset:offset + input_size[i]] + if self.task == "conversational": + res = [{ + "generated_text": s.generated_responses[-1], + "conversation": { + "past_user_inputs": s.past_user_inputs, + "generated_responses": s.generated_responses, + }, + } for s in res] + outputs.add(res, key=inputs.get_content().key_at(i)) + offset += input_size[i] + + outputs.add_property("content-type", "application/json") + return outputs diff --git a/engines/python/setup/djl_python/fastertransformer.py b/engines/python/setup/djl_python/fastertransformer.py index e77e2b2d5..0a3088b21 100644 --- a/engines/python/setup/djl_python/fastertransformer.py +++ b/engines/python/setup/djl_python/fastertransformer.py @@ -109,70 +109,66 @@ def param_mapper(parameters: dict): return parameters def inference(self, inputs: Input): - try: - # TODO: Add support for more content types - input_data = [] - input_size = [] - parameters = {} - batches = inputs.get_batches() - first = True - for item in batches: - input_map = item.get_as_json() - input_text = input_map.pop("inputs", input_map) - if isinstance(input_text, str): - input_text = [input_text] - input_size.append(len(input_text)) - input_data.extend(input_text) - if first: - parameters = input_map.pop("parameters", {}) - first = False - else: - if parameters != input_map.pop("parameters", {}): - return Output().error( - "In order to enable dynamic batching, all input batches must have the same parameters" - ) - - parameters = self.param_mapper(parameters) - max_length = parameters.pop("max_length", 50) - output_len = parameters.pop("max_seq_len", max_length) - if self.use_triton: - output_length = [output_len] * len(input_data) - if self.enable_streaming: - outputs = Output() - outputs.add_property("content-type", - "application/jsonlines") - outputs.add_stream_content( - self.model.stream_generate(input_data, output_length, - **parameters)) - return outputs - result = self.model.pipeline_generate(input_data, - output_length, - **parameters) + # TODO: Add support for more content types + input_data = [] + input_size = [] + parameters = {} + batches = inputs.get_batches() + first = True + for item in batches: + input_map = item.get_as_json() + input_text = input_map.pop("inputs", input_map) + if isinstance(input_text, str): + input_text = [input_text] + input_size.append(len(input_text)) + input_data.extend(input_text) + if first: + parameters = input_map.pop("parameters", {}) + first = False else: - if self.is_t5: - result = self.model.pipeline_generate( - input_data, **parameters) - else: - beam_width = parameters.pop("beam_width", 1) - # TODO: remove after fixes in FT python package - result = self.model.pipeline_generate( - input_data, - batch_size=len(input_data), - output_len=output_len, - beam_width=beam_width, - **parameters) - - offset = 0 - outputs = Output() - outputs.add_property("content-type", "application/json") - for i in range(inputs.get_batch_size()): - generated_text = [{ - "generated_text": s - } for s in result[offset:offset + input_size[i]]] - outputs.add(generated_text, key=inputs.get_content().key_at(i)) - except Exception as e: - logging.exception("FasterTransformer inference failed") - raise e + if parameters != input_map.pop("parameters", {}): + return Output().error( + "In order to enable dynamic batching, all input batches must have the same parameters" + ) + + parameters = self.param_mapper(parameters) + max_length = parameters.pop("max_length", 50) + output_len = parameters.pop("max_seq_len", max_length) + if self.use_triton: + output_length = [output_len] * len(input_data) + if self.enable_streaming: + outputs = Output() + outputs.add_property("content-type", + "application/jsonlines") + outputs.add_stream_content( + self.model.stream_generate(input_data, output_length, + **parameters)) + return outputs + result = self.model.pipeline_generate(input_data, + output_length, + **parameters) + else: + if self.is_t5: + result = self.model.pipeline_generate( + input_data, **parameters) + else: + beam_width = parameters.pop("beam_width", 1) + # TODO: remove after fixes in FT python package + result = self.model.pipeline_generate( + input_data, + batch_size=len(input_data), + output_len=output_len, + beam_width=beam_width, + **parameters) + + offset = 0 + outputs = Output() + outputs.add_property("content-type", "application/json") + for i in range(inputs.get_batch_size()): + generated_text = [{ + "generated_text": s + } for s in result[offset:offset + input_size[i]]] + outputs.add(generated_text, key=inputs.get_content().key_at(i)) return outputs diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 05cd560de..0fe59f799 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -148,74 +148,70 @@ def initialize(self, properties: dict): self.initialized = True def inference(self, inputs): - try: - content_type = inputs.get_property("Content-Type") - accept = inputs.get_property("Accept") - if not accept: - accept = content_type if content_type.startswith( - "tensor/") else "application/json" - elif "*/*" in accept: - accept = "application/json" - - input_data = [] - input_size = [] - parameters = [] - batch = inputs.get_batches() - first = True - for item in batch: - input_map = decode(item, content_type) - input_size.append(len(input_map.get("inputs"))) - _inputs = input_map.pop("inputs", input_map) - if isinstance(_inputs, list): - input_data.extend(_inputs) - else: - input_data.append(_inputs) - if first or self.enable_rolling_batch: - parameters.append(input_map.pop("parameters", {})) - first = False - else: - if parameters != input_map.pop("parameters", {}): - return Output().error( - "In order to enable dynamic batching, all input batches must have the same parameters" - ) - - outputs = Output() - - if self.enable_streaming: - outputs.add_property("content-type", "application/jsonlines") - if self.enable_streaming == "huggingface": - outputs.add_stream_content( - StreamingUtils.use_hf_default_streamer( - self.model, self.tokenizer, input_data, - self.device, **parameters[0])) - else: - stream_generator = StreamingUtils.get_stream_generator( - "Accelerate") - outputs.add_stream_content( - stream_generator(self.model, self.tokenizer, - input_data, self.device, - **parameters[0])) - return outputs - elif self.enable_rolling_batch: - result = self.rolling_batch.inference(input_data, parameters) - for i in range(len(batch)): - res = result[i] - outputs.add_as_json(res, batch_index=i) - - return outputs - - prediction = self.hf_pipeline(input_data, **parameters[0]) - - offset = 0 - for i in range(inputs.get_batch_size()): - encode(outputs, - prediction[offset:offset + input_size[i]], - accept, - key=inputs.get_content().key_at(i)) - offset += input_size[i] - except Exception as e: - logging.exception("Huggingface inference failed") - raise e + content_type = inputs.get_property("Content-Type") + accept = inputs.get_property("Accept") + if not accept: + accept = content_type if content_type.startswith( + "tensor/") else "application/json" + elif "*/*" in accept: + accept = "application/json" + + input_data = [] + input_size = [] + parameters = [] + batch = inputs.get_batches() + first = True + for item in batch: + input_map = decode(item, content_type) + input_size.append(len(input_map.get("inputs"))) + _inputs = input_map.pop("inputs", input_map) + if isinstance(_inputs, list): + input_data.extend(_inputs) + else: + input_data.append(_inputs) + if first or self.enable_rolling_batch: + parameters.append(input_map.pop("parameters", {})) + first = False + else: + if parameters != input_map.pop("parameters", {}): + return Output().error( + "In order to enable dynamic batching, all input batches must have the same parameters" + ) + + outputs = Output() + + if self.enable_streaming: + outputs.add_property("content-type", "application/jsonlines") + if self.enable_streaming == "huggingface": + outputs.add_stream_content( + StreamingUtils.use_hf_default_streamer( + self.model, self.tokenizer, input_data, + self.device, **parameters[0])) + else: + stream_generator = StreamingUtils.get_stream_generator( + "Accelerate") + outputs.add_stream_content( + stream_generator(self.model, self.tokenizer, + input_data, self.device, + **parameters[0])) + return outputs + elif self.enable_rolling_batch: + result = self.rolling_batch.inference(input_data, parameters) + for i in range(len(batch)): + res = result[i] + outputs.add_as_json(res, batch_index=i) + + return outputs + + prediction = self.hf_pipeline(input_data, **parameters[0]) + + offset = 0 + for i in range(inputs.get_batch_size()): + encode(outputs, + prediction[offset:offset + input_size[i]], + accept, + key=inputs.get_content().key_at(i)) + offset += input_size[i] return outputs diff --git a/engines/python/setup/djl_python/transformers-neuronx.py b/engines/python/setup/djl_python/transformers-neuronx.py index 864dc5e17..e58674d7d 100644 --- a/engines/python/setup/djl_python/transformers-neuronx.py +++ b/engines/python/setup/djl_python/transformers-neuronx.py @@ -172,51 +172,46 @@ def initialize(self, properties): self.initialized = True def infer(self, inputs): - try: - input_map = inputs.get_as_json() - input_text = input_map.pop("inputs", input_map) - parameters = input_map.pop("parameters", {}) - if isinstance(input_text, str): - input_text = [input_text] - if len(input_text) != self.batch_size: - raise ValueError( - f"{self.batch_size} batch size not equal to {len(input_text)} prompt size" - ) - outputs = Output() - model_kwargs = {} - - if self.enable_streaming: - outputs.add_property("content-type", "application/jsonlines") - if self.enable_streaming == "huggingface": - outputs.add_stream_content( - StreamingUtils.use_hf_default_streamer( - self.model, self.tokenizer, input_text, None, - **model_kwargs)) - else: - stream_generator = StreamingUtils.get_stream_generator( - "transformers-neuronx") - model_kwargs["engine"] = "transformers-neuronx" - outputs.add_stream_content( - stream_generator(self.model, self.tokenizer, - input_text, "cpu", **model_kwargs)) - return outputs - - encoded_inputs = self.tokenizer.batch_encode_plus( - input_text, return_tensors="pt", padding=True) - output_tokens = self.model.generate( - input_ids=encoded_inputs.input_ids, - attention_mask=encoded_inputs.attention_mask, - **parameters) - generated_text = self.tokenizer.batch_decode( - output_tokens, skip_special_tokens=True) - - return Output().add([{ - "generated_text": s - } for s in generated_text]) - - except Exception as e: - logging.exception("TransformerNeuronX inference failed") - raise e + input_map = inputs.get_as_json() + input_text = input_map.pop("inputs", input_map) + parameters = input_map.pop("parameters", {}) + if isinstance(input_text, str): + input_text = [input_text] + if len(input_text) != self.batch_size: + raise ValueError( + f"{self.batch_size} batch size not equal to {len(input_text)} prompt size" + ) + outputs = Output() + model_kwargs = {} + + if self.enable_streaming: + outputs.add_property("content-type", "application/jsonlines") + if self.enable_streaming == "huggingface": + outputs.add_stream_content( + StreamingUtils.use_hf_default_streamer( + self.model, self.tokenizer, input_text, None, + **model_kwargs)) + else: + stream_generator = StreamingUtils.get_stream_generator( + "transformers-neuronx") + model_kwargs["engine"] = "transformers-neuronx" + outputs.add_stream_content( + stream_generator(self.model, self.tokenizer, + input_text, "cpu", **model_kwargs)) + return outputs + + encoded_inputs = self.tokenizer.batch_encode_plus( + input_text, return_tensors="pt", padding=True) + output_tokens = self.model.generate( + input_ids=encoded_inputs.input_ids, + attention_mask=encoded_inputs.attention_mask, + **parameters) + generated_text = self.tokenizer.batch_decode( + output_tokens, skip_special_tokens=True) + + return Output().add([{ + "generated_text": s + } for s in generated_text]) _service = TransformersNeuronXService()