Skip to content

Commit

Permalink
fix benchmark serving computation mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Oct 31, 2023
1 parent 56942c4 commit 4c1fde7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
1 change: 1 addition & 0 deletions benchmark/profile_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
print(f'elapsed time for read data: '
f'{round(time.perf_counter() - start, 2)} s')

print('start tokenization. This takes a while, please wait...')
start = time.perf_counter()
tokenizer = Tokenizer(tokenizer_path)
prompts_token_lens = [len(tokenizer.encode(prompt)) for prompt in prompts]
Expand Down
13 changes: 7 additions & 6 deletions benchmark/profile_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
completions = [completion for _, completion in dataset]
print(f'elapsed time for read data: '
f'{round(time.perf_counter() - start, 2)} s')
print('start tokenization. This takes a while, please wait...')

start = time.perf_counter()
tokenizer = Tokenizer(tokenizer_path)
Expand Down Expand Up @@ -124,7 +125,6 @@ def main(tritonserver_addr: str,
res_que = mp.Queue()

procs = []
_start = time.perf_counter()
for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
display=False,
Expand All @@ -134,13 +134,18 @@ def main(tritonserver_addr: str,
proc = mp.Process(target=infer,
args=(chatbot, i + 1, req_que, res_que))
procs.append(proc)
proc.start()

# read data and put it to queue
n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len,
req_que)
for i in range(concurrency):
req_que.put([None, None, None])
_start = time.perf_counter()
for proc in procs:
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()

stats = []
for i in range(concurrency):
Expand All @@ -150,7 +155,6 @@ def main(tritonserver_addr: str,
f'stats: \n{_stats}\n{"-" * 50}\n')
stats.append(np.array(_stats))

_end = time.perf_counter()
elapsed_time = _end - _start

stats = np.concatenate(stats).reshape(-1, 3)
Expand All @@ -170,9 +174,6 @@ def main(tritonserver_addr: str,
f'req throughput: {req_throughput:.3f} req/s\n'
f'{"-" * 50}\n')

for proc in procs:
proc.join()


if __name__ == '__main__':
fire.Fire(main)

0 comments on commit 4c1fde7

Please sign in to comment.