Skip to content

Commit

Permalink
two que for two threads
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Feb 7, 2024
1 parent 29696fa commit ae42806
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0):
self.model_insts = model_insts
self.que = Queue()
self.aque = asyncio.Queue()
self.afinal = asyncio.Queue()
self.threads = [None] * self.gpu_count

def _create_model_instance(self, device_id, model_insts):
Expand Down Expand Up @@ -529,7 +530,7 @@ def _func(device_id, enque_output):
output = self.model_insts[device_id].forward(
inputs, instance_comm)
if enque_output:
self.aque.put_nowait((True, output))
self.afinal.put_nowait((True, output))

for device_id in range(self.gpu_count):
t = Thread(target=_func,
Expand Down Expand Up @@ -732,6 +733,7 @@ async def async_stream_infer(self,
"""
# start forward thread
self.aque = asyncio.Queue()
self.afinal = asyncio.Queue()
if stream_output and not stop:
self.model_insts[0].register_callback(self._async_forward_callback)

Expand All @@ -752,11 +754,7 @@ async def async_stream_infer(self,

seq_start = input_lengths + input_lengths.new_tensor(step)

finish = False
# generator
while not finish:
finish, tm_outputs = await self.aque.get()

def _yield_outputs(finish, tm_outputs):
outputs = _tm_dict_to_torch_dict(tm_outputs)

output_ids = outputs['output_ids'][:, 0, :]
Expand All @@ -780,7 +778,17 @@ async def async_stream_infer(self,
outputs = (status, output[:-1].tolist(), len_)
else:
outputs = (status, output.tolist(), len_)
yield outputs
return outputs

finish = False
# generator
while not finish:
while self.aque.qsize() > 1:
self.aque.get_nowait()
finish, tm_outputs = await self.aque.get()
while self.afinal.qsize() > 0:
finish, tm_outputs = await self.afinal.get()
yield _yield_outputs(finish, tm_outputs)

for t in self.threads:
t.join()
Expand Down

0 comments on commit ae42806

Please sign in to comment.