Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor async engine & turbomind IO #2968

Merged
merged 44 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7b8a841
refactor
lzhangzz Nov 26, 2024
160ba9e
Merge remote-tracking branch 'origin/main' into refactor-1
lzhangzz Nov 29, 2024
382f92b
async interface
lzhangzz Dec 5, 2024
9c56be8
Merge remote-tracking branch 'origin/main' into refactor-3
lzhangzz Dec 5, 2024
2cf49bd
update perf metrics & adaptive tokens per tick
lzhangzz Dec 11, 2024
aa5573d
wait-free
lzhangzz Dec 12, 2024
6378aaa
refactor gateway
lzhangzz Dec 19, 2024
9812538
optimize throughput
lzhangzz Dec 21, 2024
8baa784
add cancel cb
lzhangzz Dec 21, 2024
1bc68d1
simplify async engine
lzhangzz Dec 21, 2024
f220762
simplify async engine
lzhangzz Dec 22, 2024
31c6223
fix end session
lzhangzz Dec 22, 2024
b3d15b1
faster synchronization
lzhangzz Dec 23, 2024
c6fd260
fix async engine
lzhangzz Dec 23, 2024
8fa85dc
refactor async engine
lzhangzz Dec 24, 2024
3f07733
fix semaphore
lzhangzz Dec 25, 2024
2382d7e
refactor inference API
lzhangzz Dec 27, 2024
747252c
remove turbomind sync interface
lzhangzz Dec 27, 2024
54df9f1
Merge remote-tracking branch 'origin/main' into refactor-3
lzhangzz Dec 27, 2024
5266f27
fix msvc build
lzhangzz Jan 1, 2025
33ad2be
fix msvc build
lzhangzz Jan 1, 2025
1c20608
fix msvc build
lzhangzz Jan 1, 2025
6d1d209
Merge remote-tracking branch 'origin/main' into refactor-3
lzhangzz Jan 6, 2025
43020b5
add extra outputs
lzhangzz Jan 6, 2025
8412518
skip stop tokens
lzhangzz Jan 7, 2025
3409742
exit gracefully
lzhangzz Jan 7, 2025
21a7553
cancel all tasks atexit
lzhangzz Jan 7, 2025
49701df
refactor profiler
lzhangzz Jan 7, 2025
f4b37af
fix id2step for api server
lzhangzz Jan 7, 2025
2644fb7
save csv
lzhangzz Jan 8, 2025
6029a2e
fix interactive
lzhangzz Jan 8, 2025
50fdb68
fix lint
lzhangzz Jan 8, 2025
e2ed1a2
fix generate_token_len
lzhangzz Jan 8, 2025
21432bf
fix async_end
lzhangzz Jan 8, 2025
ad0e07c
update pipeline ut
lzhangzz Jan 8, 2025
4186da5
fix ignore eos
lzhangzz Jan 8, 2025
bee78b6
minor
lzhangzz Jan 8, 2025
5f02cad
refactor profile pipeline api
lzhangzz Jan 8, 2025
1965327
fix stop ids
lzhangzz Jan 9, 2025
7b513cb
fix duplication
lzhangzz Jan 9, 2025
2e3a17d
control output range of logits & last hidden states
lzhangzz Jan 10, 2025
80108df
fix lint & typo
lzhangzz Jan 10, 2025
6c2f901
fix blank response
lzhangzz Jan 10, 2025
31b01f1
export batch & num prompts
lzhangzz Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autotest/interface/pipeline/test_pipeline_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def run_pipeline_testcase(config, model, backend, file_name):
result = True
for i in range(2):
result &= response[i].finish_reason == 'length'
result &= response[i].session_id == i
result &= response[i].index == i
save_pipeline_common_log(config, file_name, result, response)
del pipe
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions autotest/utils/pipeline_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def assert_pipeline_single_stream_return(output, logprobs_num: int = 0):

