Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2] weight_compression #3293

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 74 additions & 27 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, Iterable, List, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch

Expand All @@ -23,6 +23,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer
from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
Expand All @@ -35,6 +36,9 @@
from nncf.experimental.common.tensor_statistics.statistics import MeanMagnitudeTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
from nncf.experimental.torch2.commands import PT2InsertionCommand
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
Expand Down Expand Up @@ -186,8 +190,14 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
return activation_ports[0]

def get_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph
self,
node_with_weight: NNCFNode,
weight_port_id: int,
model: Union[GraphModelWrapper, torch.nn.Module],
graph: NNCFGraph,
) -> Tensor:
if isinstance(model, GraphModelWrapper):
model = model.model
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
weight = get_const_data(weight_node, model)
Expand All @@ -197,7 +207,11 @@ def get_weight(
return Tensor(weight)

def get_weight_dtype(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph
self,
node_with_weight: NNCFNode,
weight_port_id: int,
model: Union[GraphModelWrapper, torch.nn.Module],
graph: NNCFGraph,
) -> TensorDataType:
return self.get_weight(node_with_weight, weight_port_id, model, graph).dtype

Expand All @@ -209,7 +223,14 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor
):
update_parameter(node_with_weight.node_name, "weight", weight.data, model)
if is_experimental_torch_tracing_enabled():
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
module = get_module_by_name(module_name, model.model)
weight_param = getattr(module, weight_attr_name)
weight_param.data = weight.data
else:
update_parameter(node_with_weight.node_name, "weight", weight.data, model)

def insert_adapters(
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
Expand All @@ -229,13 +250,19 @@ def filter_func(point: StatisticPoint) -> bool:

def transform_model(
self,
model: NNCFNetwork,
model: Union[GraphModelWrapper, torch.nn.Module],
graph: NNCFGraph,
weight_compression_parameters: Iterable[WeightCompressionParameters],
precomputed_scales: Dict[str, Tensor] = None,
precomputed_zero_points: Dict[str, Tensor] = None,
lora_correction_algo: LoraCorrectionAlgorithm = None,
) -> NNCFNetwork:
if isinstance(model, GraphModelWrapper):
model_transformer = PT2ModelTransformer(model)
model = model.model
else:
model_transformer = PTModelTransformer(model)

transformation_layout = TransformationLayout()

for wc_params in weight_compression_parameters:
Expand Down Expand Up @@ -291,38 +318,55 @@ def transform_model(

# sets compressed tensor
# TODO:(AlexanderDokuchaev): update set_const_data
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)

if not isinstance(weight, torch.nn.Parameter):
msg = f"Weight is not a torch.nn.Parameter in the model by name {weight_name}."
raise nncf.InternalError(msg)

setattr(module, weight_attr_name, compressed_parameter)

consumer_nodes = graph.get_next_nodes(weight_node)
if len(consumer_nodes) > 1:
for c_node in consumer_nodes:
c_module = model.nncf.get_module_by_scope(Scope.from_str(c_node.layer_name))
for name, param in c_module.named_parameters(recurse=False, remove_duplicate=False):
if id(param) == id(weight):
setattr(c_module, name, compressed_parameter)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"

# inserts the weight decompressor into the model as the post hook on the model weight
transformation_layout.register(
PTSharedFnInsertionCommand(
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)],
decompressor,
decompressor_name,
if is_experimental_torch_tracing_enabled():
weight.requires_grad = False
weight.data = packed_tensor

transformation_layout.register(
PT2InsertionCommand(
[
PTTargetPoint(
TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":")
)
],
decompressor,
)
)
else:
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)

setattr(module, weight_attr_name, compressed_parameter)

consumer_nodes = graph.get_next_nodes(weight_node)
if len(consumer_nodes) > 1:
for c_node in consumer_nodes:
c_module = model.nncf.get_module_by_scope(Scope.from_str(c_node.layer_name))
for name, param in c_module.named_parameters(recurse=False, remove_duplicate=False):
if id(param) == id(weight):
setattr(c_module, name, compressed_parameter)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"

# inserts the weight decompressor into the model as the post hook on the model weight
transformation_layout.register(
PTSharedFnInsertionCommand(
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)],
decompressor,
decompressor_name,
)
)
)

# apply transformations
transformed_model = PTModelTransformer(model).transform(transformation_layout)
transformed_model = model_transformer.transform(transformation_layout)

return transformed_model

