diff --git a/python/llm/src/ipex_llm/transformers/npu_generate.py b/python/llm/src/ipex_llm/transformers/npu_generate.py new file mode 100644 index 00000000000..480175f78fd --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_generate.py @@ -0,0 +1,237 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation +# /utils.py +# + +from typing import Any, Callable, Dict, List, Optional, Tuple +import os +import torch +import time +import numpy as np +import random +import logging +from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ + _crop_past_key_values, _prepare_generate_args +from ipex_llm.utils.common import invalidInputError +from transformers.modeling_outputs import CausalLMOutputWithPast + +logger = logging.getLogger("ipex_llm.npu") + +# patch GenerationMixin.generate +from transformers import GenerationMixin +original_generate = GenerationMixin.generate +query_group_size = 16 + + +@torch.no_grad() +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + if kwargs.get("num_beams", None) not in [None, 1]: + return original_generate(self, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs) + else: + return self.npu_generate(inputs=inputs, + generation_config=generation_config, + streamer=streamer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + **kwargs) + + +GenerationMixin.generate = generate + + +def clear_benchmarks(self): + self.first_token_time = None + self.last_token_time = [] + self.encoder_time = 0 + + +@torch.no_grad() +def npu_generate(self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + streamer: Optional["BaseStreamer"] = None, + **sampling_kwargs): + input_ids, generation_config, logits_processor, stopping_criteria, \ + model_kwargs = _prepare_generate_args(self, inputs, generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + streamer=streamer, + generate_attention_mask=True, + **sampling_kwargs) + + step = 0 + max_new_tokens = generation_config.max_new_tokens + attn_mask = model_kwargs.get("attention_mask", None) + + if self.config.model_type == "qwen2": + from transformers.generation.logits_process import MinNewTokensLengthLogitsProcessor + docoder_logits_processor = LogitsProcessorList([item for item in logits_processor \ + if not isinstance(item, MinNewTokensLengthLogitsProcessor)]) + else: + docoder_logits_processor = LogitsProcessorList() + + clear_benchmarks(self) + + input_len = input_ids.shape[1] + + eos_token_id_set = None + if generation_config.eos_token_id is not None: + if isinstance(generation_config.eos_token_id, list): + eos_token_id_set = set(generation_config.eos_token_id) + else: + eos_token_id_set = set([generation_config.eos_token_id]) + + while True: + if step >= max_new_tokens: + break + + tic = time.perf_counter() + + if step == 0: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + output = self(**model_inputs, + return_dict=True) + # output = CausalLMOutputWithPast( + # loss=None, + # logits=torch.empty([1, 1, self.config.vocab_size]), + # past_key_values=None, + # hidden_states=None, + # attentions=None, + # ) + logits = output['logits'][:, -1:] + logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :]) + if generation_config.do_sample: + output_ids, prob_list = deepmind_sample(logits, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) + else: + output_ids = torch.argmax(logits, dim=-1) + input_ids = torch.cat((input_ids, output_ids), dim=-1) + else: + model_inputs = { + "input_ids": input_ids[:, -1:], + "past_key_values": past_key_values, + # "position_ids": model_kwargs["position_ids"], + "use_cache": True, + "attention_mask": attn_mask, + } + + output = output = self(**model_inputs, + return_dict=True) + # output = CausalLMOutputWithPast( + # loss=None, + # logits=torch.empty([1, 1, self.config.vocab_size]), + # past_key_values=None, + # hidden_states=None, + # attentions=None, + # ) + t1 = time.perf_counter() + + logits = output['logits'] + + logits[:, -1, :] = docoder_logits_processor(input_ids, + logits[:, -1, :]) + t2 = time.perf_counter() + + if generation_config.do_sample: + output_ids, prob_list = deepmind_sample(logits, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) + output_ids = output_ids.transpose(0, 1) + else: + output_ids = torch.argmax(logits, dim=-1) + + t3 = time.perf_counter() + + input_ids = torch.cat((input_ids, output_ids), dim=-1) + t4 = time.perf_counter() + + step += 1 + + past_key_values = output.past_key_values + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_ones((attn_mask.shape[0], 1))], dim=-1 + ) + + toc = time.perf_counter() + + if self.first_token_time is None: + self.first_token_time = toc - tic + else: + self.last_token_time.append(toc - tic) + print(f"Prepare input & dummy output: {(t1 - tic)*1000} ms, update attn mask: {(toc - t4)*1000} ms, " + f"argmax: {(t3 - t2)*1000} ms, cat input id: {(t4 - t3) * 1000} ms, logtis processor: {(t2 - t1)*1000} ms" + f" total generate: {(toc - t1)*1000} ms") + + # Stop on eos and remove content after eos + if eos_token_id_set is not None: + output_ids_list = output_ids[0].tolist() + first_eos_idx = -1 + for out_idx, out_id in enumerate(output_ids_list): + if out_id in eos_token_id_set: + first_eos_idx = out_idx + break + if first_eos_idx > -1: + if streamer is not None: + streamer.put(output_ids[:(first_eos_idx + 1)].cpu()) + step -= (len(output_ids_list) - first_eos_idx - 1) + break + if streamer is not None: + streamer.put(output_ids.cpu()) + + step = min(step, max_new_tokens) + self.n_token_generated = step + + print(f"=========First token cost {self.first_token_time:.4f} s=========") + if len(self.last_token_time) > 1: + self.first_cost = self.first_token_time + self.rest_cost_mean = np.mean(self.last_token_time) + print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s " + f"({len(self.last_token_time)} tokens in all)=========") + + if streamer is not None: + streamer.end() + + return input_ids[:, : input_len + step] diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 56ca664cac9..519fd71b6a3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -179,6 +179,9 @@ def from_pretrained(cls, *args, **kwargs): transpose_value_cache=transpose_value_cache, ) model.save_low_bit = types.MethodType(save_low_bit, model) + if model.config.model_type in ["qwen2", "llama"]: + from ipex_llm.transformers.npu_generate import npu_generate + model.npu_generate = types.MethodType(npu_generate, model) else: from ipex_llm.transformers.npu_models.convert import optimize_llm optimize_llm(model) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 4600e99fefc..5e59b735ff2 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -526,7 +526,9 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa return past_key_values -def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sampling_kwargs): +def _prepare_generate_args(self, inputs, generation_config, streamer=None, logits_processor=None, + stopping_criteria=None, generate_attention_mask=False, + **sampling_kwargs): if generation_config is None: generation_config = self.generation_config @@ -551,8 +553,9 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam generation_config.pad_token_id = eos_token_id # 2. Set generation parameters if not already defined - logits_processor = LogitsProcessorList() - stopping_criteria = StoppingCriteriaList() + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None \ + else StoppingCriteriaList() # 3. Define model inputs # inputs_tensor has to be defined @@ -574,15 +577,17 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam # else: # model_kwargs["use_cache"] = generation_config.use_cache - # accepts_attention_mask = "attention_mask" in set( - # inspect.signature(self.forward).parameters.keys()) - # requires_attention_mask = "encoder_outputs" not in model_kwargs + if generate_attention_mask: + import inspect + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs - # if model_kwargs.get("attention_mask", None) is None and \ - # requires_attention_mask and accepts_attention_mask: - # model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - # inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id - # ) + if model_kwargs.get("attention_mask", None) is None and \ + requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: @@ -605,6 +610,8 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + # Skip max length + if streamer is not None: streamer.put(input_ids.cpu())