From bc0cb3a1ba2095d94c163121a99398a982e7a19b Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Fri, 8 Sep 2023 17:56:01 -0700 Subject: [PATCH] [Handler] fix device mapping issues (#1065) --- engines/python/setup/djl_python/streaming_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/engines/python/setup/djl_python/streaming_utils.py b/engines/python/setup/djl_python/streaming_utils.py index e6f5ca3ec..b7e80aaaa 100644 --- a/engines/python/setup/djl_python/streaming_utils.py +++ b/engines/python/setup/djl_python/streaming_utils.py @@ -77,6 +77,8 @@ def use_hf_default_streamer(model, tokenizer, inputs, device, **kwargs): input_tokens = tokenizer(inputs, padding=True, return_tensors="pt") if device is not None: input_tokens = input_tokens.to(device) + else: + input_tokens = input_tokens.to(model.device) streamer = HFStreamer(tokenizer, skip_special_tokens=True) generation_kwargs = dict(input_tokens, streamer=streamer, **kwargs) @@ -121,8 +123,9 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs): StreamingUtils.DEFAULT_MAX_NEW_TOKENS) tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True) input_ids = tokenized_inputs["input_ids"] - if device is not None: - input_ids = input_ids.to(device) + if device is None: + device = model.device + input_ids = input_ids.to(device) past_key_values = None decoding_method = StreamingUtils._get_decoding_method(**kwargs) @@ -159,8 +162,7 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs): else: raise ValueError(f"Unsupported model class: {generic_model_class}") - if device is not None: - attention_mask = attention_mask.to(device) + attention_mask = attention_mask.to(device) while True: if stop_generation: