Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bdellabe/awq modifier v3 #1177

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
749 changes: 749 additions & 0 deletions src/llmcompressor/modifiers/awq/base.py

Large diffs are not rendered by default.

48 changes: 20 additions & 28 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from accelerate.utils import align_module_device
from loguru import logger
from pydantic import ConfigDict
from torch.nn import Module

from llmcompressor.core import State
Expand Down Expand Up @@ -99,13 +100,16 @@ class SmoothQuantModifier(Modifier):
to use the default tensor_module_forward
"""

# Allow arbitrary types because AWQMapping has field of type torch.nn.Module
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

smoothing_strength: float = 0.5
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None

resolved_mappings_: Optional[List] = None
resolved_mappings_: Optional[List[SmoothQuantMapping]] = None
scales_: Optional[Dict] = None

def on_initialize(self, state: State, **kwargs) -> bool:
Expand Down Expand Up @@ -166,7 +170,7 @@ def _infer_mappings_from_model(
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List:
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
Expand Down Expand Up @@ -289,22 +293,16 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)
with align_module_device(module):
if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
Expand All @@ -329,15 +327,9 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)
with align_module_device(layer):
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .base import *
from .min_max import *
from .mse import *
from .rtn import *
58 changes: 58 additions & 0 deletions src/llmcompressor/observers/rtn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import calculate_qparams
from compressed_tensors.utils import deprecated

from llmcompressor.observers.base import Observer
from llmcompressor.pytorch.utils import pseudo_quantize_tensor

__all__ = ["RoundToNearestObserver"]


@Observer.register("rtn")
class RoundToNearestObserver(Observer):
"""
Implements a quantization observer that calculates scale and zero point based on the
minimum and maximum values of the tensor being observed. If averaging_constant is
specified, then the scales are updated using a moving average
"""

def calculate_qparams(
self,
observed: torch.Tensor,
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:param tensor_id: Optional id if different ranges of observed tensors are
passed, useful for sharding tensors by group_size
:return: tuple of scale and zero point derived from the observed tensor
"""

_, scales, zp = pseudo_quantize_tensor(
observed,
symmetric=self.quantization_args.symmetric,
bit_width=self.quantization_args.num_bits,
group_size=self.quantization_args.group_size or -1,
)
return (scales, zp)

def get_qparams_along_dim(
self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None
):
"""
Calculate quantization parameters along the specified dimension
"""
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
)
103 changes: 103 additions & 0 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Utility / helper functions
"""

import functools
import inspect
import os
import random
import re
Expand Down Expand Up @@ -85,6 +87,10 @@
"detach",
"adjust_quantization_for_onnx_export",
"get_dependency_order",
"pseudo_quantize_tensor",
"pseudo_dequantize_linear",
"tensor_forward_with_input_args",
"sanitize_kwargs_for_module",
]


Expand Down Expand Up @@ -680,6 +686,43 @@ def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor:
return -1.0 * newly_masked + newly_unmasked


def sanitize_kwargs_for_module(
kwargs: Dict[str, Any], module: Module
) -> Dict[str, Any]:
"""
Sanitize the kwargs for a Module by removing any keys that are not
in the signature of the forward method.
:param kwargs: the kwargs to sanitize
:param module: the Module to sanitize the kwargs for
:return: the sanitized kwargs for the callable object
"""
if not isinstance(kwargs, dict):
raise TypeError(f"Expected a dictionary as kwargs, but got {kwargs}")

allowed_params = inspect.signature(module.forward).parameters
return {key: value for key, value in kwargs.items() if key in allowed_params}


def tensor_forward_with_input_args(
module: Module, inputs: Tensor, input_kwargs: Dict[str, Any]
) -> Tensor:
"""
Forward the given inputs through the given module with the given input_kwargs.
This function is a wrapper around tensors_module_forward that ensures that the
input_kwargs are sanitized and passed to the module as keyword arguments during
the forward pass.
:param module: the module to forward the inputs through
:param inputs: the inputs to forward through the module
:param input_kwargs: the keyword arguments to pass to the
module during the forward pass
:return: the output of the module after forwarding the inputs through it
"""
inputs = inputs.to(next(module.parameters()).device)
input_kwargs = sanitize_kwargs_for_module(input_kwargs, module)

return tensors_module_forward(inputs, functools.partial(module, **input_kwargs))


##############################
#
# pytorch module helper functions
Expand Down Expand Up @@ -1194,3 +1237,63 @@ def swap_modules(
parent.__setattr__(sections[-1], submodule_to_replace)

return cur


def pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
org_w_shape = w.shape
if group_size > 0:
assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!"
w = w.reshape(-1, group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0

# zero point quantization
if not symmetric:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**bit_width - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (bit_width - 1) - 1
min_int = -(2 ** (bit_width - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales

assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0

scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)

return w, scales, zeros


def pseudo_dequantize_linear(
w: torch.Tensor,
scales: torch.Tensor,
zeros: Optional[torch.Tensor] = None,
symmetric: bool = False,
):
# get repeated count
repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

# dequantize
if not symmetric:
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
w = (w.weight.data - zeros) * scales
else:
w = w.weight.data * scales

return w
1 change: 1 addition & 0 deletions src/llmcompressor/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .flickr_30k import Flickr30K
from .gsm8k import GSM8KDataset
from .open_platypus import OpenPlatypusDataset
from .pile import PileValDataset
from .ptb import PtbDataset
from .ultrachat_200k import UltraChatDataset
from .wikitext import WikiTextDataset
27 changes: 27 additions & 0 deletions src/llmcompressor/transformers/finetune/data/pile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="mit-han-lab/pile-val-backup", alias="pile_val")
class PileValDataset(TextGenerationDataset):
"""
Child text generation class for "The Pile" dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
"""

def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.text_column = "text"
data_args.dataset = "mit-han-lab/pile-val-backup"
super().__init__(data_args=data_args, split=split, processor=processor)

def dataset_template(self, sample):
return {"text": sample["text"].strip()}
5 changes: 4 additions & 1 deletion src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import List, Optional

import datasets
import torch
from loguru import logger
from torch.utils.data import Dataset
Expand Down Expand Up @@ -106,7 +107,9 @@ def _get_split_name(inp_str):
)
for split_name, split_str in splits.items():
dataset = self._data_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
if isinstance(dataset, datasets.Dataset) or (
hasattr(dataset, "column_names") and "input_ids" in dataset.column_names
):
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
Expand Down
20 changes: 20 additions & 0 deletions src/llmcompressor/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"get_layers_params",
"get_matching_layer",
"get_no_split_params",
"get_parent_by_name",
]


Expand Down Expand Up @@ -338,3 +339,22 @@ def get_no_split_params(module: Module) -> Union[str, List[str]]:
if hasattr(model, "_no_split_modules"):
return model._no_split_modules
return ALL_TARGET


def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]:
"""
Get the parent layer of a layer by name.
:param layer_name: Name of the layer to find the parent of.
:param model: Model to search for the parent layer.
:return: Tuple containing the name of the parent layer
and the parent layer itself.
"""
if not any(layer_name == name for name, _ in model.named_modules()):
raise ValueError(f"Layer '{layer_name}' not found in model")

parent_name_parts = layer_name.split(".")[:-1]
if not parent_name_parts:
return "", model

parent_name = ".".join(parent_name_parts)
return get_layer(parent_name, model)
Empty file.
Loading