Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Jul 8, 2024
1 parent 009411d commit bdcfdf8
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 66 deletions.
18 changes: 10 additions & 8 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def __init__(self):
self.peft_config = None
self.stopping_criteria_list = None
self.adapter_registry = {}
self.adapters = None
self.hf_configs = None
self.input_format_args = None

Expand Down Expand Up @@ -254,14 +253,13 @@ def _dynamic_batch_inference(self, batch: List, errors: Dict,
inputs: Input, outputs: Output,
requests: List):
# Dynamic batching
input_data, input_size = get_input_details(requests, errors, batch)
parameters = requests[0].request_input.server_parameters
input_data, input_size, adapters, parameters = get_input_details(requests, errors, batch)

if isinstance(self.model, PeftModelForCausalLM):
if self.adapters is None:
if adapters is None:
# Inference with only base model
self.adapters = [""] * len(input_data)
parameters["adapters"] = self.adapters
adapters = [""] * len(input_data)
parameters["adapters"] = adapters
prediction = self.hf_pipeline(input_data, **parameters)
offset = 0
for i, item in enumerate(batch):
Expand Down Expand Up @@ -293,20 +291,24 @@ def _streaming_inference(self, batch: List, request_input: RequestInput,
if len(batch) > 1:
raise NotImplementedError(
"Dynamic batch not supported for generic streaming")

parameters = request_input.server_parameters
if isinstance(parameters, list):
parameters = parameters[0]
outputs.add_property("content-type", "application/jsonlines")
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
outputs.add_stream_content(
StreamingUtils.use_hf_default_streamer(
self.model, self.tokenizer, request_input.input_text,
self.hf_configs.device, **request_input.server_parameters))
self.hf_configs.device, **parameters))
else:
stream_generator = StreamingUtils.get_stream_generator(
"Accelerate")
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer,
request_input.input_text,
self.hf_configs.device,
**request_input.server_parameters))
**parameters))
return outputs

def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
Expand Down
60 changes: 35 additions & 25 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
request_id_counter = get_req_id_counter(kwargs)
for i, input_item in enumerate(batch):
try:
kwargs["is_rolling_batch"] = is_rolling_batch_enabled(
kwargs.get("configs").rolling_batch)
request_id = request_id_counter.next_id(
) if request_id_counter else i
# TODO: Decide whether it is a text input based on content-type
Expand All @@ -70,7 +72,7 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:

def get_req_id_counter(kwargs):
req_id_counter = None
if is_rolling_batch_enabled(kwargs.get("configs").rolling_batch):
if kwargs.get("is_rolling_batch"):
req_id_counter = kwargs.get("rolling_batch").req_id_counter
return req_id_counter

Expand All @@ -89,24 +91,27 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
invoke_type = input_item.get_property("X-Amzn-SageMaker-Forwarded-Api")
tokenizer = kwargs.get("tokenizer")
if is_chat_completions_request(input_map):
_inputs, _param = parse_chat_completions_request(
inputs, param = parse_chat_completions_request(
input_map, kwargs.get("is_rolling_batch"), tokenizer)
elif is_3p_request(invoke_type):
_inputs, _param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
tokenizer, invoke_type)
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
tokenizer, invoke_type)
else:
_inputs = input_map.pop("inputs", input_map)
_param = input_map.pop("parameters", {})

request_input.input_text = _inputs
request_input.parameters = _param
# assign input_ids
if kwargs.get("tokenizer"):
inputs = input_map.pop("inputs", input_map)
param = input_map.pop("parameters", {})

request_input.input_text = inputs
request_input.parameters = param
# assigns input_ids
# TODO: for dynamic batching, or HF pipeline, tokenizer is applied differently.
if kwargs.get("tokenizer") and kwargs.get("is_rolling_batch"):
request_input.input_ids = tokenizer.encode(request_input.input_text)

