Skip to content

Commit

Permalink
Apply weight compression after model save to reduce peak RAM during e…
Browse files Browse the repository at this point in the history
…xport (huggingface#878)

* Initial commit

* Style

* Adopt tests

* Add no-nncf warning

* Apply suggested changes

* Do not save in fp16 in case of weight compression

* Replace model files right away
  • Loading branch information
nikita-savelyevv committed Aug 30, 2024
1 parent d6e6e1f commit b5998f2
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 103 deletions.
50 changes: 48 additions & 2 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import gc
import logging
import operator
import warnings
from functools import reduce
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

Expand All @@ -23,18 +25,20 @@
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from transformers.utils import is_torch_available

from openvino.runtime import Core, Type, save_model
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
from optimum.exporters.openvino.convert import export_from_model
from optimum.intel.utils.import_utils import (
is_nncf_available,
is_openvino_tokenizers_available,
is_openvino_version,
is_transformers_version,
)
from optimum.utils.save_utils import maybe_load_preprocessors

from .utils import clear_class_registry
from .utils import _MAX_UNCOMPRESSED_SIZE, clear_class_registry


if TYPE_CHECKING:
Expand Down Expand Up @@ -402,7 +406,7 @@ class StoreAttr(object):
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)

export_from_model(
submodel_paths = export_from_model(
model=model,
output=output,
task=task,
Expand All @@ -425,6 +429,48 @@ class StoreAttr(object):
del model
gc.collect()

core = Core()
for submodel_path in submodel_paths:
submodel_path = Path(output) / submodel_path
submodel = core.read_model(submodel_path)

quantization_config = None
if ov_config is None:
num_parameters = 0
for op in submodel.get_ops():
if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]:
num_parameters += reduce(operator.mul, op.shape, 1)
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
quantization_config = {"bits": 8, "sym": False}
logger.info("The model weights will be quantized to int8_asym.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 "
"requires nncf. Please install it with `pip install nncf`"
)
break
else:
quantization_config = ov_config.quantization_config
if quantization_config is None:
continue

if not is_nncf_available():
raise ImportError("Quantization of the weights requires nncf, please install it with `pip install nncf`")

from optimum.intel.openvino.quantization import _weight_only_quantization

_weight_only_quantization(submodel, quantization_config)

compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml"
save_model(submodel, compressed_submodel_path, compress_to_fp16=False)
del submodel

submodel_path.unlink()
submodel_path.with_suffix(".bin").unlink()
compressed_submodel_path.rename(submodel_path)
compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin"))

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
Expand Down
38 changes: 3 additions & 35 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from .model_patcher import patch_model_with_bettertransformer
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
from .utils import (
_MAX_UNCOMPRESSED_SIZE,
OV_XML_FILE_NAME,
clear_class_registry,
flattenize_inputs,
Expand All @@ -76,21 +75,7 @@


def _save_model(model, path: str, ov_config: Optional["OVConfig"] = None, library_name: Optional[str] = None):
compress_to_fp16 = False

if ov_config is not None:
if ov_config.quantization_config:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

from optimum.intel.openvino.quantization import _weight_only_quantization

_weight_only_quantization(model, ov_config.quantization_config)

compress_to_fp16 = ov_config.dtype == "fp16"

compress_to_fp16 = ov_config is not None and ov_config.dtype == "fp16"
model = _add_version_info_to_model(model, library_name)
save_model(model, path, compress_to_fp16)

Expand Down Expand Up @@ -643,25 +628,6 @@ def export_from_model(
)
logging.disable(logging.NOTSET)

if ov_config is None:
if library_name == "diffusers":
num_parameters = model.unet.num_parameters()
else:
num_parameters = sum(param.numel() for param in list(model.parameters()) if param.requires_grad)

if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
from ...intel.openvino.configuration import OVConfig

ov_config = OVConfig(quantization_config={"bits": 8, "sym": False})

logger.info("The model weights will be quantized to int8_asym.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)

if library_name != "diffusers":
# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
Expand Down Expand Up @@ -720,6 +686,8 @@ def export_from_model(
patch_16bit_model=patch_16bit_model,
)

return files_subpaths


def export_tokenizer(
tokenizer,
Expand Down
138 changes: 72 additions & 66 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

# ruff: noqa

Expand All @@ -22,6 +23,7 @@
from enum import Enum
from functools import partial
from typing import Union

import pytest
import evaluate
import numpy as np
Expand Down Expand Up @@ -538,76 +540,80 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
self.assertEqual(0, num_int8)

def test_ovmodel_load_large_model_with_default_compressed_weights(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
)
save_model_patch.assert_called_with(
unittest.mock.ANY,
unittest.mock.ANY,
ov_config=OVConfig(quantization_config={"bits": 8}),
library_name="transformers",
)
def main_export_in_stacktrace(*args, **kwargs):
# Compression was called from `main_export`
self.assertTrue(inspect.stack()[5].function == "main_export")

with unittest.mock.patch(
"openvino.runtime.op.Constant.shape", new_callable=unittest.mock.PropertyMock
) as ov_constant_shape:
ov_constant_shape.return_value = (2000000000,)
with unittest.mock.patch(
"nncf.compress_weights", side_effect=main_export_in_stacktrace
) as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT8_ASYM,
"ratio": 1.0,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": nncf.IgnoredScope(),
"awq": None,
"subset_size": 128,
"scale_estimation": None,
}
compress_weights_patch.assert_called_with(
unittest.mock.ANY,
**compression_params,
)

def test_ovmodel_load_large_model_with_uncompressed_weights(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
)
save_model_patch.assert_called_with(
unittest.mock.ANY,
unittest.mock.ANY,
ov_config=OVConfig(dtype="auto"),
library_name="transformers",
)
with unittest.mock.patch(
"openvino.runtime.op.Constant.shape", new_callable=unittest.mock.PropertyMock
) as ov_constant_shape:
ov_constant_shape.return_value = (2000000000,)
with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
)
compress_weights_patch.assert_not_called()

def test_ovmodel_load_large_model_with_additional_quantization_config(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"],
export=True,
compile=False,
use_cache=False,
quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
)
# quantization will be performed later, using load_model
save_model_patch.assert_called_with(
unittest.mock.ANY,
unittest.mock.ANY,
ov_config=OVConfig(dtype="auto"),
library_name="transformers",
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"ratio": 0.8,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": nncf.IgnoredScope(),
"awq": None,
"subset_size": 128,
"scale_estimation": None,
}
compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)
def main_export_not_in_stacktrace(*args, **kwargs):
# Compression was not called from `main_export`
self.assertTrue(all(frame_info.function != "main_export" for frame_info in inspect.stack()))

with unittest.mock.patch(
"openvino.runtime.op.Constant.shape", new_callable=unittest.mock.PropertyMock
) as ov_constant_shape:
ov_constant_shape.return_value = (2000000000,)
with unittest.mock.patch(
"nncf.compress_weights", side_effect=main_export_not_in_stacktrace
) as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"],
export=True,
compile=False,
use_cache=False,
quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"ratio": 0.8,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": nncf.IgnoredScope(),
"awq": None,
"subset_size": 128,
"scale_estimation": None,
}
compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)

@parameterized.expand(LOAD_IN_4_BITS_SCOPE)
def test_ovmodel_4bit_dynamic_with_config(self, model_cls, model_name, quantization_config, expected_ov_int4):
Expand Down

0 comments on commit b5998f2

Please sign in to comment.