From 48e6f05a0046a5c7416bd616168a66cd326909a7 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Tue, 18 Feb 2025 21:23:13 +0200 Subject: [PATCH 1/6] init --- .../weight_compression/torch_backend.py | 89 ++-- nncf/quantization/quantize_model.py | 31 +- nncf/torch/quantization/quantize_model.py | 11 +- .../quantization/test_weights_compression.py | 432 ++++++++++++++++++ 4 files changed, 523 insertions(+), 40 deletions(-) create mode 100644 tests/torch2/function_hook/quantization/test_weights_compression.py diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 8cf5422eee2..7822db2504e 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -22,6 +22,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.statistic_point import StatisticPoint +from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator @@ -34,6 +35,9 @@ from nncf.experimental.common.tensor_statistics.statistics import MeanMagnitudeTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic +from nncf.experimental.torch2.commands import PT2InsertionCommand +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.experimental.torch2.model_transformer import PT2ModelTransformer from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.backend import MixedPrecisionAlgoBackend from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend @@ -179,8 +183,14 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: return activation_ports[0] def get_weight( - self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph + self, + node_with_weight: NNCFNode, + weight_port_id: int, + model: Union[GraphModelWrapper, torch.nn.Module], + graph: NNCFGraph, ) -> Tensor: + if isinstance(model, GraphModelWrapper): + model = model.model weight_node = get_const_node(node_with_weight, weight_port_id, graph) weight_name = weight_node.layer_attributes.name weight = get_const_data(weight_node, model) @@ -190,7 +200,11 @@ def get_weight( return Tensor(weight) def get_weight_dtype( - self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph + self, + node_with_weight: NNCFNode, + weight_port_id: int, + model: Union[GraphModelWrapper, torch.nn.Module], + graph: NNCFGraph, ) -> TensorDataType: return self.get_weight(node_with_weight, weight_port_id, model, graph).dtype @@ -222,13 +236,19 @@ def filter_func(point: StatisticPoint) -> bool: def transform_model( self, - model: NNCFNetwork, + model: Union[GraphModelWrapper, torch.nn.Module], graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters], precomputed_scales: Dict[str, Tensor] = None, precomputed_zero_points: Dict[str, Tensor] = None, lora_correction_algo: LoraCorrectionAlgorithm = None, ) -> NNCFNetwork: + if isinstance(model, GraphModelWrapper): + model_transformer = PT2ModelTransformer(model) + model = model.model + else: + model_transformer = PTModelTransformer(model) + transformation_layout = TransformationLayout() for wc_params in weight_compression_parameters: @@ -284,38 +304,55 @@ def transform_model( # sets compressed tensor # TODO:(AlexanderDokuchaev): update set_const_data - compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) module_name, weight_attr_name = split_const_name(weight_name) module = get_module_by_name(module_name, model) weight = getattr(module, weight_attr_name) + if not isinstance(weight, torch.nn.Parameter): msg = f"Weight is not a torch.nn.Parameter in the model by name {weight_name}." raise nncf.InternalError(msg) - setattr(module, weight_attr_name, compressed_parameter) - - consumer_nodes = graph.get_next_nodes(weight_node) - if len(consumer_nodes) > 1: - for c_node in consumer_nodes: - c_module = model.nncf.get_module_by_scope(Scope.from_str(c_node.layer_name)) - for name, param in c_module.named_parameters(recurse=False, remove_duplicate=False): - if id(param) == id(weight): - setattr(c_module, name, compressed_parameter) - - # registry weight decompression module in the model - decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" - - # inserts the weight decompressor into the model as the post hook on the model weight - transformation_layout.register( - PTSharedFnInsertionCommand( - [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)], - decompressor, - decompressor_name, + if is_experimental_torch_tracing_enabled(): + weight.requires_grad = False + weight.data = packed_tensor + + transformation_layout.register( + PT2InsertionCommand( + [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":") + ) + ], + decompressor, + ) + ) + else: + compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) + + setattr(module, weight_attr_name, compressed_parameter) + + consumer_nodes = graph.get_next_nodes(weight_node) + if len(consumer_nodes) > 1: + for c_node in consumer_nodes: + c_module = model.nncf.get_module_by_scope(Scope.from_str(c_node.layer_name)) + for name, param in c_module.named_parameters(recurse=False, remove_duplicate=False): + if id(param) == id(weight): + setattr(c_module, name, compressed_parameter) + + # registry weight decompression module in the model + decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" + + # inserts the weight decompressor into the model as the post hook on the model weight + transformation_layout.register( + PTSharedFnInsertionCommand( + [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)], + decompressor, + decompressor_name, + ) ) - ) # apply transformations - transformed_model = PTModelTransformer(model).transform(transformation_layout) + transformed_model = model_transformer.transform(transformation_layout) return transformed_model diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 9b7cdb6edcb..5fc594f0fb0 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -510,8 +510,9 @@ def compress_weights( compression_weights_impl = None if backend == BackendType.TORCH: + from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled + from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.model_creation import is_wrapped_model - from nncf.torch.model_creation import wrap_model from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: @@ -532,21 +533,29 @@ def compress_weights( msg = "Torch does not support statistics caching." raise nncf.ParameterNotSupportedError(msg) - if is_wrapped_model(model): - if not model.nncf.trace_parameters: - msg = ( - "Tracing capabilities with tracing parameters are required in the PyTorch model " - "for nncf.compress_weights(). Please wrap the model using " - "nncf.torch.wrap_model(model, example_input, trace_parameters=True) before calling " - "nncf.compress_weights()." - ) - raise nncf.ValidationError(msg) + if is_wrapped_model(model) and not model.nncf.trace_parameters: + msg = ( + "Tracing capabilities with tracing parameters are required in the PyTorch model " + "for nncf.compress_weights(). Please wrap the model using " + "nncf.torch.wrap_model(model, example_input, trace_parameters=True) before calling " + "nncf.compress_weights()." + ) + raise nncf.ValidationError(msg) + if isinstance(model, GraphModelWrapper): + pass elif dataset is None: msg = "Please provide a dataset of at least one element for PyTorch model tracing." raise nncf.ValidationError(msg) else: example_input = next(iter(dataset.get_inference_data())) - model = wrap_model(model, example_input=example_input, trace_parameters=True) + if is_experimental_torch_tracing_enabled(): + from nncf.experimental.torch2.function_hook import wrap_model + + model = GraphModelWrapper(wrap_model(model), example_input=example_input) + else: + from nncf.torch.model_creation import wrap_model + + model = wrap_model(model, example_input=example_input, trace_parameters=True) if mode in (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM): dataset = None # data-aware methods don't support INT8 modes compression_weights_impl = pt_compression_weights_impl diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 734f15a1607..f315331766e 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -10,7 +10,7 @@ # limitations under the License. from copy import deepcopy -from typing import Optional +from typing import Optional, Union import torch @@ -18,6 +18,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType @@ -85,7 +86,7 @@ def quantize_impl( def compress_weights_impl( - model: torch.nn.Module, + model: Union[GraphModelWrapper, torch.nn.Module], dataset: Dataset, mode: CompressWeightsMode, ratio: float, @@ -120,4 +121,8 @@ def compress_weights_impl( advanced_parameters, ) graph = NNCFGraphFactory.create(model) - return compression_algorithm.apply(model, graph, dataset=dataset) + + compressed_model = compression_algorithm.apply(model, graph, dataset=dataset) + if isinstance(compressed_model, GraphModelWrapper): + compressed_model = compressed_model.model + return compressed_model diff --git a/tests/torch2/function_hook/quantization/test_weights_compression.py b/tests/torch2/function_hook/quantization/test_weights_compression.py new file mode 100644 index 00000000000..85999fa39b3 --- /dev/null +++ b/tests/torch2/function_hook/quantization/test_weights_compression.py @@ -0,0 +1,432 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import nncf +from nncf import BackupMode +from nncf import CompressWeightsMode +from nncf import SensitivityMetric +from nncf.experimental.torch2.function_hook import get_hook_storage +from nncf.experimental.torch2.function_hook import wrap_model +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.quantization import compress_weights +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor +from nncf.torch.quantization.quantize_functions import pack_int4 +from nncf.torch.quantization.quantize_functions import pack_uint4 +from nncf.torch.quantization.quantize_functions import unpack_int4 +from nncf.torch.quantization.quantize_functions import unpack_uint4 +from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression +from tests.torch.test_models.synthetic import ShortTransformer +from tests.torch.test_tensor import cast_to + +ALL_SENSITIVITY_METRICS = list(SensitivityMetric) + +INT8_MODES = (CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) +INT4_MODES = (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM) +SUPPORTED_MODES = INT8_MODES + INT4_MODES +UNSUPPORTED_MODES = (CompressWeightsMode.NF4, CompressWeightsMode.E2M1) + + +class SequentialMatmulModel(nn.Module): + def __init__(self): + super().__init__() + self.main_values = [10000, 1000, 1, 10, 10000] + self.layers = nn.ModuleList() + + for _, main_value in enumerate(self.main_values): + weights_data = torch.arange(0, 16, dtype=torch.float32).reshape(4, 4) + weights_data[-1, -1] = main_value + weight_tensor = torch.tensor(weights_data) + layer = nn.Linear(4, 4, bias=False) + layer.weight = nn.Parameter(weight_tensor.t()) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + def get_weight_names_in_exec_order(self): + return [f"layers:{i}:weight" for i in range(len(self.main_values))] + + +class MatMulModel(torch.nn.Module): + def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)): + super().__init__() + self.w = torch.nn.Parameter(weight) + + def forward(self, input): + return input @ self.w + + +class LinearModel(torch.nn.Module): + def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)): + super().__init__() + self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False) + self.linear.weight = torch.nn.Parameter(weight) + + def forward(self, input): + return self.linear(input) + + +class FunctionalModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32)) + self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 256, 256), dtype=torch.float32)) + self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3))) + self.nested_matmul = MatMulModel() + + def forward(self, input_): + x = input_.to(torch.float32) + x = x @ self.matmul_w + x = self.nested_matmul(x) + x = F.conv2d(x, self.conv_w) + x = F.conv_transpose2d(x, self.conv_tr_w) + return x + + +class ConvolutionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_regular = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3) + self.max_pool2d = torch.nn.MaxPool2d(kernel_size=2) + self.conv_transpose = torch.nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3) + self.conv_depthwise = torch.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, groups=8) + self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=1) + self.linear = torch.nn.Linear(in_features=8, out_features=8) + + def forward(self, input_): + input_ = input_.to(torch.float32) + x = self.conv_regular(input_) + x = F.relu(x) + x.transpose_(2, 3) + x = self.max_pool2d(x) + y = self.conv_transpose(x) + z = F.conv_transpose2d(x, self.conv_transpose.weight) + x = y + z + x = self.conv_depthwise(x) + x = F.conv2d(x, self.conv_depthwise.weight, groups=self.conv_depthwise.groups) + x += torch.ones_like(x) + x = self.adaptive_avg_pool(x) + x = self.linear(x.flatten()) + return x + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights(mode): + model = wrap_model(ShortTransformer(8, 16)) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 + + input_ids = torch.randint(0, 10, (8,)) + wrapped_model = GraphModelWrapper(model, example_input=input_ids) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) + + n_compressed_weights = 0 + n_target_modules = 0 + + for _, module in compressed_model.named_children(): + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + n_target_modules += 1 + if module.weight.dtype == dtype: + n_compressed_weights += 1 + + assert n_compressed_weights == n_target_modules + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights_functional_model(mode): + model = wrap_model(FunctionalModel()) + decompressor_map = { + CompressWeightsMode.INT8_SYM: (INT8SymmetricWeightsDecompressor,), + CompressWeightsMode.INT8_ASYM: (INT8AsymmetricWeightsDecompressor,), + CompressWeightsMode.INT4_SYM: (INT4SymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor), + CompressWeightsMode.INT4_ASYM: (INT4AsymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor), + } + + decompressor_type = decompressor_map[mode] + + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) + wrapped_model = GraphModelWrapper(model, example_input=input_ids) + compressed_model = compress_weights(wrapped_model, mode=mode) + + n_compressed_weights = 0 + for layer in compressed_model.modules(): + if isinstance(layer, decompressor_type): + n_compressed_weights += 1 + assert n_compressed_weights == 4 + + +def test_compress_weights_conv(): + model = wrap_model(ConvolutionModel()) + + input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + wrapped_model = GraphModelWrapper(model, example_input=input_ids) + compressed_model = compress_weights(wrapped_model) + + n_compressed_weights = 0 + n_target_modules = 0 + + for module in compressed_model.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + n_target_modules += 1 + if module.weight.dtype in [torch.uint8, torch.int8]: + n_compressed_weights += 1 + + assert n_compressed_weights == n_target_modules + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_shared_weights(mocker, mode): + model = wrap_model(ShortTransformer(8, 16, share_weights=True)) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 + + input_ids = torch.randint(0, 10, (8,)) + wrapped_model = GraphModelWrapper(model, example_input=input_ids) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) + + n_compressed_weights = 0 + n_target_modules = 0 + + for module in compressed_model.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + n_target_modules += 1 + if module.weight.dtype == dtype: + n_compressed_weights += 1 + + assert n_compressed_weights == n_target_modules + + hook_storage = get_hook_storage(compressed_model) + decompressed_modules = list(x for x in hook_storage.post_hooks.modules() if not isinstance(x, nn.ModuleDict)) + assert len(decompressed_modules) == 2 + + # TODO(AlexanderDokuchaev): cache for shared weights + # check that the weight decompressors are called only once + for val in decompressed_modules: + mocker.spy(val, "forward") + compressed_model(input_ids) + + for val in decompressed_modules: + assert val.forward.call_count in [1, 2] # TODO(AlexanderDokuchaev): cache for shared weights + + +class EmptyModel(torch.nn.Module): + def forward(self, input): + return input + + +@pytest.mark.parametrize("mode", INT8_MODES) +@pytest.mark.parametrize( + "params", + ( + {"ratio": 0.5}, + {"group_size": 64}, + {"all_layers": True}, + {"all_layers": False}, + *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, + {"lora_correction": True}, + {"backup_mode": BackupMode.NONE}, + {"backup_mode": BackupMode.INT8_ASYM}, + {"backup_mode": BackupMode.INT8_SYM}, + {"advanced_parameters": AdvancedCompressionParameters(statistics_path="anything")}, + ), +) +def test_raise_error_with_unsupported_params_for_int8(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = GraphModelWrapper(wrap_model(dummy_torch_model), example_input=dummy_input) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(wrapped_model, mode=mode, **params) + + +@pytest.mark.parametrize("mode", INT4_MODES) +@pytest.mark.parametrize( + "params", + ( + {"gptq": True}, + {"awq": True}, + {"lora_correction": True}, + ), +) +def test_raise_error_with_unsupported_params_for_int4(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = GraphModelWrapper(wrap_model(dummy_torch_model), example_input=dummy_input) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(wrapped_model, mode=mode, **params) + + +@pytest.mark.parametrize("mode", UNSUPPORTED_MODES) +def test_raise_error_with_not_int8(mode): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = GraphModelWrapper(wrap_model(dummy_torch_model), example_input=dummy_input) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(wrapped_model, mode=mode) + + +def test_raise_error_for_statistics_caching(): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = GraphModelWrapper(wrap_model(dummy_torch_model), example_input=dummy_input) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(wrapped_model, advanced_parameters=AdvancedCompressionParameters(statistics_path="anything")) + + +class DTypeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size=(3, 3), dtype=torch.float32)) + + def forward(self, x): + x = x.to(self.weight.dtype) + x = x @ self.weight + return x + + +def test_get_dtype_attribute_of_parameter(): + model = DTypeModel() + dummy_input = torch.randint(0, 10, [3, 3]) + wrapped_model = GraphModelWrapper(wrap_model(model), example_input=dummy_input) + compressed_model = compress_weights(wrapped_model) + assert compressed_model.weight.dtype == torch.uint8 + compressed_model(dummy_input) + assert compressed_model.weight.dtype == torch.uint8 + + +@pytest.mark.parametrize("dtype", ("float16", "float32")) +def test_model_devices_and_precisions(use_cuda, dtype): + if use_cuda and not torch.cuda.is_available(): + pytest.skip("Skipping for CPU-only setups") + device = torch.device("cuda" if use_cuda else "cpu") + dtype = torch.float16 if dtype == "float16" else torch.float32 + + model = MatMulModel().to(device) + if dtype == torch.float16: + model.half() + + dummy_input = torch.rand((1, 256), dtype=dtype, device=device) + wrapped_model = GraphModelWrapper(wrap_model(model), example_input=dummy_input) + compressed_model = compress_weights(wrapped_model) + result = compressed_model(dummy_input) + + # Scale should always be in float16 + assert compressed_model.state_dict()["__nncf_hooks.post_hooks.w__0.0._scale"].dtype == torch.float16 + # Result should be in the precision of the model + assert result.dtype == dtype + + +def test_pack_uint4(): + w_uint8 = torch.randint(0, 15, (4, 4), dtype=torch.uint8) + packed_w = pack_uint4(w_uint8) + assert packed_w.dtype == torch.uint8 + assert packed_w.numel() * 2 == w_uint8.numel() + unpacked_w = unpack_uint4(packed_w).reshape(w_uint8.shape) + assert torch.all(unpacked_w == w_uint8) + + +def test_pack_int4(): + w_int8 = torch.randint(-8, 7, (4, 4), dtype=torch.int8) + packed_w = pack_int4(w_int8) + assert packed_w.dtype == torch.uint8 + assert packed_w.numel() * 2 == w_int8.numel() + unpacked_w = unpack_int4(packed_w).reshape(w_int8.shape) + assert torch.all(unpacked_w == w_int8) + + +class TestPTTemplateWeightCompression(TemplateWeightCompression): + @staticmethod + def get_matmul_model() -> torch.nn.Module: + return MatMulModel(255 * torch.eye(3, dtype=torch.float32)) + + @staticmethod + def get_sequential_matmul_model() -> torch.nn.Module: + return SequentialMatmulModel() + + @staticmethod + def to_tensor(t) -> torch.Tensor: + return torch.tensor(t) + + @staticmethod + def cast_to(x: torch.Tensor, dtype: TensorDataType) -> torch.Tensor: + return cast_to(x, dtype) + + @staticmethod + def check_weights(model: torch.nn.Module, ref_ids: List[int]) -> None: + all_names = model.get_weight_names_in_exec_order() + low_precision_nodes = list(map(lambda i: all_names[i], ref_ids)) + decompressed_modules = list( + x for x in get_hook_storage(model).named_modules() if not isinstance(x[1], nn.ModuleDict) + ) + for op_name, op in decompressed_modules: + for name in low_precision_nodes: + if name in op_name: + assert isinstance(op, INT4SymmetricWeightsDecompressor) + + @staticmethod + def get_model_for_test_scale_estimation(): + return LinearModel(torch.arange(0, 8 * 16, dtype=torch.float32).reshape(16, 8)) + + @staticmethod + def get_scale_estimation_ref(): + return torch.tensor( + [ + [[0.473328]], + [[0.929023]], + [[1.446527]], + [[1.920595]], + [[2.517054]], + [[3.030102]], + [[3.584279]], + [[4.043509]], + [[4.620008]], + [[5.165322]], + [[5.710637]], + [[6.122581]], + [[6.655914]], + [[7.237174]], + [[7.722580]], + [[8.255914]], + ] + ) + + @staticmethod + def get_orig_weight(model: torch.nn.Module) -> Tensor: + return Tensor(model.linear.weight.data) + + @staticmethod + def get_decompressed_weight(compressed_model: torch.nn.Module, input: torch.Tensor) -> Tensor: + weight = compressed_model.linear.weight + unpacked_w = compressed_model.get_submodule("__nncf_hooks.post_hooks.linear:weight__0.0")(weight) + return Tensor(unpacked_w) From 04dbe64edd72cad8faeccc7aebcded10d474be0e Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Wed, 19 Feb 2025 00:26:39 +0200 Subject: [PATCH 2/6] fix --- nncf/quantization/quantize_model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 5fc594f0fb0..a676614ee84 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -533,15 +533,16 @@ def compress_weights( msg = "Torch does not support statistics caching." raise nncf.ParameterNotSupportedError(msg) - if is_wrapped_model(model) and not model.nncf.trace_parameters: - msg = ( - "Tracing capabilities with tracing parameters are required in the PyTorch model " - "for nncf.compress_weights(). Please wrap the model using " - "nncf.torch.wrap_model(model, example_input, trace_parameters=True) before calling " - "nncf.compress_weights()." - ) - raise nncf.ValidationError(msg) - if isinstance(model, GraphModelWrapper): + if is_wrapped_model(model): + if not model.nncf.trace_parameters: + msg = ( + "Tracing capabilities with tracing parameters are required in the PyTorch model " + "for nncf.compress_weights(). Please wrap the model using " + "nncf.torch.wrap_model(model, example_input, trace_parameters=True) before calling " + "nncf.compress_weights()." + ) + raise nncf.ValidationError(msg) + elif isinstance(model, GraphModelWrapper): pass elif dataset is None: msg = "Please provide a dataset of at least one element for PyTorch model tracing." From 84c8a796fcef3c1268e1e28acaae1f77f5970fd0 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 21 Feb 2025 00:27:10 +0200 Subject: [PATCH 3/6] comment --- nncf/quantization/quantize_model.py | 18 +++++------------- nncf/torch/model_creation.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index a676614ee84..0f87946954e 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -510,9 +510,8 @@ def compress_weights( compression_weights_impl = None if backend == BackendType.TORCH: - from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled - from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.torch.model_creation import is_wrapped_model + from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: @@ -534,7 +533,7 @@ def compress_weights( raise nncf.ParameterNotSupportedError(msg) if is_wrapped_model(model): - if not model.nncf.trace_parameters: + if isinstance(model, NNCFNetwork) and not model.nncf.trace_parameters: msg = ( "Tracing capabilities with tracing parameters are required in the PyTorch model " "for nncf.compress_weights(). Please wrap the model using " @@ -542,21 +541,14 @@ def compress_weights( "nncf.compress_weights()." ) raise nncf.ValidationError(msg) - elif isinstance(model, GraphModelWrapper): - pass elif dataset is None: msg = "Please provide a dataset of at least one element for PyTorch model tracing." raise nncf.ValidationError(msg) else: - example_input = next(iter(dataset.get_inference_data())) - if is_experimental_torch_tracing_enabled(): - from nncf.experimental.torch2.function_hook import wrap_model - - model = GraphModelWrapper(wrap_model(model), example_input=example_input) - else: - from nncf.torch.model_creation import wrap_model + from nncf.torch.model_creation import wrap_model - model = wrap_model(model, example_input=example_input, trace_parameters=True) + example_input = next(iter(dataset.get_inference_data())) + model = wrap_model(model, example_input=example_input, trace_parameters=True) if mode in (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM): dataset = None # data-aware methods don't support INT8 modes compression_weights_impl = pt_compression_weights_impl diff --git a/nncf/torch/model_creation.py b/nncf/torch/model_creation.py index 2dc48b8f17c..d6ca086826e 100644 --- a/nncf/torch/model_creation.py +++ b/nncf/torch/model_creation.py @@ -27,6 +27,8 @@ from nncf.config.extractors import extract_algorithm_names from nncf.config.extractors import has_input_info_field from nncf.config.telemetry_extractors import CompressionStartedFromConfig +from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.telemetry import tracked_function from nncf.telemetry.events import NNCF_PT_CATEGORY from nncf.telemetry.extractors import FunctionCallTelemetryExtractor @@ -350,6 +352,15 @@ def wrap_model( :param trace_parameters: Whether to trace model parameters. Default is False. :return: A model wrapped by NNCFNetwork. """ + if is_experimental_torch_tracing_enabled(): + if not trace_parameters: + msg = "The 'trace_parameters=False' option is not supported in the experimental tracing mode." + raise nncf.InternalError(msg) + from nncf.experimental.torch2.function_hook import wrap_model + + wrapped_model = GraphModelWrapper(wrap_model(model), example_input=example_input) + return wrapped_model + if not isinstance(model, torch.nn.Module): msg = ( f"The provided model type {type(model)} is incompatible. " @@ -370,12 +381,12 @@ def wrap_model( def is_wrapped_model(model: torch.nn.Module) -> bool: """ - Check that the model was wrapped by NNCFNetwork. + Check that the model was wrapped by NNCFNetwork or GraphModelWrapper. :param model: A model. :return: True if the model is wrapped, False otherwise. """ - return isinstance(model, NNCFNetwork) + return isinstance(model, (NNCFNetwork, GraphModelWrapper)) @tracked_function( From 4e957a8514dedf4dfa05502fbdc52485e16ef88c Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 21 Feb 2025 00:53:37 +0200 Subject: [PATCH 4/6] fix circular import --- nncf/torch/model_creation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nncf/torch/model_creation.py b/nncf/torch/model_creation.py index d6ca086826e..7a02a578a94 100644 --- a/nncf/torch/model_creation.py +++ b/nncf/torch/model_creation.py @@ -28,7 +28,6 @@ from nncf.config.extractors import has_input_info_field from nncf.config.telemetry_extractors import CompressionStartedFromConfig from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled -from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.telemetry import tracked_function from nncf.telemetry.events import NNCF_PT_CATEGORY from nncf.telemetry.extractors import FunctionCallTelemetryExtractor @@ -357,6 +356,7 @@ def wrap_model( msg = "The 'trace_parameters=False' option is not supported in the experimental tracing mode." raise nncf.InternalError(msg) from nncf.experimental.torch2.function_hook import wrap_model + from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper wrapped_model = GraphModelWrapper(wrap_model(model), example_input=example_input) return wrapped_model @@ -386,6 +386,8 @@ def is_wrapped_model(model: torch.nn.Module) -> bool: :param model: A model. :return: True if the model is wrapped, False otherwise. """ + from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper + return isinstance(model, (NNCFNetwork, GraphModelWrapper)) From 8fe534e6128bcb712c8943bce8e2a6cba4b8c6ad Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 21 Feb 2025 04:58:41 +0200 Subject: [PATCH 5/6] wa --- tests/post_training/pipelines/lm_weight_compression.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py index 6e523d246cd..48aba9109d5 100644 --- a/tests/post_training/pipelines/lm_weight_compression.py +++ b/tests/post_training/pipelines/lm_weight_compression.py @@ -285,7 +285,11 @@ def _dump_model_fp32(self) -> None: self.model_hf.save_pretrained(self.fp32_model_dir) self.model_hf._save_config(self.fp32_model_dir) elif self.backend == BackendType.TORCH: + _need_clean_dict = "forward" not in self.model_hf.__dict__ export_from_model(self.model_hf, self.fp32_model_dir, stateful=False, compression_option="fp32") + if _need_clean_dict and "forward" in self.model_hf.__dict__: + # WA for experimental tracing, clean up overwritten forward (same as in class method) + del self.model_hf.__dict__["forward"] def _compress(self): """ From 6f1bca373407c14975e9fe7f9067fa888f0eef87 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 21 Feb 2025 19:34:14 +0200 Subject: [PATCH 6/6] awq --- .../weight_compression/torch_backend.py | 12 +++- .../quantization/test_weights_compression.py | 58 +++++++++++++++---- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 9b8b5b05864..7fcb8892d7c 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -223,7 +223,14 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC def set_weight( self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor ): - update_parameter(node_with_weight.node_name, "weight", weight.data, model) + if is_experimental_torch_tracing_enabled(): + weight_node = get_const_node(node_with_weight, weight_port_id, graph) + module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name) + module = get_module_by_name(module_name, model.model) + weight_param = getattr(module, weight_attr_name) + weight_param.data = weight.data + else: + update_parameter(node_with_weight.node_name, "weight", weight.data, model) def insert_adapters( self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool @@ -393,6 +400,9 @@ def scale_insertion_command( sq_multiply = SQMultiply(scale.shape) sq_multiply.scale = scale + + if is_experimental_torch_tracing_enabled(): + return PT2InsertionCommand(target_points, sq_multiply) scale_node_name = f"{source_node.node_name}/awq_mul" return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) diff --git a/tests/torch2/function_hook/quantization/test_weights_compression.py b/tests/torch2/function_hook/quantization/test_weights_compression.py index 85999fa39b3..6082c0277b7 100644 --- a/tests/torch2/function_hook/quantization/test_weights_compression.py +++ b/tests/torch2/function_hook/quantization/test_weights_compression.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Dict, List import pytest import torch @@ -25,6 +25,7 @@ from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.quantization import compress_weights from nncf.quantization.advanced_parameters import AdvancedCompressionParameters +from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply from nncf.tensor import Tensor from nncf.tensor import TensorDataType from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor @@ -36,6 +37,8 @@ from nncf.torch.quantization.quantize_functions import unpack_int4 from nncf.torch.quantization.quantize_functions import unpack_uint4 from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression +from tests.torch.ptq.test_weights_compression import AWQActLinearModel +from tests.torch.ptq.test_weights_compression import AWQLinearModel from tests.torch.test_models.synthetic import ShortTransformer from tests.torch.test_tensor import cast_to @@ -273,11 +276,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize("mode", INT4_MODES) @pytest.mark.parametrize( "params", - ( - {"gptq": True}, - {"awq": True}, - {"lora_correction": True}, - ), + ({"gptq": True}, {"lora_correction": True}), ) def test_raise_error_with_unsupported_params_for_int4(mode, params): dummy_torch_model = EmptyModel() @@ -374,6 +373,18 @@ def get_matmul_model() -> torch.nn.Module: def get_sequential_matmul_model() -> torch.nn.Module: return SequentialMatmulModel() + @staticmethod + def get_model_for_test_scale_estimation(): + return LinearModel(torch.arange(0, 8 * 16, dtype=torch.float32).reshape(16, 8)) + + @staticmethod + def get_awq_model() -> torch.nn.Module: + return AWQLinearModel() + + @staticmethod + def get_awq_act_model(with_multiply, n_layers): + return AWQActLinearModel(with_multiply=with_multiply, n_layers=n_layers) + @staticmethod def to_tensor(t) -> torch.Tensor: return torch.tensor(t) @@ -394,10 +405,6 @@ def check_weights(model: torch.nn.Module, ref_ids: List[int]) -> None: if name in op_name: assert isinstance(op, INT4SymmetricWeightsDecompressor) - @staticmethod - def get_model_for_test_scale_estimation(): - return LinearModel(torch.arange(0, 8 * 16, dtype=torch.float32).reshape(16, 8)) - @staticmethod def get_scale_estimation_ref(): return torch.tensor( @@ -430,3 +437,34 @@ def get_decompressed_weight(compressed_model: torch.nn.Module, input: torch.Tens weight = compressed_model.linear.weight unpacked_w = compressed_model.get_submodule("__nncf_hooks.post_hooks.linear:weight__0.0")(weight) return Tensor(unpacked_w) + + @staticmethod + def get_ignored_scope_name() -> str: + return "linear6/linear/0" + + @staticmethod + def get_num_int4_nodes(model: torch.nn.Module) -> int: + num = 0 + for op in get_hook_storage(model).modules(): + num += isinstance(op, INT4SymmetricWeightsDecompressor) + return num + + @pytest.fixture(params=INT4_MODES) + def int4_mode(self, request): + return request.param + + @staticmethod + def get_num_multiply_from_awq(model): + awq_num = 0 + for module in model.modules(): + if isinstance(module, SQMultiply): + awq_num += 1 + return awq_num + + @staticmethod + def get_reference_for_test_awq_scale_reference() -> Dict[str, Tensor]: + return { + "linear3/linear/0": Tensor( + torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]]) + ) + }