# TODO: Instead of modifying user parameters, maintain this in server_parameters.
# Added here for backward compatibility
# re-organize the parameters
if is_rolling_batch_enabled(kwargs.get("configs").rolling_batch):
if kwargs.get("is_rolling_batch"):
if "stream" in input_map:
request_input.parameters["stream"] = input_map.pop("stream")
if "cached_prompt" in input_map:
Expand All @@ -124,18 +129,17 @@ def add_server_maintained_params(request_input: TextInput, input_item: Input,
if input_item.contains_key("seed"):
request_input.server_parameters["seed"] = input_item.get_as_string(
key="seed")
if not "output_formatter" in request_input.server_parameters:
request_input.server_parameters["output_formatter"] = kwargs.get(
"configs").output_formatter

request_input.output_formatter = request_input.server_parameters.get(
"output_formatter")
# setting the output formatter
if not "output_formatter" in request_input.server_parameters:
request_input.server_parameters["output_formatter"] = kwargs.get("configs").output_formatter

if request_input.output_formatter == "json" or request_input.output_formatter == "sse":
output_formatter = request_input.server_parameters["output_formatter"]
if output_formatter == "json" or output_formatter == "sse":
request_input.tgi_compat = kwargs.get("configs").tgi_compat

# duplicating parameters for client side batching
if isinstance(request_input.input_text, list):
if isinstance(request_input.input_text, list) and len(
request_input.input_text) > 1:
parameters = []
for _ in range(len(request_input.input_text)):
parameters.append(request_input.server_parameters.copy())
Expand All @@ -147,22 +151,28 @@ def parse_adapters(request_input: TextInput, input_item: Input,
adapter_registry = kwargs.get("adapter_registry")
# if adapter registry exists and not empty, then we assume, peft is supported for the incoming
if adapter_registry:
input_len = len(request_input.input_text) if isinstance(
request_input.input_text, list) else 1
adapters_per_item = _fetch_adapters_from_input(input_map, input_item)
if adapters_per_item:
_validate_adapters(adapters_per_item,
kwargs.get("adapter_registry"))
else:
# inference with just base model.
adapters_per_item = [""] * len(request_input.input_text)
adapters_per_item = [""] * input_len

if len(request_input.input_text) != len(adapters_per_item):
if input_len != len(adapters_per_item):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
# lookup the adapter registry to get the adapter details of the registered adapter.
request_input.adapters = [
adapters_data = [
kwargs.get("adapter_registry").get(adapter, None)
for adapter in adapter_registry
for adapter in adapters_per_item
]
if len(adapters_data) == 1:
adapters_data = adapters_data[0]

request_input.adapters = adapters_data


def _fetch_adapters_from_input(input_map: dict, input_item: Input):
Expand Down
6 changes: 4 additions & 2 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def __init__(self, request_input: TextInput = None):
self.adapter = request_input.adapters

# output formatter
stream = self.request_input.parameters.get("stream", False)
request_input.output_formatter = self.parameters.pop("output_formatter")
request_input.stream = self.parameters.pop("stream", False)
self.output_formatter, self.content_type = get_output_formatter(
request_input.output_formatter, stream, request_input.tgi_compat)
request_input.output_formatter, request_input.stream, request_input.tgi_compat)
request_input.output_formatter = self.output_formatter
self.legacy_formatter = self._is_output_formatter_legacy()

self.request_output = TextGenerationOutput(request_id=self.id,
Expand Down
7 changes: 4 additions & 3 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ class RequestInput:
Attributes:
request_id: The request ID.
output_formatter: Output formatter of the request
parameters: parameters in the request payload, will be used in the output formatter
parameters: parameters in the request payload
server_parameters: parameters that are modified by the built-in handlers to support backend engines.
"""
request_id: int
output_formatter: Union[Callable, str] = None
stream: Optional[bool] = False
parameters: Dict = field(default_factory=lambda: {})
server_parameters: Dict = field(default_factory=lambda: {})
tgi_compat: bool = False
Expand All @@ -147,11 +149,10 @@ class TextInput(RequestInput):
adapters: adapter used for the request.
tokenizer: tokenizer used for the request.
"""
input_text: str = None
input_text: Union[str, List[str]] = None
input_ids: List[int] = field(default_factory=lambda: [])
adapters: Optional[Any] = None
tokenizer: Optional[Any] = None
found_adapters: bool = False

def prompt_tokens_length(self) -> int:
return len(self.input_ids)
Expand Down
10 changes: 5 additions & 5 deletions engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TRTLLMService(object):
"""

def __init__(self):
self.input_format_args = None
self.initialized = False
self.trt_configs = None
self.rolling_batch = None
Expand All @@ -40,6 +41,7 @@ def initialize(self, properties: dict):
self.rolling_batch = TRTLLMRollingBatch(
self.trt_configs.model_id_or_path, properties, self.trt_configs)
self.tokenizer = self.rolling_batch.get_tokenizer()
self.input_format_args = self.get_input_format_args()
self.initialized = True
return

Expand All @@ -54,16 +56,14 @@ def inference(self, inputs: Input) -> Output:
"""
Does preprocessing and sends new requests to the rolling batch script for inference
:param inputs (Input): a batch of inputs, each corresponding to a new request
:param inputs: (Input) a batch of inputs, each corresponding to a new request
:return outputs (Output): a batch of outputs that contain status code, output text, and other information
"""
outputs = Output()
kwargs = self.__dict__
kwargs[
"configs"] = self.trt_configs # TODO: Rename it to configs, so it would uniform in all handlers

parsed_input = parse_input_with_formatter(inputs, **kwargs)
parsed_input = parse_input_with_formatter(inputs,
**self.input_format_args)
if len(parsed_input.requests) == 0:
for i in range(len(parsed_input.batch)):
err = parsed_input.errors.get(i)
Expand Down
17 changes: 8 additions & 9 deletions engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def _get_generation_result_from_python_backend(generations, inputs_size):
log_prob = curr_cum_log_prob - cum_log_probs[i]
token_result = {
'id':
_get_value_based_on_tensor(generation[i].token_id,
index=0),
_get_value_based_on_tensor(generation[i].token_id,
index=0),
'text':
generation[i].token_text,
generation[i].token_text,
'log_prob':
log_prob if i < len(tokens_results) else curr_cum_log_prob,
log_prob if i < len(tokens_results) else curr_cum_log_prob,
}
cum_log_probs[i] = curr_cum_log_prob
tokens_results[i].append(token_result)
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_input_format_args(self):
return {
"configs": self.trt_configs,
"tokenizer":
None, # tokenizer, for chat completions is not supported for python backend.
None, # tokenizer, for chat completions is not supported for python backend.
}

def inference(self, inputs: Input) -> Output:
Expand All @@ -121,10 +121,9 @@ def inference(self, inputs: Input) -> Output:
outputs.add(err, key="data", batch_index=i)
return outputs

input_data, input_size = get_input_details(parsed_input.requests,
parsed_input.errors,
parsed_input.batch)
params = parsed_input.requests[0].request_input.server_parameters
input_data, input_size, params, _ = get_input_details(parsed_input.requests,
parsed_input.errors,
parsed_input.batch)

if "output_formatter" in params:
# output formatter is not supported for TensorRT-LLM python backend.
Expand Down
31 changes: 31 additions & 0 deletions engines/python/setup/djl_python/tests/test_input_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

from djl_python.input_parser import parse_input_with_formatter
from djl_python.test_model import create_concurrent_batch_request


class InputParserTest(unittest.TestCase):

def test_input_parameters(self):
inputs = [{
"inputs": "The winner of oscar this year is",
"parameters": {
"max_new_tokens": 50
},
"stream": False
}, {
"inputs": "A little redhood is",
"parameters": {
"min_new_tokens": 51,
"max_new_tokens": 256,

},
"stream": True
}]

serving_properties = {
"rolling_batch": "disable"
}

inputs = create_concurrent_batch_request(inputs, serving_properties=serving_properties)
parsed_input = parse_input_with_formatter(inputs)
10 changes: 5 additions & 5 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def test_jsonlines_fmt(self):
TextInput(request_id=1,
input_text="This is a wonderful day",
parameters={
"max_new_tokens": 256,
"stream": True
}))
"max_new_tokens": 256
},
stream=True))
for req in [req1, req2]:
req.set_next_token(Token(244, "He", -0.334532))
print(req.get_next_token(), end='')
Expand Down Expand Up @@ -209,9 +209,9 @@ def test_sse_tgi_compat_fmt(self):
TextInput(request_id=1,
input_text="This is a wonderful day",
parameters={
"max_new_tokens": 256,
"stream": True
"max_new_tokens": 256
},
stream=True,
tgi_compat=True))
req.set_next_token(Token(244, "He", -0.334532))
next_token = req.get_next_token()
Expand Down
9 changes: 3 additions & 6 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def partition(self, properties: dict):
self.initialized = True

def inference(self, inputs: Input) -> Output:
parsed_input = parse_input_with_formatter(inputs, **self.__dict__)
parsed_input = parse_input_with_formatter(inputs,
**self.input_format_args)
errors = parsed_input.errors
requests = parsed_input.requests
outputs = Output()
Expand All @@ -229,11 +230,7 @@ def inference(self, inputs: Input) -> Output:
self.rolling_batch)

batch = parsed_input.batch
input_data, input_size = get_input_details(requests, errors, batch)
parameters = parsed_input.requests[0].request_input.server_parameters
# Remove rolling batch default parameters
parameters.pop("output_formatter", None)
parameters.pop("stream", None)
input_data, input_size, parameters, _ = get_input_details(requests, errors, batch)
model_kwargs = {}

prompt_size = len(input_data)
Expand Down
8 changes: 5 additions & 3 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ def get_input_details(requests, errors, batch):
input_size = []
adapters = []
idx = 0
request_input = requests[0].request_input
parameters = request_input.server_parameters
parameters = requests[0].request_input.server_parameters
if isinstance(parameters, list):
parameters = parameters[0]

for i in range(len(batch)):
if i in errors:
input_size.append(0)
Expand All @@ -134,4 +136,4 @@ def get_input_details(requests, errors, batch):

idx += 1
adapters = adapters if adapters else None
return input_data, input_size, adapters
return input_data, input_size, parameters, adapters

0 comments on commit bdcfdf8

Please sign in to comment.