Skip to content

Commit

Permalink
[python] move parse input functions to input_parser.py (#2092)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Jul 3, 2024
1 parent de562de commit 1a46376
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 175 deletions.
22 changes: 12 additions & 10 deletions engines/python/setup/djl_python/chat_completions/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict

from djl_python.chat_completions.chat_properties import ChatProperties


def is_chat_completions_request(inputs: map) -> bool:
def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
tokenizer):
if not is_rolling_batch:
raise ValueError(
Expand All @@ -28,14 +30,14 @@ def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
raise AttributeError(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")
chat_params = ChatProperties(**inputs)
_param = chat_params.model_dump(by_alias=True, exclude_none=True)
_messages = _param.pop("messages")
_inputs = tokenizer.apply_chat_template(_messages, tokenize=False)
_param[
chat_params = ChatProperties(**input_map)
param = chat_params.model_dump(by_alias=True, exclude_none=True)
messages = param.pop("messages")
inputs = tokenizer.apply_chat_template(messages, tokenize=False)
param[
"do_sample"] = chat_params.temperature is not None and chat_params.temperature > 0.0
_param["details"] = True # Enable details for chat completions
_param[
param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

return _inputs, _param
return inputs, param
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, ParsedInput, rolling_batch_inference
from djl_python.utils import rolling_batch_inference
from djl_python.input_parser import ParsedInput, InputFormatConfigs, parse_input_with_formatter

ARCHITECTURES_2_TASK = {
"TapasForQuestionAnswering": "table-question-answering",
Expand Down
176 changes: 176 additions & 0 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import List, Union, Callable, Any

from djl_python import Input
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from djl_python.encode_decode import decode
from djl_python.three_p.three_p_utils import is_3p_request, parse_3p_request


@dataclass
class ParsedInput:
input_data: List[str]
input_size: List[int]
parameters: List[dict]
errors: dict
batch: list
is_client_side_batch: list = field(default_factory=lambda: [])
adapters: list = None


@dataclass
class InputFormatConfigs:
is_rolling_batch: bool = False
is_adapters_supported: bool = False
output_formatter: Union[str, Callable] = None
tokenizer: Any = None


def parse_input_with_formatter(inputs: Input,
input_format_configs: InputFormatConfigs,
adapter_registry: dict = {}) -> ParsedInput:
"""
Preprocessing function that extracts information from Input objects.
:param input_format_configs: format configurations for the input.
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:return parsed_input: object of data class that contains all parsed input details
"""

input_data = []
input_size = []
parameters = []
adapters = []
errors = {}
found_adapters = False
batch = inputs.get_batches()
# only for dynamic batch
is_client_side_batch = [False for _ in range(len(batch))]
for i, item in enumerate(batch):
try:
content_type = item.get_property("Content-Type")
invoke_type = item.get_property("X-Amzn-SageMaker-Forwarded-Api")
input_map = decode(item, content_type)
_inputs, _param, is_client_side_batch[i] = _parse_inputs_params(
input_map, item, input_format_configs, invoke_type)
if input_format_configs.is_adapters_supported:
adapters_per_item, found_adapter_per_item = _parse_adapters(
_inputs, input_map, item, adapter_registry)
except Exception as e: # pylint: disable=broad-except
logging.warning(f"Parse input failed: {i}")
input_size.append(0)
errors[i] = str(e)
continue

input_data.extend(_inputs)
input_size.append(len(_inputs))

if input_format_configs.is_adapters_supported:
adapters.extend(adapters_per_item)
found_adapters = found_adapter_per_item or found_adapters

for _ in range(input_size[i]):
parameters.append(_param)

if found_adapters and adapters is not None:
adapter_data = [
adapter_registry.get(adapter, None) for adapter in adapters
]
else:
adapter_data = None

return ParsedInput(input_data=input_data,
input_size=input_size,
parameters=parameters,
errors=errors,
batch=batch,
is_client_side_batch=is_client_side_batch,
adapters=adapter_data)


def _parse_inputs_params(input_map, item, input_format_configs, invoke_type):
if is_chat_completions_request(input_map):
_inputs, _param = parse_chat_completions_request(
input_map, input_format_configs.is_rolling_batch,
input_format_configs.tokenizer)
elif is_3p_request(invoke_type):
_inputs, _param = parse_3p_request(
input_map, input_format_configs.is_rolling_batch,
input_format_configs.tokenizer, invoke_type)
else:
_inputs = input_map.pop("inputs", input_map)
_param = input_map.pop("parameters", {})

# Add some additional parameters that are necessary.
# Per request streaming is only supported by rolling batch
if input_format_configs.is_rolling_batch:
_param["stream"] = input_map.pop("stream", _param.get("stream", False))

if "cached_prompt" in input_map:
_param["cached_prompt"] = input_map.pop("cached_prompt")
if "seed" not in _param:
# set server provided seed if seed is not part of request
if item.contains_key("seed"):
_param["seed"] = item.get_as_string(key="seed")
if not "output_formatter" in _param:
_param["output_formatter"] = input_format_configs.output_formatter

if isinstance(_inputs, list):
return _inputs, _param, True
else:
return [_inputs], _param, False


def _parse_adapters(_inputs, input_map, item,
adapter_registry) -> (List, bool):
adapters_per_item = _fetch_adapters_from_input(input_map, item)
found_adapter_per_item = False
if adapters_per_item:
_validate_adapters(adapters_per_item, adapter_registry)
found_adapter_per_item = True
else:
# inference with just base model.
adapters_per_item = [""] * len(_inputs)

if len(_inputs) != len(adapters_per_item):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
return adapters_per_item, found_adapter_per_item


def _fetch_adapters_from_input(input_map: dict, inputs: Input):
adapters_per_item = []
if "adapters" in input_map:
adapters_per_item = input_map.pop("adapters", [])

# check content, possible in workflow approach
if inputs.contains_key("adapter"):
adapters_per_item = inputs.get_as_string("adapter")

# check properties, possible from header
if "adapter" in inputs.get_properties():
adapters_per_item = inputs.get_properties()["adapter"]

if not isinstance(adapters_per_item, list):
adapters_per_item = [adapters_per_item]

return adapters_per_item


def _validate_adapters(adapters_per_item, adapter_registry):
for adapter_name in adapters_per_item:
if adapter_name and adapter_name not in adapter_registry:
raise ValueError(f"Adapter {adapter_name} is not registered")
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from djl_python.rolling_batch.trtllm_rolling_batch import TRTLLMRollingBatch
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.tensorrt_llm_python import TRTLLMPythonService
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, rolling_batch_inference
from djl_python.utils import rolling_batch_inference
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
from typing import List, Tuple


Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from djl_python.encode_decode import encode
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter


def _get_value_based_on_tensor(value, index=None):
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled
from djl_python.neuron_utils.model_loader import TNXModelLoader, OptimumModelLoader
from djl_python.neuron_utils.utils import task_from_config, build_vllm_rb_properties
from djl_python.utils import InputFormatConfigs, parse_input_with_formatter, rolling_batch_inference
from djl_python.utils import rolling_batch_inference
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
from typing import Tuple, List

model = None
Expand Down
Loading

0 comments on commit 1a46376

Please sign in to comment.