Skip to content

Commit

Permalink
enable IPEXModelForSeq2SeqLM by torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Aug 15, 2024
1 parent 7b8eaa6 commit 27f7d9d
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 15 deletions.
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_import_structure["ipex"] = [
"inference_mode",
"IPEXModelForCausalLM",
"IPEXModelForSeq2SeqLM",
"IPEXModelForSequenceClassification",
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
Expand Down Expand Up @@ -194,6 +195,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
inference_mode,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
100 changes: 85 additions & 15 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
GenerationConfig,
Expand All @@ -43,7 +44,7 @@
is_torch_xpu_available,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME

Expand Down Expand Up @@ -113,21 +114,35 @@ def ipex_jit_trace(model, task, use_cache):
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
if is_ipex_version(">=", "2.3.0") and task in _IPEX_EXPORTED_GENERATION_TASKS:
_enable_tpp()
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
# Disable repack while jit tracing to reduce the memory
ipex._C.disable_jit_linear_repack()
with torch.no_grad():
trace_model = torch.jit.trace(
model,
example_kwarg_inputs=sample_inputs,
strict=False,
check_trace=False,
)
trace_model = torch.jit.freeze(trace_model)
trace_model(**sample_inputs)
trace_model(**sample_inputs)

return trace_model
if task == "text2text-generation":
if is_torch_version("<", "2.5.0"):
warnings.warn(
"The Seq2Seq model will keep eager mode, please upgrade your pytorch version for better performance with graph mode."
)
else:
torch._inductor.cpp_wrapper = True
model.forward = torch.compile(model.forward, dynamic=True)
# compile
warnings.warn("Run a few epochs for torch.compile by your use case.")

return model
else:
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
# Disable repack while jit tracing to reduce the memory
ipex._C.disable_jit_linear_repack()
with torch.no_grad():
trace_model = torch.jit.trace(
model,
example_kwarg_inputs=sample_inputs,
strict=False,
check_trace=False,
)
trace_model = torch.jit.freeze(trace_model)
trace_model(**sample_inputs)
trace_model(**sample_inputs)

return trace_model


class IPEXModel(OptimizedModel):
Expand Down Expand Up @@ -618,6 +633,61 @@ def generate(self, *args, **kwargs):
return super().generate(*args, **kwargs)


class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
export_feature = "text2text-generation"
_supports_cache_class = False

def __init__(
self,
model,
config: PretrainedConfig = None,
export: bool = False,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
**kwargs,
):
# Perform the initial warmup at the end of __init__
super().__init__(
model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache
)
GenerationMixin.__init__(self)
self.generation_config = GenerationConfig.from_model_config(self.config)
try:
self.model_cls = get_class_from_dynamic_module(
self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)

if "_reorder_cache" in self.model_cls.__dict__ and isinstance(
self.model_cls.__dict__["_reorder_cache"], staticmethod
):
self._reorder_cache = self.model_cls._reorder_cache
elif "_reorder_cache" in self.model_cls.__dict__:
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)

self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache

def get_encoder(self):
return self.model.encoder

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
outputs = self._call_model(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
return outputs


def _ipex_prepare_inputs_for_generation(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
"text-classification": "IPEXModelForSequenceClassification",
"token-classification": "IPEXModelForTokenClassification",
"question-answering": "IPEXModelForQuestionAnswering",
"text2text-generation": "IPEXModelForSeq2SeqLM",
}
7 changes: 7 additions & 0 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand All @@ -71,6 +72,12 @@
"default": "gpt2",
"type": "text",
},
"text2text-generation": {
"impl": TextGenerationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "google-t5/t5-base",
"type": "text",
},
"fill-mask": {
"impl": FillMaskPipeline,
"class": (IPEXModelForMaskedLM,),
Expand Down
11 changes: 11 additions & 0 deletions optimum/intel/utils/dummy_ipex_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForSeq2SeqLM(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForQuestionAnswering(metaclass=DummyObject):
_backends = ["ipex"]

Expand Down

0 comments on commit 27f7d9d

Please sign in to comment.