Skip to content

Commit

Permalink
Fix lmi-dist batch handling (deepjavalibrary#887)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored and KexinFeng committed Aug 16, 2023
1 parent b2948e3 commit 7aa877c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
10 changes: 7 additions & 3 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,13 @@ def inference(self, inputs):
first = True
for item in batch:
input_map = decode(item, content_type)
input_size.append(len(input_map.get("inputs")))
_inputs = input_map.pop("inputs", input_map)
if isinstance(_inputs, list):
input_data.extend(_inputs)
input_size.append(len(_inputs))
else:
input_data.append(_inputs)
input_size.append(1)
if first or self.rolling_batch_type:
parameters.append(input_map.pop("parameters", {}))
first = False
Expand All @@ -208,9 +209,12 @@ def inference(self, inputs):

if self.rolling_batch_type:
result = self.rolling_batch.inference(input_data, parameters)
for i in range(len(batch)):
for i in range(inputs.get_batch_size()):
res = result[i]
outputs.add_as_json(res, batch_index=i)
encode(outputs,
res,
accept,
key=inputs.get_content().key_at(i))

return outputs
elif self.enable_streaming:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,27 @@ def _prefill_and_decode(self, new_batch):

if self.cache:
decode_generations, decode_next_batch = self.model.generate_token(self.cache)
self.cache = decode_next_batch
generations.extend(decode_generations)

# concatenate with the existing batch of the model
self.cache = self.model.batch_type.concatenate([prefill_next_batch, self.cache])


if decode_next_batch:
self.cache = self.model.batch_type.concatenate([prefill_next_batch, decode_next_batch])
else:
self.cache = prefill_next_batch
else:
self.cache = prefill_next_batch
else:
generations, next_batch = self.model.generate_token(self.cache)
self.cache = next_batch

generation_dict = {}
for generation in generations:
generation_dict[generation.request_id] = generation
generation_dict = {generation.request_id: generation for generation in generations}

req_ids = []
for r in self.pending_requests:
generation = generation_dict[r.id]
is_last_token = generation.generated_text is not None
if not is_last_token:
req_ids.append((r.id))
req_ids.append(r.id)
r.set_next_token(generation.token_text, last_token=is_last_token)

# filter the requests that are stopped.
Expand Down

0 comments on commit 7aa877c

Please sign in to comment.