Skip to content

Commit

Permalink
Update client to use iterable client class
Browse files Browse the repository at this point in the history
  • Loading branch information
dyastremsky committed Oct 11, 2023
1 parent 0cd3d91 commit 45a531f
Showing 1 changed file with 124 additions and 94 deletions.
218 changes: 124 additions & 94 deletions samples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import argparse
import asyncio
import json
import queue

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'queue' is not used.
import sys
from os import system

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'system' is not used.

Expand All @@ -37,114 +38,141 @@
from tritonclient.utils import *


def create_request(
prompt,
stream,
request_id,
sampling_parameters,
model_name,
send_parameters_as_tensor=True,
):
inputs = []
prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_)
try:
inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
inputs[-1].set_data_from_numpy(prompt_data)
except Exception as e:
print(f"Encountered an error {e}")

stream_data = np.array([stream], dtype=bool)
inputs.append(grpcclient.InferInput("stream", [1], "BOOL"))
inputs[-1].set_data_from_numpy(stream_data)

# Request parameters are not yet supported via BLS. Provide an
# optional mechanism to send serialized parameters as an input
# tensor until support is added

if send_parameters_as_tensor:
sampling_parameters_data = np.array(
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
class LLMClient:
def __init__(self, flags: argparse.Namespace):
self._client = grpcclient.InferenceServerClient(
url=flags.url, verbose=flags.verbose
)
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
inputs[-1].set_data_from_numpy(sampling_parameters_data)

# Add requested outputs
outputs = []
outputs.append(grpcclient.InferRequestedOutput("text_output"))

# Issue the asynchronous sequence inference.
return {
"model_name": model_name,
"inputs": inputs,
"outputs": outputs,
"request_id": str(request_id),
"parameters": sampling_parameters,
}


async def main(FLAGS):
model_name = "vllm_opt"
sampling_parameters = {"temperature": "0.1", "top_p": "0.95"}
stream = FLAGS.streaming_mode
with open(FLAGS.input_prompts, "r") as file:
print(f"Loading inputs from `{FLAGS.input_prompts}`...")
prompts = file.readlines()

results_dict = {}

async with grpcclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose
) as triton_client:
# Request iterator that yields the next request
async def async_request_iterator():
try:
for iter in range(FLAGS.iterations):
for i, prompt in enumerate(prompts):
prompt_id = FLAGS.offset + (len(prompts) * iter) + i
results_dict[str(prompt_id)] = []
yield create_request(
prompt, stream, prompt_id, sampling_parameters, model_name
)
except Exception as error:
print(f"caught error in request iterator: {error}")
self._flags = flags
self._loop = asyncio.get_event_loop()
self._results_dict = {}

async def async_request_iterator(self, prompts, sampling_parameters):
try:
for iter in range(self._flags.iterations):
for i, prompt in enumerate(prompts):
prompt_id = self._flags.offset + (len(prompts) * iter) + i
self._results_dict[str(prompt_id)] = []
yield self.create_request(
prompt,
self._flags.streaming_mode,
prompt_id,
sampling_parameters,
)
except Exception as error:
print(f"Caught an error in the request iterator: {error}")

async def stream_infer(self, prompts, sampling_parameters):
try:
# Start streaming
response_iterator = triton_client.stream_infer(
inputs_iterator=async_request_iterator(),
stream_timeout=FLAGS.stream_timeout,
response_iterator = self._client.stream_infer(
inputs_iterator=self.async_request_iterator(
prompts, sampling_parameters
),
stream_timeout=self._flags.stream_timeout,
)
# Read response from the stream
async for response in response_iterator:
result, error = response
if error:
print(f"Encountered error while processing: {error}")
else:
output = result.as_numpy("text_output")
for i in output:
results_dict[result.get_response().id].append(i)

yield response
except InferenceServerException as error:
print(error)
sys.exit(1)

with open(FLAGS.results_file, "w") as file:
for id in results_dict.keys():
for result in results_dict[id]:
file.write(result.decode("utf-8"))
file.write("\n")
file.write("\n=========\n\n")
print(f"Storing results into `{FLAGS.results_file}`...")
async def process_stream(self, prompts, sampling_parameters):
# Clear results in between process_stream calls
self.results_dict = []

# Read response from the stream
async for response in self.stream_infer(prompts, sampling_parameters):
result, error = response
if error:
print(f"Encountered error while processing: {error}")
else:
output = result.as_numpy("TEXT")
for i in output:
self._results_dict[result.get_response().id].append(i)

async def run(self):
sampling_parameters = {"temperature": "0.1", "top_p": "0.95"}
stream = self._flags.streaming_mode

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable stream is not used.
with open(self._flags.input_prompts, "r") as file:
print(f"Loading inputs from `{self._flags.input_prompts}`...")
prompts = file.readlines()

await self.process_stream(prompts, sampling_parameters)

with open(self._flags.results_file, "w") as file:
for id in self._results_dict.keys():
for result in self._results_dict[id]:
file.write(result.decode("utf-8"))
file.write("\n")
file.write("\n=========\n\n")
print(f"Storing results into `{self._flags.results_file}`...")

if self._flags.verbose:
with open(self._flags.results_file, "r") as file:
print(f"\nContents of `{self._flags.results_file}` ===>")
print(file.read())

print("PASS: vLLM example")

def run_async(self):
self._loop.run_until_complete(self.run())

def create_request(
self,
prompt,
stream,
request_id,
sampling_parameters,
send_parameters_as_tensor=True,
):
inputs = []
prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_)
try:
inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES"))
inputs[-1].set_data_from_numpy(prompt_data)
except Exception as error:
print(f"Encountered an error during request creation: {error}")

stream_data = np.array([stream], dtype=bool)
inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL"))
inputs[-1].set_data_from_numpy(stream_data)

# Request parameters are not yet supported via BLS. Provide an
# optional mechanism to send serialized parameters as an input
# tensor until support is added

if send_parameters_as_tensor:
sampling_parameters_data = np.array(
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
)
inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES"))
inputs[-1].set_data_from_numpy(sampling_parameters_data)

if FLAGS.verbose:
print(f"\nContents of `{FLAGS.results_file}` ===>")
system(f"cat {FLAGS.results_file}")
# Add requested outputs
outputs = []
outputs.append(grpcclient.InferRequestedOutput("TEXT"))

print("PASS: vLLM example")
# Issue the asynchronous sequence inference.
return {
"model_name": self._flags.model,
"inputs": inputs,
"outputs": outputs,
"request_id": str(request_id),
"parameters": sampling_parameters,
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
type=str,
required=False,
default="vllm",
help="Model name",
)
parser.add_argument(
"-v",
"--verbose",
Expand All @@ -159,7 +187,7 @@ async def async_request_iterator():
type=str,
required=False,
default="localhost:8001",
help="Inference server URL and it gRPC port. Default is localhost:8001.",
help="Inference server URL and its gRPC port. Default is localhost:8001.",
)
parser.add_argument(
"-t",
Expand Down Expand Up @@ -206,4 +234,6 @@ async def async_request_iterator():
help="Enable streaming mode",
)
FLAGS = parser.parse_args()
asyncio.run(main(FLAGS))

client = LLMClient(FLAGS)
client.run_async()

0 comments on commit 45a531f

Please sign in to comment.