Expand Down Expand Up @@ -356,6 +400,9 @@ def scale_insertion_command(

sq_multiply = SQMultiply(scale.shape)
sq_multiply.scale = scale

if is_experimental_torch_tracing_enabled():
return PT2InsertionCommand(target_points, sq_multiply)
scale_node_name = f"{source_node.node_name}/awq_mul"
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)

Expand Down
6 changes: 4 additions & 2 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def compress_weights(

if backend == BackendType.TORCH:
from nncf.torch.model_creation import is_wrapped_model
from nncf.torch.model_creation import wrap_model
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl

if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]:
Expand All @@ -529,7 +529,7 @@ def compress_weights(
raise nncf.ParameterNotSupportedError(msg)

if is_wrapped_model(model):
if not model.nncf.trace_parameters:
if isinstance(model, NNCFNetwork) and not model.nncf.trace_parameters:
msg = (
"Tracing capabilities with tracing parameters are required in the PyTorch model "
"for nncf.compress_weights(). Please wrap the model using "
Expand All @@ -541,6 +541,8 @@ def compress_weights(
msg = "Please provide a dataset of at least one element for PyTorch model tracing."
raise nncf.ValidationError(msg)
else:
from nncf.torch.model_creation import wrap_model

example_input = next(iter(dataset.get_inference_data()))
model = wrap_model(model, example_input=example_input, trace_parameters=True)
if mode in (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM):
Expand Down
17 changes: 15 additions & 2 deletions nncf/torch/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.config.extractors import extract_algorithm_names
from nncf.config.extractors import has_input_info_field
from nncf.config.telemetry_extractors import CompressionStartedFromConfig
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.telemetry import tracked_function
from nncf.telemetry.events import NNCF_PT_CATEGORY
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
Expand Down Expand Up @@ -350,6 +351,16 @@ def wrap_model(
:param trace_parameters: Whether to trace model parameters. Default is False.
:return: A model wrapped by NNCFNetwork.
"""
if is_experimental_torch_tracing_enabled():
if not trace_parameters:
msg = "The 'trace_parameters=False' option is not supported in the experimental tracing mode."
raise nncf.InternalError(msg)
from nncf.experimental.torch2.function_hook import wrap_model
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

wrapped_model = GraphModelWrapper(wrap_model(model), example_input=example_input)
return wrapped_model

if not isinstance(model, torch.nn.Module):
msg = (
f"The provided model type {type(model)} is incompatible. "
Expand All @@ -370,12 +381,14 @@ def wrap_model(

def is_wrapped_model(model: torch.nn.Module) -> bool:
"""
Check that the model was wrapped by NNCFNetwork.
Check that the model was wrapped by NNCFNetwork or GraphModelWrapper.

:param model: A model.
:return: True if the model is wrapped, False otherwise.
"""
return isinstance(model, NNCFNetwork)
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

return isinstance(model, (NNCFNetwork, GraphModelWrapper))


@tracked_function(
Expand Down
11 changes: 8 additions & 3 deletions nncf/torch/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
# limitations under the License.

from copy import deepcopy
from typing import Optional
from typing import Optional, Union

import torch

import nncf
from nncf.common.factory import NNCFGraphFactory
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.parameters import BackupMode
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
Expand Down Expand Up @@ -85,7 +86,7 @@ def quantize_impl(


def compress_weights_impl(
model: torch.nn.Module,
model: Union[GraphModelWrapper, torch.nn.Module],
dataset: Dataset,
mode: CompressWeightsMode,
ratio: float,
Expand Down Expand Up @@ -120,4 +121,8 @@ def compress_weights_impl(
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
return compression_algorithm.apply(model, graph, dataset=dataset)

compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
if isinstance(compressed_model, GraphModelWrapper):
compressed_model = compressed_model.model
return compressed_model
4 changes: 4 additions & 0 deletions tests/post_training/pipelines/lm_weight_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ def _dump_model_fp32(self) -> None:
self.model_hf.save_pretrained(self.fp32_model_dir)
self.model_hf._save_config(self.fp32_model_dir)
elif self.backend == BackendType.TORCH:
_need_clean_dict = "forward" not in self.model_hf.__dict__
export_from_model(self.model_hf, self.fp32_model_dir, stateful=False, compression_option="fp32")
if _need_clean_dict and "forward" in self.model_hf.__dict__:
# WA for experimental tracing, clean up overwritten forward (same as in class method)
del self.model_hf.__dict__["forward"]

def _compress(self):
"""
Expand Down
Loading