Skip to content

Commit

Permalink
Add support assisted decoding in ipex 2.4 (huggingface#823)
Browse files Browse the repository at this point in the history
* support assisted decoding in ipex 2.5

* Update optimum/intel/ipex/modeling_base.py

Co-authored-by: Ella Charlaix <[email protected]>

* fix tests fail

* fix style

* ipex onnx config

* patch before generate and un-patch after generate

* only patch functions in assisted decoding

* try and cache the genration result and do un-patch

* raise error

* fix style

* ipex 2.4 supports assisted decoding

* fix inputs

* fix generate

* enable assisted decoding tests

* more tests on assisted decoding

* fix config name

* unpatch target model's generation
  • Loading branch information
jiqing-feng committed Sep 9, 2024
1 parent 8a015a6 commit 5db1ac7
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 17 deletions.
98 changes: 98 additions & 0 deletions optimum/exporters/ipex/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

from optimum.exporters.onnx.model_configs import (
FalconOnnxConfig,
GPT2OnnxConfig,
LlamaOnnxConfig,
)
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import DummyPastKeyValuesGenerator, DummyTextInputGenerator
from optimum.utils.normalized_config import NormalizedTextConfig


DEFAULT_DUMMY_SHAPES["batch_size"] = 1


class IPEXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = getattr(normalized_config, "num_key_value_heads", 1)
self.max_position_embeddings = normalized_config.max_position_embeddings

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape_init = (1, self.sequence_length, self.sequence_length, 1)
shape_beam_idx_tmp = (self.max_position_embeddings, self.batch_size)
shape_kv = (
self.max_position_embeddings,
self.batch_size,
self.num_key_value_heads,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_int_tensor(shape_init, max_value=1, framework=framework).contiguous(),
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
self.random_int_tensor(shape_beam_idx_tmp, max_value=1, framework=framework).contiguous(),
)
for _ in range(self.num_layers)
]


class IPEXDummyTextInputGenerator(DummyTextInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
super().__init__(task, normalized_config, batch_size, **kwargs)


class LlamaIPEXConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


class FalconIPEXConfig(FalconOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


class GPT2IPEXConfig(GPT2OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


ipex_onnx_config = {"llama": LlamaIPEXConfig, "falcon": FalconIPEXConfig, "gpt2": GPT2IPEXConfig}
83 changes: 69 additions & 14 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

import intel_extension_for_pytorch as ipex
import torch
import transformers
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -43,20 +43,24 @@
is_torch_xpu_available,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.generation.candidate_generator import _crop_past_key_values
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.tasks import make_backend_config_constructor_for_task
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.model_config import ipex_onnx_config
from ...exporters.ipex.model_patcher import (
_IPEX_EXPORTED_GENERATION_TASKS,
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_patch_model,
)
from ..generation.modeling import prepare_jit_inputs
from ..generation.modeling import get_float_type
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device

Expand Down Expand Up @@ -86,10 +90,35 @@ def _is_patched_with_ipex(model, task):


def _prepare_inputs_for_ipex_model(model, task, use_cache):
if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task):
return get_dummy_input(model, return_dict=True)
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config:
onnx_config_class = make_backend_config_constructor_for_task(
ipex_onnx_config[model.config.model_type], task=task
)
else:
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
)
else:
return prepare_jit_inputs(model, task, use_cache)
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")

# Check attention_mask shape
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config and use_cache:
past_len = dummy_inputs["past_key_values"][0][0].shape[-2]
input_len = dummy_inputs["input_ids"].shape[-1]
attention_len = dummy_inputs["attention_mask"].shape[-1]
if attention_len != input_len + past_len:
dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to(
dummy_inputs["input_ids"].dtype
)

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


def ipex_jit_trace(model, task, use_cache):
Expand All @@ -103,11 +132,7 @@ def ipex_jit_trace(model, task, use_cache):
sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache)

model.config.return_dict = False

if "past_key_values" in sample_inputs:
model.config.use_cache = use_cache
if not use_cache:
sample_inputs.pop("past_key_values")
model.config.use_cache = use_cache

# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
Expand Down Expand Up @@ -372,7 +397,7 @@ def _init_warmup(self):
# TODO : add warmup for IPEX exported model
if not self._is_ipex_exported:
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache)
if self._device.type != "cpu":
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
for _ in range(2):
Expand Down Expand Up @@ -652,11 +677,28 @@ def _prepare_generation_config(
return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if self._is_ipex_exported and kwargs.get("assistant_model", None):
if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
return super().generate(*args, **kwargs)
# Patch functions to support IAKV cache
if self._is_ipex_exported and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
elif self._is_ipex_exported:
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values

try:
result = super().generate(*args, **kwargs)
except Exception as e:
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
raise e

if self._is_ipex_exported and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values

return result


def _ipex_prepare_inputs_for_generation(
Expand Down Expand Up @@ -736,3 +778,16 @@ def _ipex_reorder_cache(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
new_past_key_values = []
for i in range(len(past_key_values)):
pkv = []
pkv.append(past_key_values[i][0][:, :max_length, :max_length, :])
pkv += [past_key_values[i][_] for _ in range(1, 4)]
new_past_key_values.append(tuple(pkv))
new_past_key_values = tuple(new_past_key_values)
return new_past_key_values
return _crop_past_key_values(model, past_key_values, max_length)
9 changes: 6 additions & 3 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,10 @@ def test_pipeline(self, model_arch):
self.assertEqual(pipe.device, model.device)
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))

# High optimized model llama is not supported assisted decoding for now.
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_assisted_decoding(self, model_arch):
# Patched models are not support assisted decoding for now.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
# Patched models are not support assisted decoding if ipex < 2.5.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"):
return
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand All @@ -296,11 +295,15 @@ def test_assisted_decoding(self, model_arch):
ipex_output_assisted = ipex_model.generate(
**tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4
)
ipex_output_assisted_2 = ipex_model.generate(
**tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4
)
transformers_output = transformers_model.generate(**tokens, do_sample=False, max_new_tokens=4)
transformers_output_assisted = transformers_model.generate(
**tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4
)
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted))
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted_2))
self.assertTrue(torch.equal(transformers_output, transformers_output_assisted))

@parameterized.expand(
Expand Down

0 comments on commit 5db1ac7

Please sign in to comment.