diff --git a/optimum/exporters/ipex/model_config.py b/optimum/exporters/ipex/model_config.py new file mode 100755 index 000000000..5cd7f1e79 --- /dev/null +++ b/optimum/exporters/ipex/model_config.py @@ -0,0 +1,98 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from optimum.exporters.onnx.model_configs import ( + FalconOnnxConfig, + GPT2OnnxConfig, + LlamaOnnxConfig, +) +from optimum.utils import DEFAULT_DUMMY_SHAPES +from optimum.utils.input_generators import DummyPastKeyValuesGenerator, DummyTextInputGenerator +from optimum.utils.normalized_config import NormalizedTextConfig + + +DEFAULT_DUMMY_SHAPES["batch_size"] = 1 + + +class IPEXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.num_key_value_heads = getattr(normalized_config, "num_key_value_heads", 1) + self.max_position_embeddings = normalized_config.max_position_embeddings + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape_init = (1, self.sequence_length, self.sequence_length, 1) + shape_beam_idx_tmp = (self.max_position_embeddings, self.batch_size) + shape_kv = ( + self.max_position_embeddings, + self.batch_size, + self.num_key_value_heads, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_int_tensor(shape_init, max_value=1, framework=framework).contiguous(), + self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(), + self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(), + self.random_int_tensor(shape_beam_idx_tmp, max_value=1, framework=framework).contiguous(), + ) + for _ in range(self.num_layers) + ] + + +class IPEXDummyTextInputGenerator(DummyTextInputGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + **kwargs, + ): + super().__init__(task, normalized_config, batch_size, **kwargs) + + +class LlamaIPEXConfig(LlamaOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator + + +class FalconIPEXConfig(FalconOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator + + +class GPT2IPEXConfig(GPT2OnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator + + +ipex_onnx_config = {"llama": LlamaIPEXConfig, "falcon": FalconIPEXConfig, "gpt2": GPT2IPEXConfig} diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index d6467f76a..739a2f2b4 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -23,10 +23,10 @@ import intel_extension_for_pytorch as ipex import torch +import transformers from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp -from intel_extension_for_pytorch.transformers.optimize import get_dummy_input from transformers import ( AutoConfig, AutoModel, @@ -43,20 +43,24 @@ is_torch_xpu_available, ) from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.generation.candidate_generator import _crop_past_key_values from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.models.auto.auto_factory import _get_model_class as get_model_class from transformers.utils import WEIGHTS_NAME from optimum.exporters import TasksManager +from optimum.exporters.tasks import make_backend_config_constructor_for_task from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager +from ...exporters.ipex.model_config import ipex_onnx_config from ...exporters.ipex.model_patcher import ( _IPEX_EXPORTED_GENERATION_TASKS, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model, ) -from ..generation.modeling import prepare_jit_inputs +from ..generation.modeling import get_float_type +from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device @@ -86,10 +90,35 @@ def _is_patched_with_ipex(model, task): def _prepare_inputs_for_ipex_model(model, task, use_cache): - if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task): - return get_dummy_input(model, return_dict=True) + task = _TASK_ALIASES.get(task, task) + signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) + if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config: + onnx_config_class = make_backend_config_constructor_for_task( + ipex_onnx_config[model.config.model_type], task=task + ) + else: + onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) + float_dtype = get_float_type(model.dtype) + if "text-generation" in task: + onnx_config = onnx_config_class( + model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype + ) else: - return prepare_jit_inputs(model, task, use_cache) + onnx_config = onnx_config_class(model.config) + + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") + + # Check attention_mask shape + if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config and use_cache: + past_len = dummy_inputs["past_key_values"][0][0].shape[-2] + input_len = dummy_inputs["input_ids"].shape[-1] + attention_len = dummy_inputs["attention_mask"].shape[-1] + if attention_len != input_len + past_len: + dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to( + dummy_inputs["input_ids"].dtype + ) + + return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} def ipex_jit_trace(model, task, use_cache): @@ -103,11 +132,7 @@ def ipex_jit_trace(model, task, use_cache): sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache) model.config.return_dict = False - - if "past_key_values" in sample_inputs: - model.config.use_cache = use_cache - if not use_cache: - sample_inputs.pop("past_key_values") + model.config.use_cache = use_cache # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. # Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks. @@ -372,7 +397,7 @@ def _init_warmup(self): # TODO : add warmup for IPEX exported model if not self._is_ipex_exported: use_cache = "past_key_values" in self.input_names - dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache) if self._device.type != "cpu": dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) for _ in range(2): @@ -652,11 +677,28 @@ def _prepare_generation_config( return generation_config, model_kwargs def generate(self, *args, **kwargs): - if self._is_ipex_exported and kwargs.get("assistant_model", None): + if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None): raise ValueError( - f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) - return super().generate(*args, **kwargs) + # Patch functions to support IAKV cache + if self._is_ipex_exported and kwargs.get("assistant_model", None): + transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values + elif self._is_ipex_exported: + transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values + + try: + result = super().generate(*args, **kwargs) + except Exception as e: + transformers.generation.utils._crop_past_key_values = _crop_past_key_values + transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values + raise e + + if self._is_ipex_exported and kwargs.get("assistant_model", None): + transformers.generation.utils._crop_past_key_values = _crop_past_key_values + transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values + + return result def _ipex_prepare_inputs_for_generation( @@ -736,3 +778,16 @@ def _ipex_reorder_cache( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past_key_values ) + + +def _ipex_crop_past_key_values(model, past_key_values, max_length): + if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): + new_past_key_values = [] + for i in range(len(past_key_values)): + pkv = [] + pkv.append(past_key_values[i][0][:, :max_length, :max_length, :]) + pkv += [past_key_values[i][_] for _ in range(1, 4)] + new_past_key_values.append(tuple(pkv)) + new_past_key_values = tuple(new_past_key_values) + return new_past_key_values + return _crop_past_key_values(model, past_key_values, max_length) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 01f935292..53c733c4f 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -281,11 +281,10 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) - # High optimized model llama is not supported assisted decoding for now. @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_assisted_decoding(self, model_arch): - # Patched models are not support assisted decoding for now. - if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + # Patched models are not support assisted decoding if ipex < 2.5. + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"): return model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -296,11 +295,15 @@ def test_assisted_decoding(self, model_arch): ipex_output_assisted = ipex_model.generate( **tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4 ) + ipex_output_assisted_2 = ipex_model.generate( + **tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4 + ) transformers_output = transformers_model.generate(**tokens, do_sample=False, max_new_tokens=4) transformers_output_assisted = transformers_model.generate( **tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4 ) self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) + self.assertTrue(torch.equal(ipex_output, ipex_output_assisted_2)) self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) @parameterized.expand(