Skip to content

Commit

Permalink
remove turbomind sync interface
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Dec 27, 2024
1 parent 2382d7e commit 747252c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 168 deletions.
33 changes: 3 additions & 30 deletions lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,6 @@ def input_prompt(model_name):
return '\n'.join(iter(input, sentinel))


def infer(generator, session_id, input_ids, gen_config, sequence_start, step,
stream_output, tokenizer, state):
for outputs in generator.stream_infer(session_id=session_id,
input_ids=input_ids,
gen_config=gen_config,
sequence_start=sequence_start,
sequence_end=False,
step=step,
stream_output=stream_output):
res, tokens = input_ids + outputs.token_ids, outputs.num_token
# decode res
response, state = tokenizer.detokenize_incrementally(res, state=state)
print(response, end='', flush=True)
return tokens


async def async_infer(generator, session_id, input_ids, gen_config,
sequence_start, step, stream_output, tokenizer, state):
token_ids = input_ids.copy()
Expand All @@ -64,8 +48,6 @@ async def async_infer(generator, session_id, input_ids, gen_config,
state=state)
prev_len = tokens
print(response, end='', flush=True)
# if 'I' in response:
# await generator.async_cancel()
return tokens


Expand All @@ -88,7 +70,6 @@ def main(model_path: str,
stream_output: bool = True,
request_output_len: int = 1024,
chat_template_config: ChatTemplateConfig = None,
use_async: bool = True,
**kwargs):
"""An example to perform model inference through the command line
interface.
Expand Down Expand Up @@ -183,10 +164,7 @@ def main(model_path: str,
if prompt == 'exit':
exit(0)
elif prompt == 'end':
if use_async:
loop.run_until_complete(generator.async_end(session_id))
else:
generator.end(session_id)
loop.run_until_complete(generator.async_end(session_id))
nth_round = 1
step = 0
seed = random.getrandbits(64)
Expand All @@ -210,15 +188,10 @@ def main(model_path: str,
print(f'{prompt}', end='', flush=True)
state = DetokenizeState(len(input_ids))

if use_async:
coro = async_infer(generator, session_id, input_ids,
gen_config, sequence_start, step,
stream_output, tokenizer, state)
tokens = loop.run_until_complete(coro)
else:
tokens = infer(generator, session_id, input_ids, gen_config,
coro = async_infer(generator, session_id, input_ids, gen_config,
sequence_start, step, stream_output, tokenizer,
state)
tokens = loop.run_until_complete(coro)

# update step
step += len(input_ids) + tokens
Expand Down
138 changes: 0 additions & 138 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os.path as osp
import sys
import threading
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
Expand Down Expand Up @@ -645,143 +644,6 @@ def _get_generation_config(self, cfg: GenerationConfig):
# print (c)
return c

def signal_cb(self):
with self.cond:
self.flag = 1
self.cond.notify()

def end_cb(self, status: int):
print(f'session ended, status = {status}')
self.end_event.set()

def end(self):
self.done_event.wait()
self.end_event = threading.Event()
self.model_inst.end(self.end_cb)
self.end_event.wait()

def cancel(self, session_id: int, blocking: bool = True):
self.model_inst.cancel()
if blocking:
self.done_event.wait()

def stream_infer(self,
session_id,
input_ids,
input_embeddings=None,
input_embedding_ranges=None,
sequence_start: bool = True,
sequence_end: bool = False,
step=0,
stop=False,
gen_config: GenerationConfig = None,
stream_output=False,
**kwargs):
"""Perform model inference.
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
input_embeddings (List[numpy.ndarray]): embeddings features
input_embedding_ranges (List[Tuple[int,int]]): the begin/end
offsets of input_embeddings to input_ids
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): indicator for cancelling the session
gen_config (GenerationConfig): generation config
stream_output (bool): indicator for stream output
kwargs (dict): kwargs for backward compatibility
"""

gen_cfg = self._get_generation_config(gen_config)

inputs, input_length = self.prepare_inputs(
input_ids=input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
gen_config=gen_config)

inputs = _np_dict_to_tm_dict(inputs)

session = _tm.SessionParam(id=session_id,
step=step,
start=sequence_start,
end=sequence_end,
stop=stop)

self.cond = threading.Condition()
self.flag = 0
self.done_event = threading.Event()

outputs, shared_state = self.model_inst.forward(
inputs, session, gen_cfg, stream_output, self.signal_cb)

outputs = _tm_dict_to_torch_dict(outputs)

output_ids_buf = outputs['output_ids']

out_logprobs = None
finish = False
state = None

output_ids = []
output_len = 0
prev_len = step + input_length[0]

try:
# generator
while True:
with self.cond:
while not self.flag:
self.cond.wait()
self.flag = 0

state = shared_state.consume()
status, seq_len = state.status, state.seq_len

if status in [7, 8]: # TODO: use enum
finish = True
status = 0
elif status:
yield self._get_error_output()
break

if seq_len == prev_len and not finish:
continue

output_ids += output_ids_buf[prev_len:seq_len].tolist()
output_len += seq_len - prev_len

status = ResponseType.FINISH if finish else ResponseType.SUCCESS # noqa
output = EngineOutput(status, output_ids, output_len.item(),
out_logprobs)

prev_len = seq_len

if out_logprobs:
output_token_len = len(output.token_ids)
output.logprobs = out_logprobs[:output_token_len]

yield output

if finish:
break

except Exception as e:
logger.error(e)
yield self._get_error_output()

finally:
with self.cond:
# Contract: `cb` won't be called again if status is non-zero
# wait for status to be set as `finish` or `error`
while not state or state.status == 0:
while not self.flag:
self.cond.wait()
state = shared_state.consume()
self.cond = None

def decode(self,
input_ids,
steps: List[int] = None,
Expand Down

0 comments on commit 747252c

Please sign in to comment.