def assert_pipeline_batch_stream_return(output, size: int = 1):
for i in range(size):
output_list = [item for item in output if item.session_id == i]
output_list = [item for item in output if item.index == i]
result, msg = assert_pipeline_single_stream_return(output_list)
if not result:
return result, msg
Expand All @@ -249,7 +249,7 @@ def assert_pipeline_single_element(output,
result = True
result &= output.generate_token_len > 0
result &= output.input_token_len > 0
result &= output.session_id >= 0
result &= output.index >= 0
if is_last:
result &= len(output.text) >= 0
result &= output.finish_reason in ['stop', 'length']
Expand Down
116 changes: 54 additions & 62 deletions benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import csv
import json
import os
import random
import time
from collections import OrderedDict
from typing import List, Tuple

from tqdm import tqdm
Expand All @@ -14,6 +11,10 @@
from lmdeploy import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig, pipeline)
from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
from lmdeploy.profiler import Profiler, Session
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def sample_requests(dataset_path: str, num_requests: int,
Expand Down Expand Up @@ -66,91 +67,70 @@ def __init__(self, model_path: str, engine_config, csv: str):

self.csv = csv

def process_request(self, requests, concurrency, temperature, top_p, top_k,
stream_output):
def process_request(self, requests, profiler: Profiler, temperature, top_p,
top_k, stream_output):

stats = OrderedDict(
(session_id, None) for session_id in range(len(requests)))
prompts = [prompt for prompt, _, _ in requests]
gen_configs = [
GenerationConfig(temperature=temperature,
top_p=top_p,
top_k=top_k,
ignore_eos=True,
do_sample=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do_sample True?

max_new_tokens=output_len)
for _, _, output_len in requests
]

start = time.perf_counter()
sess: List[Session] = []
for _, input_len, output_len in requests:
sess.append(profiler.new_session(input_len, output_len))

def _to_status(finish_reason):
if finish_reason == 'length':
return Session.SUCCESS
else:
return Session.FAIL

profiler.start()

for s in sess:
s.tick(0)

if stream_output:
pbar = tqdm(total=len(requests))
for output in self.pipe.stream_infer(prompts,
gen_configs,
do_preprocess=False):
session_id = output.session_id
index = output.index
n_token = output.generate_token_len
finish_reason = output.finish_reason
stats[session_id] = (n_token, finish_reason)
sess[index].tick(n_token)
if finish_reason is not None:
sess[index].finish(_to_status(finish_reason))
pbar.update(1)
pbar.close()
else:
for output in self.pipe(prompts,
gen_configs,
do_preprocess=False,
use_tqdm=True):
session_id = output.session_id
index = output.index
n_token = output.generate_token_len
finish_reason = output.finish_reason
stats[session_id] = (n_token, finish_reason)

elapsed_time = time.perf_counter() - start

completion_tokens = 0
for session_id, (n_token, finish_reason) in stats.items():
assert finish_reason == 'length', \
f'unexpected finish_reason of session_id={session_id}, ' \
f'prompt={requests[session_id][0]}'
assert n_token - 1 <= requests[session_id][-1] <= n_token, \
f'request to generate {requests[session_id][-1]} tokens, ' \
f'but got {n_token} tokens'
completion_tokens += n_token

prompt_tokens = 0
for _, input_len, _ in requests:
prompt_tokens += input_len

completion_token_throughput = completion_tokens / elapsed_time
total_token_throughput = (prompt_tokens +
completion_tokens) / elapsed_time
rps = len(requests) / elapsed_time
rpm = rps * 60

print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.3f}s\n')

print(
f'number of prompts: {len(requests)}\n'
f'number of prompt tokens: {prompt_tokens:.0f}\n'
f'number of completion tokens: {completion_tokens:.0f}\n'
f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
f'RPS (request per second): {rps:.3f} req/s\n'
f'RPM (request per minute): {rpm:.3f} req/min\n'
f'{"-" * 50}\n')

if self.csv:
with open(self.csv, 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([
'batch', 'num_promts', 'RPS', 'RPM',
'throughput(out tok/s)', 'throughput(total tok/s)'
])
writer.writerow([
concurrency,
len(requests), f'{rps:.3f}', f'{rpm:.3f}',
f'{completion_token_throughput:.3f}',
f'{total_token_throughput:.3f}'
])
sess[index].tick(n_token)
sess[index].finish(_to_status(finish_reason))

profiler.finish()

# report first failure
for i, s in enumerate(sess):
if s.status != Session.SUCCESS or s.ns[-1] < s.req_output_len:
logger.error(
f'Request {i} failed with {s.ns[-1]}/{s.req_output_len} tokens generated' # noqa: E501
)
logger.error(f'Prompt: {prompts[i]}')
logger.warning('Got failed requests, metrics may be invalid')
break


def parse_args():
Expand Down Expand Up @@ -252,13 +232,25 @@ def main():
requests = sample_requests(args.dataset, args.num_prompts,
engine.tokenizer)

profiler = Profiler(args.stream_output, [50, 75, 95, 99])

engine.process_request(requests,
profiler,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
concurrency=args.concurrency,
stream_output=args.stream_output)

hyperparams = [('Concurrency', args.concurrency),
('Stream output', str(args.stream_output).lower())]

profiler.compute_metrics()
profiler.summarize(title='Profile Pipeline API', hyperparams=hyperparams)

if args.csv:
profiler.save_csv(args.csv, (('batch', args.concurrency),
('num_prompts', args.num_prompts)))


if __name__ == '__main__':
main()
Loading
Loading