diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index 7b87a20d0..6e9992348 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -36,8 +36,9 @@ jobs: pip install cmake pip install py-cpuinfo pip install torch==2.3.0 torchaudio==2.3.0 torchvision==0.18 --index-url https://download.pytorch.org/whl/cpu + pip install intel-extension-for-pytorch==2.3.0 + pip install datasets==2.19.0 pip install .[neural-compressor,diffusers,tests] - pip install intel-extension-for-transformers pip install peft - name: Test with Pytest @@ -45,7 +46,5 @@ jobs: pytest tests/neural_compressor/ --ignore tests/neural_compressor/test_ipex.py --durations=0 - name: Test IPEX run: | - pip uninstall -y intel-extension-for-transformers - pip install intel-extension-for-pytorch==2.3.0 pytest tests/neural_compressor/test_ipex.py diff --git a/examples/neural_compressor/language-modeling/README.md b/examples/neural_compressor/language-modeling/README.md index 80d7a25d1..476ac526e 100644 --- a/examples/neural_compressor/language-modeling/README.md +++ b/examples/neural_compressor/language-modeling/README.md @@ -97,4 +97,4 @@ respectively `dynamic`, `static`, `weight_only` or `aware_training`. The flag `--verify_loading` can be passed along to verify that the resulting quantized model can be loaded correctly. -> **_Note:_** `weight_only` quantization_approach requires `neural-compressor` >= 2.3 and `intel-extension-for-transformers` >= 1.3. +> **_Note:_** `weight_only` quantization_approach requires `neural-compressor` > 3.0. diff --git a/examples/neural_compressor/language-modeling/requirements.txt b/examples/neural_compressor/language-modeling/requirements.txt index ec38e83d2..8960d82db 100644 --- a/examples/neural_compressor/language-modeling/requirements.txt +++ b/examples/neural_compressor/language-modeling/requirements.txt @@ -3,5 +3,4 @@ torch >= 1.9 datasets >= 1.8.0 sentencepiece != 0.1.92 protobuf -intel-extension-for-transformers >= 1.3 peft diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index e2496d5ba..7e8107219 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -39,6 +39,7 @@ QuantizationAwareTrainingConfig, WeightPruningConfig, ) +from neural_compressor.transformers import GPTQConfig, RtnConfig from transformers import ( CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, @@ -57,12 +58,8 @@ from transformers.utils.versions import require_version from optimum.intel.neural_compressor import INCModelForCausalLM, INCQuantizer, INCTrainer -from optimum.intel.utils.import_utils import ITREX_IMPORT_ERROR, is_itrex_available -if is_itrex_available(): - from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig - os.environ["CUDA_VISIBLE_DEVICES"] = "" # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -203,12 +200,8 @@ class OptimizationArguments: metadata={"help": "Whether or not to verify the loading of the quantized model."}, ) bits: str = field( - default="4", - metadata={"help": "Bits number of weight for weight only quantization. 1~8 bits."}, - ) - weight_dtype: str = field( - default="int4_clip", - metadata={"help": "weight dtype for weight only quantization."}, + default=4, + metadata={"help": "Bits number of weight for weight only quantization. only support 4 bits now."}, ) group_size: int = field( default=-1, @@ -223,7 +216,6 @@ class OptimizationArguments: metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."}, ) quantization_methodology: str = field( - choices=["rtn", "gptq"], default="rtn", metadata={"help": "Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'."}, ) @@ -249,6 +241,11 @@ class OptimizationArguments: metadata={"help": "Calibration dataset sequence max length, this should align with your model config"}, ) + def __post_init__(self): + woq_algorithms = ["rtn", "gptq"] + if self.quantization_methodology not in woq_algorithms: + raise ValueError(f"Value must be one of {woq_algorithms}, got {self.quantization_methodology}") + @dataclass class DataTrainingArguments: @@ -655,13 +652,11 @@ def compute_metrics(eval_preds): else: recipes = {} if optim_args.quantization_approach == "weight_only": - if not is_itrex_available(): - raise ImportError(ITREX_IMPORT_ERROR.format("WeightOnly quantization")) if optim_args.apply_pruning or optim_args.apply_distillation: raise ValueError("Weight only quantization and pruning or distillation cannot be combined.") algorithm_args = { - "weight_dtype": optim_args.weight_dtype, + "bits": optim_args.bits, "sym": optim_args.weight_only_scheme == "sym", "group_size": optim_args.group_size, } @@ -756,10 +751,10 @@ def compute_metrics(eval_preds): trainer.save_metrics("train", metrics) trainer.save_state() - if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic", "weight_only"}: + if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic"}: model = trainer.model if isinstance(trainer.model, PreTrainedModel) else trainer.model._model quantizer = INCQuantizer.from_pretrained(model) - if optim_args.quantization_approach in ["static", "weight_only"]: + if optim_args.quantization_approach == "static": num_calibration_samples = min(len(train_dataset), optim_args.num_calibration_samples) train_dataset = train_dataset.select(range(num_calibration_samples)) quantization_config.calibration_sampling_size = num_calibration_samples @@ -776,6 +771,19 @@ def compute_metrics(eval_preds): ) trainer.model = quantizer._quantized_model + if optim_args.apply_quantization and optim_args.quantization_approach == "weight_only": + model = trainer.model if isinstance(trainer.model, PreTrainedModel) else trainer.model._model + num_calibration_samples = min(len(train_dataset), optim_args.num_calibration_samples) + train_dataset = train_dataset.select(range(num_calibration_samples)) + quantization_config.calibration_sampling_size = num_calibration_samples + quantized_model = INCModelForCausalLM.from_pretrained( + model_args.model_name_or_path, quantization_config=quantization_config + ) + if hasattr(quantization_config, "tokenizer"): + quantization_config.tokenizer.save_pretrained(training_args.output_dir) + quantized_model.save_pretrained(training_args.output_dir) + trainer.model = quantized_model + if optim_args.apply_quantization and optim_args.verify_loading: loaded_model = INCModelForCausalLM.from_pretrained(training_args.output_dir) tokens = tokenizer("This is a sample input", return_tensors="pt") diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index a12cfc84e..392d84b47 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -23,6 +23,8 @@ from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import EntryNotFoundError +from neural_compressor.transformers import GPTQConfig, RtnConfig +from neural_compressor.transformers.models.modeling_auto import _BaseINCAutoModelClass from neural_compressor.utils.pytorch import load from transformers import ( AutoConfig, @@ -47,8 +49,9 @@ from optimum.intel.generation import BaseModelForCausalLM from ...modeling_base import OptimizedModel -from ..utils.import_utils import _torch_version, is_itrex_available, is_torch_version +from ..utils.import_utils import _torch_version, is_torch_version from .configuration import INCConfig +from .quantization import _weight_only_quantization from .utils import QUANTIZATION_CONFIG_NAME @@ -122,8 +125,85 @@ def _from_pretrained( raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") token = use_auth_token + quantization_config = kwargs.pop("quantization_config", None) model_path = Path(model_id) is_local = model_path.is_dir() + + # ITREX compatibility + quantization_config_path = None + if is_local: + quantization_config_path = model_path / subfolder / QUANTIZATION_CONFIG_NAME + else: + try: + quantization_config_path = hf_hub_download( + repo_id=model_id, + filename=QUANTIZATION_CONFIG_NAME, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + pass + if quantization_config_path and Path(quantization_config_path).is_file(): + algorithm = getattr(quantization_config, "quant_method", None) + if algorithm in {"rtn", "gptq", "awq", "autoround"}: + raise ValueError( + "This model was obtained through ITREX quantization, support for ITREX models is deprecated since neural-compressor v3.0. " + "To load this model please downgrade both optimum-intel and neural-compressor." + ) + # quantization_config = PretrainedConfig.from_pretrained(quantization_config_path) + # config.quantization_config = quantization_config.to_dict() + + if hasattr(config, "quantization_config"): + if config.quantization_config is None: + raise ValueError( + "The loading of `quantization_config` failed, to load this model please make sure the config is compatible" + ) + else: + try: + logger.info( + "The weight only quantized model loading only supports the same format as GPTQ, such as https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ/tree/main." + ) + _BaseINCAutoModelClass.ORIG_MODEL = cls.auto_model_class + model = _BaseINCAutoModelClass.load_low_bit( + model_id, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + config=config, + **kwargs, + ) + logger.info("Saved low bit model loading successfully. Other input args " "will be ignored.") + return model + except Exception as e: + raise RuntimeError(f"The quantized model cannot be loaded. Detailed error: {e}") + if isinstance(quantization_config, (RtnConfig, GPTQConfig)): + logger.info( + "The quantized model parameters will be saved in the same format as GPTQ, here is the sample model https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ/tree/main for details." + ) + model = _weight_only_quantization( + cls.auto_model_class, + model_id, + quantization_config=quantization_config, + token=token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + local_files_only=local_files_only, + subfolder=subfolder, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cls(model, config=config, model_save_dir=None, **kwargs).model + model_cache_path = None inc_config = None msg = None @@ -165,52 +245,6 @@ def _from_pretrained( model_save_dir = Path(model_cache_path).parent - if is_itrex_available(): - quantization_config_path = None - if is_local: - quantization_config_path = model_path / subfolder / QUANTIZATION_CONFIG_NAME - else: - try: - quantization_config_path = hf_hub_download( - repo_id=model_id, - filename=QUANTIZATION_CONFIG_NAME, - subfolder=subfolder, - token=token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - pass - - if quantization_config_path and Path(quantization_config_path).is_file(): - quantization_config = PretrainedConfig.from_pretrained(quantization_config_path) - algorithm = getattr(quantization_config, "quant_method", None) - if algorithm in {"rtn", "gptq", "awq", "autoround"}: - from intel_extension_for_transformers.transformers.modeling.modeling_auto import ( - _BaseQBitsAutoModelClass, - ) - - _BaseQBitsAutoModelClass.ORIG_MODEL = cls.auto_model_class - - model = _BaseQBitsAutoModelClass.from_pretrained( - pretrained_model_name_or_path=model_id, - token=token, - revision=revision, - force_download=force_download, - cache_dir=cache_dir, - local_files_only=local_files_only, - subfolder=subfolder, - trust_remote_code=trust_remote_code, - use_neural_speed=False, - **kwargs, - ) - - return cls( - model, config=config, model_save_dir=model_save_dir, q_config=quantization_config, **kwargs - ) - try: inc_config = INCConfig.from_pretrained(model_id, subfolder=subfolder, revision=revision) if not is_torch_version("==", inc_config.torch_version): @@ -254,7 +288,7 @@ def _from_pretrained( def _save_pretrained(self, save_directory: Union[str, Path]): if isinstance(self.model, torch.nn.Module): - # For ITREX model + # For INC weight only model if isinstance(self._q_config, PretrainedConfig): self._q_config.to_json_file(os.path.join(save_directory, QUANTIZATION_CONFIG_NAME)) self.model.save_pretrained(save_directory) diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index cc57f1f84..6ca9fd661 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -23,10 +23,12 @@ import torch from datasets import Dataset, load_dataset +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from neural_compressor.config import PostTrainingQuantConfig from neural_compressor.model.torch_model import IPEXModel, PyTorchModel from neural_compressor.quantization import fit -from packaging.version import parse +from neural_compressor.transformers import GPTQConfig, RtnConfig +from neural_compressor.transformers.quantization import convert_to_quantized_model, save_low_bit from torch.utils.data import DataLoader, RandomSampler from transformers import ( DataCollator, @@ -40,55 +42,19 @@ from ..utils.constant import _TASK_ALIASES, WEIGHTS_NAME from ..utils.import_utils import ( - ITREX_IMPORT_ERROR, _ipex_version, - _itrex_version, _neural_compressor_version, - _torch_version, is_ipex_version, - is_itrex_available, - is_itrex_version, is_neural_compressor_version, - is_torch_version, ) from .configuration import INCConfig -from .modeling_base import ( # noqa - INCModel, - INCModelForMaskedLM, - INCModelForMultipleChoice, - INCModelForQuestionAnswering, - INCModelForSeq2SeqLM, - INCModelForSequenceClassification, - INCModelForTokenClassification, - INCModelForVision2Seq, -) from .utils import ( IPEX_MINIMUM_VERSION, - ITREX_MINIMUM_TORCH_VERSION, - ITREX_MINIMUM_VERSION, NEURAL_COMPRESSOR_MINIMUM_VERSION, - NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION, INCDataLoader, ) -if is_itrex_available(): - if is_itrex_version("<", ITREX_MINIMUM_VERSION): - raise ImportError( - f"Found an incompatible version of `intel-extension-for-transformers`. Found version {_itrex_version}, " - f"but only version {ITREX_MINIMUM_VERSION} or higher is supported." - ) - - from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_to_quantized_model - from intel_extension_for_transformers.transformers.modeling.modeling_auto import save_low_bit - from intel_extension_for_transformers.transformers.utils.config import ( - AwqConfig, - GPTQConfig, - ITREXQuantizationConfigMixin, - RtnConfig, - ) - - logger = logging.getLogger(__name__) @@ -153,7 +119,7 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs): def quantize( self, - quantization_config: Union["PostTrainingQuantConfig", "ITREXQuantizationConfigMixin"], + quantization_config: Union["PostTrainingQuantConfig"], save_directory: Union[str, Path], calibration_dataset: Dataset = None, batch_size: int = 8, @@ -166,7 +132,7 @@ def quantize( Quantize a model given the optimization specifications defined in `quantization_config`. Args: - quantization_config (`Union[PostTrainingQuantConfig, ITREXQuantizationConfigMixin]`): + quantization_config (`Union[PostTrainingQuantConfig]`): The configuration containing the parameters related to quantization. save_directory (`Union[str, Path]`): The directory where the quantized model should be saved. @@ -181,9 +147,6 @@ def quantize( """ save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) - device = kwargs.pop("device", "cpu") - use_cpu = device == torch.device("cpu") or device == "cpu" - use_xpu = device == torch.device("xpu") or device == "xpu" calibration_dataloader = None default_name = WEIGHTS_NAME self._set_task() @@ -204,58 +167,7 @@ def quantize( f"but only version {IPEX_MINIMUM_VERSION} or higher is supported." ) - # ITREX Weight Only Quantization - if not isinstance(quantization_config, PostTrainingQuantConfig): - if is_itrex_version("==", "1.4.2") and ( - is_torch_version("!=", "2.3.0") or parse(_torch_version).local != "cpu" - ): - raise ImportError( - f"Found an incompatible version of `intel-extension-for-transformers` and `torch`. Found version itrex {_itrex_version} and torch {_torch_version}, " - f"but only torch 2.3.0+cpu is compatible with ITREX v1.4.2." - ) - - # check neural-compressor version - if is_neural_compressor_version("<", NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION): - raise ImportError( - f"Found an incompatible version of neural-compressor. Found version {_neural_compressor_version}, " - f"but only version {NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION} or higher supports weight-only quantization." - ) - if not is_itrex_available(): - raise ImportError(ITREX_IMPORT_ERROR.format("Weight only quantization")) - - if is_torch_version("<", ITREX_MINIMUM_TORCH_VERSION): - raise ImportError( - f"Found an incompatible version of `torch`. Found version {_torch_version}, " - f"but only version {ITREX_MINIMUM_TORCH_VERSION} or higher is supported." - ) - - if not isinstance(quantization_config, ITREXQuantizationConfigMixin): - raise TypeError( - "`quantization_config` should either be an instance of `neural_compressor.config.PostTrainingQuantConfig` or " - f"`intel_extension_for_transformers.transformers.utils.config.ITREXQuantizationConfigMixin` but got: {type(quantization_config)} instead." - ) - - if not isinstance(quantization_config, (GPTQConfig, RtnConfig)): - raise ValueError( - f"Weight-only quantization is only support RTN and GPTQ algorithm now! But got {quantization_config}" - ) - - if calibration_dataset is None and isinstance(quantization_config, (GPTQConfig, AwqConfig)): - raise ValueError( - "Weight-only quantization needs a calibration dataset for both GPTQ and AWQ methodologies." - ) - - if calibration_dataset is not None: - calibration_dataloader = self._get_calibration_dataloader( - calibration_dataset=calibration_dataset, - batch_size=batch_size, - remove_unused_columns=remove_unused_columns, - data_collator=data_collator, - use_label=not isinstance(quantization_config, (GPTQConfig)), - ) - quantization_config.calib_dataloader = calibration_dataloader - - elif INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC: + if INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC: # Since PyTorch fx trace does not really require an example_inputs, only need calibration_dataset or calibration_fn here. if calibration_dataset is None and self.calibration_fn is None: raise ValueError( @@ -270,56 +182,37 @@ def quantize( data_collator=data_collator, ) - if not isinstance(quantization_config, PostTrainingQuantConfig): - if use_cpu: - # will remove after intel-extension-for-transformers 1.3.3 release. - quantization_config.device = "cpu" - quantization_config.post_init_cpu() - elif use_xpu: - # will remove after intel-extension-for-transformers 1.3.3 release. - quantization_config.device = "xpu" - quantization_config.post_init_xpu() - - self._quantized_model = convert_to_quantized_model( - self._original_model, quantization_config, device=quantization_config.device - ) - - self._quantized_model.quantization_config = quantization_config - self._quantized_model.save_pretrained = types.MethodType(save_low_bit, self._quantized_model) - self._quantized_model.save_pretrained(save_directory) + if isinstance(self._original_model.config, PretrainedConfig): + self._original_model.config.backend = quantization_config.backend - else: - if isinstance(self._original_model.config, PretrainedConfig): - self._original_model.config.backend = quantization_config.backend - - compressed_model = fit( - self._original_model, - conf=quantization_config, - calib_dataloader=calibration_dataloader, - eval_func=self.eval_fn, - calib_func=self.calibration_fn, - ) + compressed_model = fit( + self._original_model, + conf=quantization_config, + calib_dataloader=calibration_dataloader, + eval_func=self.eval_fn, + calib_func=self.calibration_fn, + ) - if not hasattr(compressed_model, "_model") or compressed_model._model is None: - raise RuntimeError("Calling `neural_compressor.fit` returned unexpected results") + if not hasattr(compressed_model, "_model") or compressed_model._model is None: + raise RuntimeError("Calling `neural_compressor.fit` returned unexpected results") - if isinstance(self._original_model.config, PretrainedConfig): - # If backend is IPEX, then the quantized model is JIT model which will drop the config attribute, - # so need set config from original_model. - model_config = copy.deepcopy(self._original_model.config) - model_config.torch_dtype = "int8" - if isinstance(compressed_model, IPEXModel): - model_config.torchscript = True - model_config.backend = "ipex" - model_config.save_pretrained(save_directory) + if isinstance(self._original_model.config, PretrainedConfig): + # If backend is IPEX, then the quantized model is JIT model which will drop the config attribute, + # so need set config from original_model. + model_config = copy.deepcopy(self._original_model.config) + model_config.torch_dtype = "int8" + if isinstance(compressed_model, IPEXModel): + model_config.torchscript = True + model_config.backend = "ipex" + model_config.save_pretrained(save_directory) - self._quantized_model = compressed_model._model + self._quantized_model = compressed_model._model - output_path = save_directory.joinpath(file_name or default_name) - # Save the quantized model - self._save_pretrained(compressed_model, output_path) - quantization_config = INCConfig(quantization=quantization_config) - quantization_config.save_pretrained(save_directory) + output_path = save_directory.joinpath(file_name or default_name) + # Save the quantized model + self._save_pretrained(compressed_model, output_path) + quantization_config = INCConfig(quantization=quantization_config) + quantization_config.save_pretrained(save_directory) @staticmethod def _save_pretrained(model: Union[PyTorchModel, IPEXModel], output_path: str): @@ -440,3 +333,77 @@ def _get_calibration_dataloader( def _remove_unused_columns(self, dataset: Dataset): ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) return dataset.remove_columns(ignored_columns) + + +def _weight_only_quantization( + model_class, + model_id: Union[str, Path], + quantization_config: Union[RtnConfig, GPTQConfig], + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + local_files_only: bool = False, + subfolder: str = "", + trust_remote_code: bool = False, + **kwargs, +): + device_map = kwargs.get("device_map", None) + if device_map is None: + device_map = "xpu" if (hasattr(torch, "xpu") and torch.xpu.is_available()) else "cpu" + else: + device_map = device_map.type if isinstance(device_map, torch.device) else device_map + + use_xpu = device_map == torch.device("xpu") or device_map == "xpu" + if use_xpu and (not hasattr(torch, "xpu") or not torch.xpu.is_available()): + raise AssertionError("There is no xpu device in this system!") + + if is_neural_compressor_version("<=", "3.0"): + raise AssertionError("Please install neural_compressor > v3.0") + if is_ipex_version("<", "2.3.1") and use_xpu: + raise AssertionError("Please install intel_extension_for_pytorch >= v2.3.1.") + + loading_kwargs = { + "subfolder": subfolder, + "revision": revision, + "cache_dir": cache_dir, + "token": token, + "local_files_only": local_files_only, + "force_download": force_download, + "trust_remote_code": trust_remote_code, + } + + low_cpu_mem_usage = True + if use_xpu: + try: + # TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device. + model = model_class.from_pretrained( + model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs + ) + except NotImplementedError: + logger.info( + "Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption." + ) + low_cpu_mem_usage = False + model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs) + quantization_config.update(**{"device": "xpu"}) + quantization_config.post_init_xpu() + else: + model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs) + quantization_config.post_init_cpu() + + model.config.update({"low_cpu_mem_usage": low_cpu_mem_usage}) + model.eval() + + if (not torch.cuda.is_available() or device_map == "cpu") and model.config.model_type == "chatglm": + model = model.float() + + model = convert_to_quantized_model(model, quantization_config, device=device_map) + quantization_config.remove_redundant_parameters() + model.config.quantization_config = quantization_config + # add quantization_config and save_low_bit to pretrained model dynamically + model.device_map = device_map + model.quantization_config = quantization_config + model.save_pretrained = types.MethodType(save_low_bit, model) + + return model diff --git a/optimum/intel/neural_compressor/utils.py b/optimum/intel/neural_compressor/utils.py index 84c1d6dc2..c7a7ceda7 100644 --- a/optimum/intel/neural_compressor/utils.py +++ b/optimum/intel/neural_compressor/utils.py @@ -32,9 +32,7 @@ NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0" -IPEX_MINIMUM_VERSION = "2.1.0" -ITREX_MINIMUM_VERSION = "1.4.0" -ITREX_MINIMUM_TORCH_VERSION = "2.2.0" +IPEX_MINIMUM_VERSION = "2.3.1" _HEAD_TO_AUTOMODELS = { diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 8024d2389..2afe80d55 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -62,16 +62,6 @@ _neural_compressor_available = False -_itrex_available = importlib.util.find_spec("intel_extension_for_transformers") is not None -_itrex_version = "N/A" -if _itrex_available: - try: - _itrex_version = importlib_metadata.version("intel_extension_for_transformers") - logging.warn("`transformers` version >= 4.31 is requirements by intel-extension-for-transformers.") - except importlib_metadata.PackageNotFoundError: - _itrex_available = False - - _ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None _ipex_version = "N/A" if _ipex_available: @@ -177,10 +167,6 @@ def is_neural_compressor_available(): return _neural_compressor_available -def is_itrex_available(): - return _itrex_available - - def is_ipex_available(): return _ipex_available @@ -341,15 +327,6 @@ def is_neural_compressor_version(operation: str, version: str): return compare_versions(parse(_neural_compressor_version), operation, version) -def is_itrex_version(operation: str, version: str): - """ - Compare the current intel_extension_for_transformers version to a given reference with an operation. - """ - if not _itrex_available: - return False - return compare_versions(parse(_itrex_version), operation, version) - - def is_openvino_version(operation: str, version: str): """ Compare the current OpenVINO version to a given reference with an operation. @@ -432,11 +409,6 @@ def is_datasets_version(operation: str, version: str): `pip install neural-compressor`. Please note that you may need to restart your runtime after installation. """ -ITREX_IMPORT_ERROR = """ -{0} requires the intel-extension-for-transformers library but it was not found in your environment. You can install it with pip: -`pip install intel-extension-for-transformers` and `pip install peft`. Please note that you may need to restart your runtime after installation. -""" - DATASETS_IMPORT_ERROR = """ {0} requires the datasets library but it was not found in your environment. You can install it with pip: `pip install datasets`. Please note that you may need to restart your runtime after installation. @@ -454,7 +426,6 @@ def is_datasets_version(operation: str, version: str): ("nncf", (is_nncf_available, NNCF_IMPORT_ERROR)), ("openvino", (is_openvino_available, OPENVINO_IMPORT_ERROR)), ("neural_compressor", (is_neural_compressor_available, NEURAL_COMPRESSOR_IMPORT_ERROR)), - ("itrex", (is_itrex_available, ITREX_IMPORT_ERROR)), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), ] ) diff --git a/setup.py b/setup.py index cd488f830..bd1939925 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"] EXTRAS_REQUIRE = { - "neural-compressor": ["neural-compressor>=2.2.0,<3.0", "accelerate", "transformers<4.43"], + "neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<=4.43.2"], "openvino": ["openvino>=2023.3", "nncf>=2.11.0", "openvino-tokenizers[transformers]"], "nncf": ["nncf>=2.11.0"], "ipex": ["intel-extension-for-pytorch", "transformers>=4.39,<4.45"], diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index 0c3e60969..81e6d03dc 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -40,7 +40,6 @@ INCTrainer, ) from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, QUANTIZATION_CONFIG_NAME, WEIGHTS_NAME -from optimum.intel.utils.import_utils import is_itrex_available os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -147,25 +146,25 @@ def test_compare_with_and_without_past_key_values(self): self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH) self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv)) - @unittest.skipIf(not is_itrex_available(), reason="ITREX not available") - def test_saving_loading_woq_itrex_model(self): - model_name = "echarlaix/tiny-random-PhiForCausalLM" - subfolder = "itrex" - model = INCModelForCausalLM.from_pretrained(model_name, revision="itrex", subfolder=subfolder) - tokenizer = AutoTokenizer.from_pretrained(model_name, revision="itrex") + def test_saving_loading_inc_woq_model(self): + model_name = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" + model = INCModelForCausalLM.from_pretrained(model_name, revision="main") + tokenizer = AutoTokenizer.from_pretrained(model_name, revision="main") tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokens = tokenizer("This is a sample output", return_tensors="pt") + with torch.no_grad(): + outputs = model(**tokens) + with tempfile.TemporaryDirectory() as tmp_dir: - model_save_dir = Path(tmp_dir) / subfolder + model_save_dir = Path(tmp_dir) / "inc_woq" model.save_pretrained(model_save_dir) folder_contents = os.listdir(model_save_dir) self.assertIn(SAFE_WEIGHTS_NAME, folder_contents) self.assertIn(QUANTIZATION_CONFIG_NAME, folder_contents) - loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder) + loaded_model = INCModelForCausalLM.from_pretrained(model_save_dir) with torch.no_grad(): - outputs = model(**tokens) loaded_outputs = loaded_model(**tokens) self.assertTrue("logits" in loaded_outputs) diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index c1ac71e93..75f2845c7 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -45,7 +45,7 @@ set_seed, ) from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset -from optimum.intel.utils.import_utils import is_torch_version, is_itrex_available +from optimum.intel.utils.import_utils import is_torch_version from optimum.intel import ( INCConfig, @@ -467,50 +467,49 @@ def _compute_metrics(pred): class WeightOnlyQuantizationTest(INCTestMixin): WEIGHT_ONLY_CONFIG = ( - ("rtn", "int4_clip"), - ("rtn", "int8"), - ("gptq", "int4_clip"), + ("rtn", 4), + ("gptq", 4), ) @parameterized.expand(WEIGHT_ONLY_CONFIG) - @unittest.skipIf(not is_itrex_available(), reason="ITREX not available") - def test_weight_only_quantization(self, methodology, weight_dtype): - model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" - - from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig + def test_weight_only_quantization(self, methodology, bits): + from neural_compressor.transformers import GPTQConfig, RtnConfig - bits = 4 if "4" in weight_dtype else 8 + model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" if methodology == "gptq": - # max_input_length can be removed after neural-compressor > v2.5.1 + tokenizer = AutoTokenizer.from_pretrained(model_name) quantization_config = GPTQConfig( - bits=bits, sym=True, damp_percent=0.01, weight_dtype=weight_dtype, max_input_length=128 + bits=bits, + sym=True, + damp_percent=0.01, + desc_act=True, + tokenizer=tokenizer, + n_samples=20, + group_size=8, + batch_size=5, + seq_len=32, + block_size=16, ) else: - quantization_config = RtnConfig(bits=bits, weight_dtype=weight_dtype) + quantization_config = RtnConfig(bits=bits, group_size=8) - model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - quantizer = INCQuantizer.from_pretrained(copy.deepcopy(model), task="text-generation") - calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) + tokens = tokenizer("This is a sample output", return_tensors="pt") with tempfile.TemporaryDirectory() as tmp_dir: - quantizer.quantize( - quantization_config=quantization_config, - calibration_dataset=calibration_dataset, - save_directory=tmp_dir, - ) - loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir) - - tokens = tokenizer("This is a sample output", return_tensors="pt") + quantized_model = INCModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config) + with torch.no_grad(): + quantizer_outputs = quantized_model(**tokens) + quantized_model.save_pretrained(tmp_dir) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir) with torch.no_grad(): loaded_outputs = loaded_model(**tokens) - # quantizer_outputs = model(**tokens) self.assertTrue("logits" in loaded_outputs) self.assertIsInstance(loaded_outputs.logits, torch.Tensor) self.assertTrue("past_key_values" in loaded_outputs) self.assertIsInstance(loaded_outputs.past_key_values, tuple) - # self.assertTrue(torch.allclose(quantizer_outputs.logits, loaded_outputs.logits, equal_nan=True, atol=1e-4)) + self.assertTrue(torch.allclose(quantizer_outputs.logits, loaded_outputs.logits, equal_nan=True, atol=1e-4))