Skip to content

Commit

Permalink
decode by interval
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Nov 7, 2023
1 parent 120e078 commit 334426e
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ class Chatbot:
profile_generation (bool): profile token generation or not
"""

def __init__(self,
tritonserver_addr: str,
model_name: str = '',
ignore_eos: bool = False,
log_level: int = logging.INFO,
display: bool = False,
profile_generation: bool = False,
profile_serving: bool = False,
**model_kwargs):
def __init__(
self,
tritonserver_addr: str,
model_name: str = '',
ignore_eos: bool = False,
log_level: int = logging.INFO,
display: bool = False,
profile_generation: bool = False,
profile_serving: bool = False,
decode_interval: int = 5, # an empirical value
**model_kwargs):
self.tritonserver_addr = tritonserver_addr
self.model_name = model_name
if self.model_name == '':
Expand Down Expand Up @@ -109,7 +111,7 @@ def __init__(self,
self.display = display
self.profile_generation = profile_generation
self.profile_serving = profile_serving
self.interval = 50 # an empirical value
self.decode_interval = decode_interval

def stream_infer(self,
session_id: int,
Expand Down Expand Up @@ -504,7 +506,7 @@ def _stream_infer(self,
for status, res, n_token in self.stream_consumer(
self.postprocess, que, session, input_tokens, preseq_length,
cancel, logger, self.display, self.profile_generation,
self.eos_id):
self.eos_id, self.decode_interval):
yield status, res, n_token

producer.join()
Expand Down Expand Up @@ -598,7 +600,7 @@ def _stream_producer(tritonserver_addr, session, que, cfg, input_ids,
@staticmethod
def stream_consumer(postprocess, res_queue, session, n_input_token,
preseq_length, cancel, logger, display,
profile_generation, eos_id):
profile_generation, eos_id, decode_interval):
"""Consume the response from the triton inference server.
Args:
Expand All @@ -613,6 +615,7 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
display (bool): display the text in the consolo interface or not
profile_generation (bool): indicator for profiling token generation
eos_id (int): eos token id
decode_interval (int): the interval of decoding section by section.
Yields:
tuple: status, text, generated token number
Expand Down Expand Up @@ -666,7 +669,8 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
'postprocessing is ignored during profiling '
'token generation', output_ids.shape[-1])
continue
if not break_flag and output_ids.shape[-1] - n_token < 50:
if not break_flag and output_ids.shape[
-1] - n_token < decode_interval:
continue
output_str = postprocess(
output_ids, np.array([[n_token]], dtype=np.uint32))
Expand Down

0 comments on commit 334426e

Please sign in to comment.