Skip to content

Commit

Permalink
refactor inference API
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Dec 27, 2024
1 parent 3f07733 commit 2382d7e
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 182 deletions.
22 changes: 11 additions & 11 deletions benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def __init__(self, model_path: str, engine_config, csv: str):
def process_request(self, requests, concurrency, temperature, top_p, top_k,
stream_output):

stats = OrderedDict(
(session_id, None) for session_id in range(len(requests)))
stats = OrderedDict((index, None) for index in range(len(requests)))
prompts = [prompt for prompt, _, _ in requests]
gen_configs = [
GenerationConfig(temperature=temperature,
top_p=top_p,
top_k=top_k,
ignore_eos=True,
do_sample=True,
max_new_tokens=output_len)
for _, _, output_len in requests
]
Expand All @@ -87,31 +87,31 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k,
for output in self.pipe.stream_infer(prompts,
gen_configs,
do_preprocess=False):
session_id = output.session_id
index = output.index
n_token = output.generate_token_len
finish_reason = output.finish_reason
stats[session_id] = (n_token, finish_reason)
stats[index] = (n_token, finish_reason)
if finish_reason is not None:
pbar.update(1)
else:
for output in self.pipe(prompts,
gen_configs,
do_preprocess=False,
use_tqdm=True):
session_id = output.session_id
index = output.index
n_token = output.generate_token_len
finish_reason = output.finish_reason
stats[session_id] = (n_token, finish_reason)
stats[index] = (n_token, finish_reason)

elapsed_time = time.perf_counter() - start

completion_tokens = 0
for session_id, (n_token, finish_reason) in stats.items():
for index, (n_token, finish_reason) in stats.items():
assert finish_reason == 'length', \
f'unexpected finish_reason of session_id={session_id}, ' \
f'prompt={requests[session_id][0]}'
assert n_token - 1 <= requests[session_id][-1] <= n_token, \
f'request to generate {requests[session_id][-1]} tokens, ' \
f'unexpected finish_reason of index={index}, ' \
f'prompt={requests[index][0]}'
assert n_token - 1 <= requests[index][-1] <= n_token, \
f'request to generate {requests[index][-1]} tokens, ' \
f'but got {n_token} tokens'
completion_tokens += n_token

Expand Down
1 change: 0 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ class Response:
text: str
generate_token_len: int
input_token_len: int
session_id: int
finish_reason: Optional[Literal['stop', 'length']] = None
token_ids: List[int] = field(default_factory=list)
logprobs: List[Dict[int, float]] = None
Expand Down
Loading

0 comments on commit 2382d7e

Please sign in to comment.