From 45a531fc0efab13fc28e560ec19bd50e02c62bdd Mon Sep 17 00:00:00 2001 From: David Yastremsky Date: Tue, 10 Oct 2023 19:32:14 -0700 Subject: [PATCH] Update client to use iterable client class --- samples/client.py | 218 ++++++++++++++++++++++++++-------------------- 1 file changed, 124 insertions(+), 94 deletions(-) diff --git a/samples/client.py b/samples/client.py index d29ad1de..93786a07 100755 --- a/samples/client.py +++ b/samples/client.py @@ -29,6 +29,7 @@ import argparse import asyncio import json +import queue import sys from os import system @@ -37,114 +38,141 @@ from tritonclient.utils import * -def create_request( - prompt, - stream, - request_id, - sampling_parameters, - model_name, - send_parameters_as_tensor=True, -): - inputs = [] - prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_) - try: - inputs.append(grpcclient.InferInput("text_input", [1], "BYTES")) - inputs[-1].set_data_from_numpy(prompt_data) - except Exception as e: - print(f"Encountered an error {e}") - - stream_data = np.array([stream], dtype=bool) - inputs.append(grpcclient.InferInput("stream", [1], "BOOL")) - inputs[-1].set_data_from_numpy(stream_data) - - # Request parameters are not yet supported via BLS. Provide an - # optional mechanism to send serialized parameters as an input - # tensor until support is added - - if send_parameters_as_tensor: - sampling_parameters_data = np.array( - [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ +class LLMClient: + def __init__(self, flags: argparse.Namespace): + self._client = grpcclient.InferenceServerClient( + url=flags.url, verbose=flags.verbose ) - inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES")) - inputs[-1].set_data_from_numpy(sampling_parameters_data) - - # Add requested outputs - outputs = [] - outputs.append(grpcclient.InferRequestedOutput("text_output")) - - # Issue the asynchronous sequence inference. - return { - "model_name": model_name, - "inputs": inputs, - "outputs": outputs, - "request_id": str(request_id), - "parameters": sampling_parameters, - } - - -async def main(FLAGS): - model_name = "vllm_opt" - sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} - stream = FLAGS.streaming_mode - with open(FLAGS.input_prompts, "r") as file: - print(f"Loading inputs from `{FLAGS.input_prompts}`...") - prompts = file.readlines() - - results_dict = {} - - async with grpcclient.InferenceServerClient( - url=FLAGS.url, verbose=FLAGS.verbose - ) as triton_client: - # Request iterator that yields the next request - async def async_request_iterator(): - try: - for iter in range(FLAGS.iterations): - for i, prompt in enumerate(prompts): - prompt_id = FLAGS.offset + (len(prompts) * iter) + i - results_dict[str(prompt_id)] = [] - yield create_request( - prompt, stream, prompt_id, sampling_parameters, model_name - ) - except Exception as error: - print(f"caught error in request iterator: {error}") + self._flags = flags + self._loop = asyncio.get_event_loop() + self._results_dict = {} + async def async_request_iterator(self, prompts, sampling_parameters): + try: + for iter in range(self._flags.iterations): + for i, prompt in enumerate(prompts): + prompt_id = self._flags.offset + (len(prompts) * iter) + i + self._results_dict[str(prompt_id)] = [] + yield self.create_request( + prompt, + self._flags.streaming_mode, + prompt_id, + sampling_parameters, + ) + except Exception as error: + print(f"Caught an error in the request iterator: {error}") + + async def stream_infer(self, prompts, sampling_parameters): try: # Start streaming - response_iterator = triton_client.stream_infer( - inputs_iterator=async_request_iterator(), - stream_timeout=FLAGS.stream_timeout, + response_iterator = self._client.stream_infer( + inputs_iterator=self.async_request_iterator( + prompts, sampling_parameters + ), + stream_timeout=self._flags.stream_timeout, ) - # Read response from the stream async for response in response_iterator: - result, error = response - if error: - print(f"Encountered error while processing: {error}") - else: - output = result.as_numpy("text_output") - for i in output: - results_dict[result.get_response().id].append(i) - + yield response except InferenceServerException as error: print(error) sys.exit(1) - with open(FLAGS.results_file, "w") as file: - for id in results_dict.keys(): - for result in results_dict[id]: - file.write(result.decode("utf-8")) - file.write("\n") - file.write("\n=========\n\n") - print(f"Storing results into `{FLAGS.results_file}`...") + async def process_stream(self, prompts, sampling_parameters): + # Clear results in between process_stream calls + self.results_dict = [] + + # Read response from the stream + async for response in self.stream_infer(prompts, sampling_parameters): + result, error = response + if error: + print(f"Encountered error while processing: {error}") + else: + output = result.as_numpy("TEXT") + for i in output: + self._results_dict[result.get_response().id].append(i) + + async def run(self): + sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} + stream = self._flags.streaming_mode + with open(self._flags.input_prompts, "r") as file: + print(f"Loading inputs from `{self._flags.input_prompts}`...") + prompts = file.readlines() + + await self.process_stream(prompts, sampling_parameters) + + with open(self._flags.results_file, "w") as file: + for id in self._results_dict.keys(): + for result in self._results_dict[id]: + file.write(result.decode("utf-8")) + file.write("\n") + file.write("\n=========\n\n") + print(f"Storing results into `{self._flags.results_file}`...") + + if self._flags.verbose: + with open(self._flags.results_file, "r") as file: + print(f"\nContents of `{self._flags.results_file}` ===>") + print(file.read()) + + print("PASS: vLLM example") + + def run_async(self): + self._loop.run_until_complete(self.run()) + + def create_request( + self, + prompt, + stream, + request_id, + sampling_parameters, + send_parameters_as_tensor=True, + ): + inputs = [] + prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_) + try: + inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES")) + inputs[-1].set_data_from_numpy(prompt_data) + except Exception as error: + print(f"Encountered an error during request creation: {error}") + + stream_data = np.array([stream], dtype=bool) + inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL")) + inputs[-1].set_data_from_numpy(stream_data) + + # Request parameters are not yet supported via BLS. Provide an + # optional mechanism to send serialized parameters as an input + # tensor until support is added + + if send_parameters_as_tensor: + sampling_parameters_data = np.array( + [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ + ) + inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES")) + inputs[-1].set_data_from_numpy(sampling_parameters_data) - if FLAGS.verbose: - print(f"\nContents of `{FLAGS.results_file}` ===>") - system(f"cat {FLAGS.results_file}") + # Add requested outputs + outputs = [] + outputs.append(grpcclient.InferRequestedOutput("TEXT")) - print("PASS: vLLM example") + # Issue the asynchronous sequence inference. + return { + "model_name": self._flags.model, + "inputs": inputs, + "outputs": outputs, + "request_id": str(request_id), + "parameters": sampling_parameters, + } if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + required=False, + default="vllm", + help="Model name", + ) parser.add_argument( "-v", "--verbose", @@ -159,7 +187,7 @@ async def async_request_iterator(): type=str, required=False, default="localhost:8001", - help="Inference server URL and it gRPC port. Default is localhost:8001.", + help="Inference server URL and its gRPC port. Default is localhost:8001.", ) parser.add_argument( "-t", @@ -206,4 +234,6 @@ async def async_request_iterator(): help="Enable streaming mode", ) FLAGS = parser.parse_args() - asyncio.run(main(FLAGS)) + + client = LLMClient(FLAGS) + client.run_async()