Skip to content

Commit

Permalink
Use Attach TPC to Framework as part of MCT main track (#1308)
Browse files Browse the repository at this point in the history
TPC framework attachment mechanism change.
All the main API functions now get a framework-agnostic target platform description and use the AttachTpcToFrameork module to generate a framework-dependent quantization capabilities description.

The main changes include:

- Alignment of the existing TP models with the new schema (v1).
- Remove TP models v2-4 for IMX500 (moved to tests for simplicity of testing).
- Remove existing fw TPCs (no need for them anymore).
- Addition of attach2fw usage in all API functions.
- Some changes to the built-in OperatorSetsNames enum to support additional required operators.
- Massive tests modification and alignment to comply with the new TPC mechanism.
  • Loading branch information
ofirgo authored Jan 6, 2025
1 parent d6ee6d3 commit d895518
Show file tree
Hide file tree
Showing 138 changed files with 2,335 additions and 3,846 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
from model_compression_toolkit.core.common.quantization import quantization_config
from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG, CustomOpsetLayers
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,30 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass, field
from dataclasses import dataclass
import math
from enum import Enum
from typing import Optional, Dict, NamedTuple, List

from model_compression_toolkit import DefaultDict
from model_compression_toolkit.constants import MIN_THRESHOLD


class CustomOpsetLayers(NamedTuple):
"""
This struct defines a set of operators from a specific framework, which will be used to configure a custom operator
set in the TPC.
Args:
operators: a list of framework operators to map to a certain custom opset name.
attr_mapping: a mapping between an opset name to a mapping between its attributes' general names and names in
the framework.
"""

operators: List
attr_mapping: Optional[Dict[str, DefaultDict]] = None


class QuantizationErrorMethod(Enum):
"""
Method for quantization threshold selection:
Expand Down Expand Up @@ -86,6 +103,7 @@ class QuantizationConfig:
concat_threshold_update: bool = False
activation_bias_correction: bool = False
activation_bias_correction_threshold: float = 0.0
custom_tpc_opset_to_layer: Optional[Dict[str, CustomOpsetLayers]] = None


# Default quantization configuration the library use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
from model_compression_toolkit.verify_packages import FOUND_TF
Expand All @@ -27,6 +28,8 @@
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from tensorflow.keras.models import Model
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
AttachTpcToKeras

from model_compression_toolkit import get_target_platform_capabilities

Expand All @@ -36,7 +39,8 @@ def keras_resource_utilization_data(in_model: Model,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(
mixed_precision_config=MixedPrecisionQuantizationConfig()),
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> ResourceUtilization:
target_platform_capabilities: TargetPlatformModel = KERAS_DEFAULT_TPC
) -> ResourceUtilization:
"""
Computes resource utilization data that can be used to calculate the desired target resource utilization
for mixed-precision quantization.
Expand Down Expand Up @@ -78,6 +82,12 @@ def keras_resource_utilization_data(in_model: Model,

fw_impl = KerasImplementation()

# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
target_platform_capabilities = attach2keras.attach(
target_platform_capabilities,
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)

return compute_resource_utilization_data(in_model,
representative_data_gen,
core_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.verify_packages import FOUND_TORCH

if FOUND_TORCH:
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from torch.nn import Module
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
AttachTpcToPytorch

from model_compression_toolkit import get_target_platform_capabilities

Expand All @@ -39,7 +41,7 @@
def pytorch_resource_utilization_data(in_model: Module,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(),
target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC
target_platform_capabilities: TargetPlatformModel= PYTORCH_DEFAULT_TPC
) -> ResourceUtilization:
"""
Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
Expand Down Expand Up @@ -80,6 +82,12 @@ def pytorch_resource_utilization_data(in_model: Module,

fw_impl = PytorchImplementation()

# Attach tpc model to framework
attach2pytorch = AttachTpcToPytorch()
target_platform_capabilities = (
attach2pytorch.attach(target_platform_capabilities,
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer))

return compute_resource_utilization_data(in_model,
representative_data_gen,
core_config,
Expand Down
11 changes: 10 additions & 1 deletion model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
Expand All @@ -47,6 +48,8 @@
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
from model_compression_toolkit import get_target_platform_capabilities
from mct_quantizers.keras.metadata import add_metadata
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
AttachTpcToKeras

# As from TF2.9 optimizers package is changed
if version.parse(tf.__version__) < version.parse("2.9"):
Expand Down Expand Up @@ -152,7 +155,7 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
gptq_representative_data_gen: Callable = None,
target_resource_utilization: ResourceUtilization = None,
core_config: CoreConfig = CoreConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
"""
Quantize a trained Keras model using post-training quantization. The model is quantized using a
symmetric constraint quantization thresholds (power of two).
Expand Down Expand Up @@ -237,6 +240,12 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da

fw_impl = GPTQKerasImplemantation()

# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
target_platform_capabilities = attach2keras.attach(
target_platform_capabilities,
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)

tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
Expand Down
11 changes: 10 additions & 1 deletion model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.metadata import create_model_metadata
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.verify_packages import FOUND_TORCH

Expand All @@ -47,6 +48,9 @@
from torch.optim import Adam, Optimizer
from model_compression_toolkit import get_target_platform_capabilities
from mct_quantizers.pytorch.metadata import add_metadata
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
AttachTpcToPytorch

DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)

def get_pytorch_gptq_config(n_epochs: int,
Expand Down Expand Up @@ -140,7 +144,7 @@ def pytorch_gradient_post_training_quantization(model: Module,
core_config: CoreConfig = CoreConfig(),
gptq_config: GradientPTQConfig = None,
gptq_representative_data_gen: Callable = None,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
"""
Quantize a trained Pytorch module using post-training quantization.
By default, the module is quantized using a symmetric constraint quantization thresholds
Expand Down Expand Up @@ -209,6 +213,11 @@ def pytorch_gradient_post_training_quantization(model: Module,

fw_impl = GPTQPytorchImplemantation()

# Attach tpc model to framework
attach2pytorch = AttachTpcToPytorch()
target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
core_config.quantization_config.custom_tpc_opset_to_layer)

# ---------------------- #
# Core Runner
# ---------------------- #
Expand Down
10 changes: 8 additions & 2 deletions model_compression_toolkit/pruning/keras/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.pruning.pruner import Pruner
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
from model_compression_toolkit.logger import Logger
Expand All @@ -35,14 +35,16 @@
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from tensorflow.keras.models import Model
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
AttachTpcToKeras

DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)

def keras_pruning_experimental(model: Model,
target_resource_utilization: ResourceUtilization,
representative_data_gen: Callable,
pruning_config: PruningConfig = PruningConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
"""
Perform structured pruning on a Keras model to meet a specified target resource utilization.
This function prunes the provided model according to the target resource utilization by grouping and pruning
Expand Down Expand Up @@ -111,6 +113,10 @@ def keras_pruning_experimental(model: Model,
# Instantiate the Keras framework implementation.
fw_impl = PruningKerasImplementation()

# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
target_platform_capabilities = attach2keras.attach(target_platform_capabilities)

# Convert the original Keras model to an internal graph representation.
float_graph = read_model_to_graph(model,
representative_data_gen,
Expand Down
10 changes: 8 additions & 2 deletions model_compression_toolkit/pruning/pytorch/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from typing import Callable, Tuple
from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.pruning.pruner import Pruner
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
from model_compression_toolkit.logger import Logger
Expand All @@ -38,6 +38,8 @@
PruningPytorchImplementation
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from torch.nn import Module
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
AttachTpcToPytorch

# Set the default Target Platform Capabilities (TPC) for PyTorch.
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
Expand All @@ -46,7 +48,7 @@ def pytorch_pruning_experimental(model: Module,
target_resource_utilization: ResourceUtilization,
representative_data_gen: Callable,
pruning_config: PruningConfig = PruningConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
target_platform_capabilities: TargetPlatformModel = DEFAULT_PYOTRCH_TPC) -> \
Tuple[Module, PruningInfo]:
"""
Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
Expand Down Expand Up @@ -117,6 +119,10 @@ def pytorch_pruning_experimental(model: Module,
# Instantiate the Pytorch framework implementation.
fw_impl = PruningPytorchImplementation()

# Attach TPC to framework
attach2pytorch = AttachTpcToPytorch()
target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)

# Convert the original Pytorch model to an internal graph representation.
float_graph = read_model_to_graph(model,
representative_data_gen,
Expand Down
Loading

0 comments on commit d895518

Please sign in to comment.