Skip to content

Commit

Permalink
[Handler] fix device mapping issues (#1065)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Sep 9, 2023
1 parent 9263842 commit bc0cb3a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def use_hf_default_streamer(model, tokenizer, inputs, device, **kwargs):
input_tokens = tokenizer(inputs, padding=True, return_tensors="pt")
if device is not None:
input_tokens = input_tokens.to(device)
else:
input_tokens = input_tokens.to(model.device)

streamer = HFStreamer(tokenizer, skip_special_tokens=True)
generation_kwargs = dict(input_tokens, streamer=streamer, **kwargs)
Expand Down Expand Up @@ -121,8 +123,9 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs):
StreamingUtils.DEFAULT_MAX_NEW_TOKENS)
tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True)
input_ids = tokenized_inputs["input_ids"]
if device is not None:
input_ids = input_ids.to(device)
if device is None:
device = model.device
input_ids = input_ids.to(device)

past_key_values = None
decoding_method = StreamingUtils._get_decoding_method(**kwargs)
Expand Down Expand Up @@ -159,8 +162,7 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs):
else:
raise ValueError(f"Unsupported model class: {generic_model_class}")

if device is not None:
attention_mask = attention_mask.to(device)
attention_mask = attention_mask.to(device)

while True:
if stop_generation:
Expand Down

0 comments on commit bc0cb3a

Please sign in to comment.