Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
squash
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
kylesayrs committed Jan 13, 2025
1 parent 0535613 commit 08d700c
Showing 19 changed files with 915 additions and 1,250 deletions.
13 changes: 0 additions & 13 deletions examples/automodelforcausallm/README.md

This file was deleted.

11 changes: 0 additions & 11 deletions examples/automodelforcausallm/run_automodelforcausallm.py

This file was deleted.

359 changes: 95 additions & 264 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import contextlib
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from loguru import logger
from torch.nn import Module
from tqdm import tqdm
from pydantic import Field, PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
accumulate_hessian,
make_empty_hessian,
sparsify_weight,
)
from llmcompressor.utils.metric_logging import CompressionLogger

__all__ = ["SparseGPTModifier"]


class SparseGPTModifier(Modifier):
class SparseGPTModifier(SparsityModifierMixin, Modifier):
"""
Modifier for applying the one-shot SparseGPT algorithm to a model
@@ -67,267 +68,97 @@ class SparseGPTModifier(Modifier):
previously pruned model, defaults to False.
"""

sparsity: Union[float, List[float]] = 0.0
sparsity_profile: Optional[str] = None
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None
# modifier arguments
sparsity: Optional[Union[float, List[float]]] = None
mask_structure: str = "0:0"
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None # misspelling?
sparsity_profile: Optional[str] = None # deprecated
block_size: int = 128
dampening_frac: Optional[float] = 0.01
preserve_sparsity_mask: bool = False

model: Optional[Any] = None
layer_compressors_: Optional[List[Any]] = None
prunen_: Optional[int] = None
prunem_: Optional[int] = None
compressible_layers_: Optional[List] = None

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
if self.sparsity == 0.0:
raise ValueError(
"To use the SparseGPTModifier, target sparsity must be > 0.0"
)

modifiable_model = state.model
calibration_dataloader = state.data.calib

if self.targets is None:
# if no targets are provided, default to the modules that shouldn't be
# split by FSDP. For Transformers models this is equivalent to the
# decoder layers (ie LlamaDecoderLayer)
self.targets = get_no_split_params(modifiable_model)

self.initialize_compression(modifiable_model, calibration_dataloader)
self.apply_compression(calibration_dataloader)

return True

def initialize_compression(
preserve_sparsity_mask: bool = False # deprecate?
offload_hessians: bool = False

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
module_targets: Union[str, List[str]] = ["Linear"]
targets: Union[str, List[str], None] = None # deprecated, clones sequential_targets
sequential_targets: Union[str, List[str], None] = None
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

def calibrate_module(
self,
model: Module,
dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Setup for SparseGPT, initializes the model, device,
and other parameters, also initilializes the
compressible layers of model, and sets the device
:param model: model to initialize for compression
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_compressors_ = []
self._infer_mask_block_size()

if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl":
logger.info(
"Inferring layer-wise sparsities from "
f"{len(dataloader)} calibration samples..."
)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
logger.info(f"Preparing {name} for compression")
if isinstance(self.sparsity, Dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, List):
layer_sparsity = self.sparsity[idx]
else: # float
layer_sparsity = self.sparsity
args = self._pruning_arguments(layer_sparsity)
comp_cls = self._compression_class()
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)
if not self.sequential_update:
# add all batch processing hooks before the forward pass
compressor.pre_compress()
self.layer_compressors_.append(compressor)

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
compressible layer names
:precondition: self.model is set and is a torch.nn.Module
:return: dictionary of modules to compress
"""
if not isinstance(self.model, Module):
raise ValueError(
"`self.model` must be a PyTorch Module to use "
f"the {self.__class__.__qualname__} modifier but got "
f"{type(self.model)} instead"
# Assume that the first argument is the input
inp = args[0]

# Initialize hessian if not present
if module not in self._num_samples:
device = get_execution_device(module)
self._hessians[module] = make_empty_hessian(module, device=device)
self._num_samples[module] = 0

# Accumulate hessian with input with optional offloading
with self._maybe_onload_hessian(module):
self._hessians[module], self._num_samples[module] = accumulate_hessian(
inp,
module,
self._hessians[module],
self._num_samples[module],
)

return get_layers(self.targets, self.model)

@torch.no_grad()
def apply_compression(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
def on_sequential_batch_end(self):
"""
Run Wanda on the loaded model, using dataloader as calibration data
:param dataloader: calibration data for WANDA
Sparsify modules
TODO: implement with event callback
"""
class_name = self.__class__.__name__.replace("PyTorch", "")
logger.info(
f"Running {class_name} calibration with "
f"{len(dataloader) if dataloader else 0} samples..."
)
if not self.sequential_update:
# in non-sequential mode we run one forward batch for all modules
run_calibration_forward(self.model, dataloader, mask_padding=True)

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
layer_sparsity = layer_compressor.args["sparsity"]
logger.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)

# Prune/quantize using SparseGPT
if self.sequential_update:
# in sequential mode we run one forward pass for each module we
# want to compress, this will be really slow but allows compression in
# earlier layers to affect later layers
layer_compressor.pre_compress()
logger.info(f"Calibrating {layer_compressor.name}...")
run_calibration_forward(self.model, dataloader, mask_padding=True)
layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()
torch.cuda.empty_cache()

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
# single sparsity will be applied to all layers
return

target_layers = list(self.compressible_layers_.keys())

if len(target_layers) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of sparsities. "
"Received {len(target_layers)} layers and "
f"{len(self.sparsity)} sparsities"
)

def _pruning_arguments(self, sparsity):
"""
Gather the parameters needed for root module compression in a dict
:param sparsity: target sparsity
:return: dict of params for pruning
"""
return {
"sparsity": sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"preserve_sparsity_mask": self.preserve_sparsity_mask,
}

def _compression_class(self):
"""
:return: wrapper class used for root modules of this compression class
"""
return SparseGptWrapper

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.
:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, activations):
sparsegpt_groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z])

del activations
torch.cuda.empty_cache()

outlier_ratios = {}
for group in sparsegpt_groups:
threshold = torch.mean(sparsegpt_groups[group]) * self.owl_m
outlier_ratios[group] = (
100
* (sparsegpt_groups[group] > threshold).sum().item()
/ sparsegpt_groups[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
num_samples = self._num_samples[module]

logger.info(f"Sparsifying {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
CompressionLogger(module) as comp_logger,
):
loss, sparsified_weight = sparsify_weight(
module=module,
hessians_dict=self._hessians,
sparsity=sparsity,
prune_n=self._prune_n,
prune_m=self._prune_m,
block_size=self.block_size,
dampening_frac=self.dampening_frac,
preserve_sparsity_mask=self.preserve_sparsity_mask, # TODO: should we deprecate this? GPTQ just checks against a sparsity threshold
)
comp_logger.set_loss(loss)

update_offload_parameter(module, "weight", sparsified_weight)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
# self._hessians[module] already deleted by sparsify_weight
del self._num_samples[module]

device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()
@contextlib.contextmanager
def _maybe_onload_hessian(self, module: torch.nn.Module):
if self.offload_hessians:
device = get_execution_device(module)
self._hessians[module] = self._hessians[module].to(device=device)

self.remove_hooks()
yield

return acts
if self.offload_hessians:
if module in self._hessians: # may have been deleted in context
self._hessians[module] = self._hessians[module].to(device="cpu")
235 changes: 235 additions & 0 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import warnings
from collections import defaultdict
from functools import partial
from typing import List, Tuple, Union

import numpy as np
import torch
from loguru import logger
from pydantic import field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
match_class,
match_targets,
)


class SparsityModifierMixin(HooksMixin):
@field_validator("sequential_update", mode="before", check_fields=False)
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)

return True

@field_validator("sparsity_profile", mode="before", check_fields=False)
def validate_sparsity_profile(cls, value) -> None:
if value is not None:
warnings.warn(
"`sparsity_profile` is deprecated, use `owl_m` and `owl_lmbda`"
)
return None

@model_validator(mode="after")
def validate_model_after(model: "Modifier") -> "Modifier":
sparsity = model.sparsity
owl_m = model.owl_m
owl_lmbda = model.owl_lmbda
mask_structure = model.mask_structure
targets = model.targets
sequential_targets = model.sequential_targets

if (owl_m is not None) ^ (owl_lmbda is not None):
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")

if owl_m is not None and sparsity is not None:
raise ValueError("Cannot provide both sparsity and owl parameters")

if targets is not None:
warnings.warn(
"`targets` is deprecated, use `module_targets` and `sequential_targets`"
)
if sequential_targets is not None:
raise ValueError("Cannot use both `targets` and `sequential_targets`")
model.sequential_targets = targets
model.targets = None

model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)

return model

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
model = state.model
dataloader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)

# infer layer sparsities
if self.owl_m is not None and self.owl_lmbda is not None:
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()

# register hooks
for index, (name, layer) in enumerate(
get_layers(self.sequential_targets, model).items()
):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, list):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in layer.named_modules():
if (
match_targets(module, self.module_targets)[0]
or match_class(module, self.module_targets)[0]
):
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = state.data.calib.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib, self)
return True

return True

def _infer_sequential_targets(
self, model: torch.nn.Module
) -> Union[str, List[str]]:
if self.sequential_targets is None:
return get_no_split_params(model)
if isinstance(self.sequential_targets, str):
return [self.sequential_targets]
return self.sequential_targets

def _infer_owl_layer_sparsity(self, activations):
groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
groups[name] = torch.cat([item.flatten().cpu() for item in z])

del activations
torch.cuda.empty_cache()

outlier_ratios = {}
for group in groups:
threshold = torch.mean(groups[group]) * self.owl_m
outlier_ratios[group] = (
100 * (groups[group] > threshold).sum().item() / groups[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

def _get_activations(self, model, dataloader, nsamples=128):
acts = defaultdict(int)

def save_acts(_module, input, name):
nonlocal acts
if isinstance(input, tuple):
input = input[0]
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()

# TODO: only add hooks to target modules
hooks = set(
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
for name, mod in model.named_modules()
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
)
with HooksMixin.disable_hooks(keep=hooks):
run_basic(model, dataloader)
self.remove_hooks(hooks)

return acts

def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
n, m = mask_structure.split(":")
return int(n), int(m)
225 changes: 225 additions & 0 deletions src/llmcompressor/modifiers/obcq/sgpt_sparsify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import math
from typing import Dict, Optional, Tuple

import torch
import transformers
from compressed_tensors.quantization.lifecycle.forward import forward_quantize

from llmcompressor.utils import getattr_chain

SGPT_PRECISION = torch.float32


def make_empty_hessian(
module: torch.nn.Module, device: Optional[torch.device] = None
) -> torch.Tensor:
weight = module.weight
num_columns = weight.shape[1]
device = device if device is not None else weight.device
return torch.zeros((num_columns, num_columns), device=device, dtype=SGPT_PRECISION)


def accumulate_hessian(
inp: torch.Tensor,
module: torch.nn.Module,
H: Optional[torch.Tensor] = None,
num_samples: int = 1,
) -> Tuple[torch.Tensor, int]:
inp = inp.to(device=H.device)
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

num_added = inp.shape[0] # note this is the number of dataset samples, not
# multiplied by the sequence length

if isinstance(module, (torch.nn.Linear, transformers.Conv1D)):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()

if isinstance(module, torch.nn.Conv2d):
unfold = torch.nn.Unfold(
module.kernel_size,
dilation=module.dilation,
padding=module.padding,
stride=module.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)

H *= num_samples / (num_samples + num_added)
num_samples += num_added

inp = inp.to(dtype=SGPT_PRECISION)
inp = math.sqrt(2 / num_samples) * inp
H += inp.matmul(inp.t())

return H, num_samples


def sparsify_weight(
module: torch.nn.Module,
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
sparsity: float,
prune_n: int,
prune_m: int,
block_size: int,
dampening_frac: float,
preserve_sparsity_mask: bool,
) -> torch.Tensor:
"""
Run pruning and quantization(if applicable) on the layer up to the target
sparsity value.
:param module: module with weight being sparsified
:param hessian_dict: dictionary containing preaccumulated hessian for sparsification
:param sparsity: target sparsity to reach for layer
:param prune_n: N for N:M pruning
:param prune_m: M for N:M pruning
:param block_size: Number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Extend or ignore the base sparsity mask
"""
final_shape = module.weight.shape
final_dtype = module.weight.dtype
W = module.weight.clone()
H = hessians_dict[module] # unfortunately python does not have a `move` keyword
del hessians_dict[module] # so we have to delete the original reference manually

# if this module is quantized, perform RTN quantization before sparsifying
args_loc = "quantization_scheme.weights"
weight_quant_args = getattr_chain(module, args_loc, None)
if weight_quant_args is not None:
W = forward_quantize(module, W, "weight", weight_quant_args)

# standardize shape and dtype
if isinstance(module, torch.nn.Conv2d):
W = W.flatten(1)
elif isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
W = W.to(dtype=SGPT_PRECISION)
num_rows = W.shape[0]
num_columns = W.shape[1]

# mask dead hessian values
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

# compute inverse hessian in place to save memory
try:
damp = dampening_frac * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0], device=H.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
except torch._C._LinAlgError:
raise torch._C._LinAlgError(
"Failed to invert hessian due to numerical instability. Consider "
"increasing GPTQModifier.dampening_frac, increasing the number "
"of calibration samples, or shuffling the calibration dataset"
)

# sparsity mask
# TODO: consider computing sparsity mask in the same way and place as gptq
mask = None
if preserve_sparsity_mask:
# compute existing sparsity mask
mask = torch.where(
W == 0,
torch.tensor(1, dtype=torch.bool),
torch.tensor(0, dtype=torch.bool),
)
current_sparsity = mask.sum() / W.numel()
if current_sparsity > sparsity:
raise ValueError(
"The target sparsity is lower than the sparsity "
"of the base model. Please retry "
"after turning preserve_sparsity_mask=False"
)

losses = torch.zeros(num_rows, device=module.weight.device)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, num_columns, block_size):
i2 = min(i1 + block_size, num_columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

if prune_n == 0:
if mask is not None:
mask1 = mask[:, i1:i2]
if int(W1.numel() * sparsity) > mask1.sum():
# target sparsity is higher than base sparsity, extend mask1
tmp = (
(~mask[:, i1:i2])
* W1**2
/ (torch.diag(Hinv1).reshape((1, -1))) ** 2
)
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
else:
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
else:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
mask1 = torch.zeros_like(W1) == 1

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if prune_n != 0 and i % prune_m == 0:
tmp = (
W1[:, i : (i + prune_m)] ** 2
/ (torch.diag(Hinv1)[i : (i + prune_m)].reshape((1, -1))) ** 2
)
if mask is not None:
tmp = tmp * (~mask[:, i : (i + prune_m)])

mask1.scatter_(
1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True
)

q = w.clone()
q[mask1[:, i]] = 0

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

W[:, i1:i2] = Q1
losses += torch.sum(Losses1, 1) / 2

if preserve_sparsity_mask:
# respect the sparsity of other groups
# really not needed, but kept for explicitness
W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:])
else:
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

if isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

# perform RTN quantization
if weight_quant_args is not None:
W = forward_quantize(module, W, "weight", weight_quant_args)

loss = torch.sum(losses).item()
return loss, W
Empty file.
226 changes: 0 additions & 226 deletions src/llmcompressor/modifiers/obcq/utils/sgpt_wrapper.py

This file was deleted.

344 changes: 78 additions & 266 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,31 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from loguru import logger
from torch.nn import Module
from tqdm import tqdm
from pydantic import Field, PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.pruning.wanda.utils.wanda_sparsify import (
accumulate_row_scalars,
make_empty_row_scalars,
sparsify_weight,
)
from llmcompressor.utils.metric_logging import CompressionLogger

__all__ = ["WandaPruningModifier"]


class WandaPruningModifier(Modifier):
class WandaPruningModifier(SparsityModifierMixin, Modifier):
"""
Modifier for applying the one-shot WANDA algorithm to a model
from the paper: https://arxiv.org/abs/2306.11695
Lifecycle:
- on_initialize
- initialize_compression()
- compressible_layers()
- LayerCompressor.pre_compress()
- apply_compression()
- run_calibration_forward()
- LayerCompressor.compress()
- LayerCompressor.post_compress()
- LayerCompressor.revert_layer_wrappers()
- on_finalize
:param sparsity: Sparsity to compress model to
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
@@ -48,254 +36,78 @@ class WandaPruningModifier(Modifier):
to compress every layer in the model
"""

sparsity: Union[float, List[float]] = 0.0
sparsity_profile: Optional[str] = None
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None
# sparsity arguments
sparsity: Optional[Union[float, List[float]]] = None
mask_structure: str = "0:0"
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
model: Optional[Any] = None
layer_compressors_: List = None

compressible_layers_: Optional[List] = None
prunen_: Optional[int] = None
prunem_: Optional[int] = None

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the WANDA algorithm on the current state
:param state: session state storing input model and calibration data
:param kwargs: Unused, kept to conform to the parent method signature
"""
modifiable_model = state.model
calibration_dataloader = state.data.calib

if self.targets is None:
# if no targets are provided, default to the modules that shouldn't be
# split by FSDP. For Transformers models this is equivalent to the
# decoder layers (ie LlamaDecoderLayer)
self.targets = get_no_split_params(modifiable_model)

self.initialize_compression(modifiable_model, calibration_dataloader)
self.apply_compression(calibration_dataloader)

return True

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
compressible layer names
:precondition: self.model is set and is a torch.nn.Module
:return: dictionary of modules to compress
"""
if not isinstance(self.model, Module):
raise ValueError(
"`self.model` must be a torch.nn.Module to use "
f"the {self.__class__.__qualname__} modifier but got "
f"{type(self.model)} instead"
)

return get_layers(self.targets, self.model)

def initialize_compression(
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None # misspelling?
sparsity_profile: Optional[str] = None # deprecated

# data pipeline arguments
module_targets: Union[str, List[str]] = ["Linear"]
targets: Union[str, List[str], None] = None # deprecated, clones sequential_targets
sequential_targets: Union[str, List[str], None] = None
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_row_scalars: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(
default_factory=dict
)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

def calibrate_module(
self,
model: Module,
dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Setup for WANDA, initializes the model, device,
and other parameters, also initilializes the
compressible layers of model, and sets the device
:param model: model to initialize for compression
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_compressors_ = []
self._infer_mask_block_size()

if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl":
logger.info(
"Inferring layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
logger.info(f"Preparing {name} for compression")
if isinstance(self.sparsity, Dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, List):
layer_sparsity = self.sparsity[idx]
else: # float
layer_sparsity = self.sparsity
args = self._pruning_arguments(layer_sparsity)
comp_cls = self._compression_class()
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)
if not self.sequential_update:
# add all batch processing hooks before the forward pass
compressor.pre_compress()
self.layer_compressors_.append(compressor)

@torch.no_grad()
def apply_compression(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run Wanda on the loaded model, using dataloader as calibration data
:param dataloader: calibration data for WANDA
"""
class_name = self.__class__.__name__.replace("PyTorch", "")
logger.info(
f"Running {class_name} calibration with " f"{len(dataloader)} samples..."
# Assume that the first argument is the input
inp = args[0]

# Initialize row scalars if not present
if module not in self._num_samples:
device = get_execution_device(module)
self._row_scalars[module] = make_empty_row_scalars(module, device=device)
self._num_samples[module] = 0

# Accumulate scalars using data
self._row_scalars[module], self._num_samples[module] = accumulate_row_scalars(
inp,
module,
self._row_scalars[module],
self._num_samples[module],
)
if not self.sequential_update:
# in non-sequential mode we run one forward batch for all modules
run_calibration_forward(self.model, dataloader, mask_padding=True)

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
layer_sparsity = layer_compressor.args["sparsity"]
logger.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)

# Prune/quantize using the layer compressor
if self.sequential_update:
# in sequential mode we run one forward pass for each module we
# want to compress, this will be really slow but allows compression in
# earlier layers to affect later layers
layer_compressor.pre_compress()
logger.info(f"Calibrating {layer_compressor.name}...")
run_calibration_forward(self.model, dataloader, mask_padding=True)
layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()
torch.cuda.empty_cache()

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
# single sparsity will be applied to all layers
return

target_layers = list(self.compressible_layers_.keys())

if len(target_layers) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(target_layers)} layers and "
f"{len(self.sparsity)} sparsities"
)

def _pruning_arguments(self, sparsity) -> Dict[str, Any]:
"""
Gather the parameters needed for root module compression in a dict
:param sparsity: target sparsity
:return: dict of params for pruning
"""
return {
"sparsity": sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
}

def _compression_class(self):
def on_sequential_batch_end(self):
"""
:return: wrapper class used for root modules of this compression class
Sparsify modules
TODO: implement with event callback
"""
return WandaWrapper

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.
:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, activations):
wanda = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
wanda[name] = torch.cat([item.flatten().cpu() for item in z])

del activations
torch.cuda.empty_cache()

outlier_ratios = {}
for group in wanda:
threshold = torch.mean(wanda[group]) * self.owl_m
outlier_ratios[group] = (
100 * (wanda[group] > threshold).sum().item() / wanda[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
num_samples = self._num_samples[module]

logger.info(f"Sparsifying {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
CompressionLogger(module),
):
sparsified_weight = sparsify_weight(
module=module,
row_scalars_dict=self._row_scalars,
sparsity=sparsity,
prune_n=self._prune_n,
prune_m=self._prune_m,
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")

device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()

self.remove_hooks()
update_offload_parameter(module, "weight", sparsified_weight)

return acts
# self._row_scalars[module] already deleted by sparsify_weight
del self._num_samples[module]
111 changes: 111 additions & 0 deletions src/llmcompressor/modifiers/pruning/wanda/utils/wanda_sparsify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Dict, Optional

import torch
import transformers

WANDA_PRECISION = torch.float32

# TODO: are these needed?
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.allow_tf32 = False


def make_empty_row_scalars(
module: torch.nn.Module, device: Optional[torch.device] = None
) -> torch.Tensor:
weight = module.weight
num_columns = weight.shape[1]
device = device if device is not None else weight.device
return torch.zeros(num_columns, device=device)


def accumulate_row_scalars(
inp: torch.Tensor,
module: torch.nn.Module,
row_scalars: Optional[torch.Tensor],
num_samples: int,
):
inp = inp.to(device=row_scalars.device)
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

num_added = inp.shape[0] # note this is the number of dataset samples, not
# multiplied by the sequence length

if isinstance(module, (torch.nn.Linear, transformers.Conv1D)):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()

if isinstance(module, torch.nn.Conv2d):
unfold = torch.nn.Unfold(
module.kernel_size,
dilation=module.dilation,
padding=module.padding,
stride=module.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)

row_scalars *= num_samples / (num_samples + num_added)
num_samples += num_added

inp = inp.type(WANDA_PRECISION)
row_scalars += torch.norm(inp, p=2, dim=1) ** 2 / num_samples

return row_scalars, num_samples


def sparsify_weight(
module: torch.nn.Module,
row_scalars_dict: Dict[torch.nn.Module, torch.Tensor],
sparsity: float,
prune_n: int,
prune_m: int,
) -> torch.Tensor:
"""
Run pruning on the layer up to the target sparsity value.
:param sparsity: target sparsity to reach for layer
:param prunen: N for N:M pruning
:param prunem: M for N:M pruning
"""
final_shape = module.weight.shape
final_dtype = module.weight.dtype
W = module.weight.data.clone()
if isinstance(module, torch.nn.Conv2d):
W = W.flatten(1)
if isinstance(module, transformers.Conv1D):
W = W.t()
W = W.to(dtype=WANDA_PRECISION)
S = row_scalars_dict[module] # unfortunately python does not have a `move` keyword
del row_scalars_dict[module] # so we have to delete the original reference manually

W_metric = torch.abs(W) * torch.sqrt(S.reshape((1, -1)))

# initialize a mask to be all False
W_mask = torch.zeros_like(W_metric) == 1
if prune_n != 0:
# structured n:m sparsity
for ii in range(W_metric.shape[1]):
if ii % prune_m == 0:
tmp = W_metric[:, ii : (ii + prune_m)].float()
W_mask.scatter_(
1,
ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
True,
)
else:
sort_res = torch.sort(W_metric, dim=-1, stable=True)
indices = sort_res[1][:, : int(W_metric.shape[1] * sparsity)]
W_mask.scatter_(1, indices, True)

W[W_mask] = 0.0 # set weights to zero

if isinstance(module, transformers.Conv1D):
W = W.t()

W = W.reshape(final_shape).to(final_dtype)

return W
125 changes: 0 additions & 125 deletions src/llmcompressor/modifiers/pruning/wanda/utils/wanda_wrapper.py

This file was deleted.

117 changes: 0 additions & 117 deletions src/llmcompressor/modifiers/utils/compression_wrapper.py

This file was deleted.

40 changes: 30 additions & 10 deletions src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from functools import wraps
from typing import Any, Callable, ClassVar, List, Union
from typing import Any, Callable, ClassVar, Optional, Set, Union

import torch
from loguru import logger
@@ -29,18 +29,29 @@ class HooksMixin(BaseModel):
- modifier.remove_hooks()
"""

_HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin
_hooks: List[RemovableHandle] = [] # attached to local subclasses
# attached to global HooksMixin class
_HOOKS_DISABLED: ClassVar[bool] = False
_HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set()

# attached to local subclasses
_hooks: Set[RemovableHandle] = set()

@classmethod
@contextlib.contextmanager
def disable_hooks(cls):
"""Disable all hooks across all modifiers"""
def disable_hooks(cls, keep: Set[RemovableHandle] = set()):
"""
Disable all hooks across all modifiers. Composing multiple contexts is
equivalent to the union of `keep` arguments
:param keep: optional set of handles to keep enabled
"""
try:
cls._HOOKS_DISABLED = True
cls._HOOKS_KEEP_ENABLED |= keep
yield
finally:
cls._HOOKS_DISABLED = False
cls._HOOKS_KEEP_ENABLED -= keep

def register_hook(
self,
@@ -60,24 +71,33 @@ def register_hook(
Ex. "forward", "forward_pre", "full_backward", "state_dict_post", ""
:param kwargs: keyword arguments to pass to register hook method
"""
handle = None

@wraps(hook)
def wrapped_hook(*args, **kwargs):
if HooksMixin._HOOKS_DISABLED:
nonlocal handle

if (
HooksMixin._HOOKS_DISABLED
and handle not in HooksMixin._HOOKS_KEEP_ENABLED
):
return

return hook(*args, **kwargs)

register_function = getattr(target, f"register_{hook_type}_hook")
handle = register_function(wrapped_hook, **kwargs)
self._hooks.append(handle)
self._hooks.add(handle)
logger.debug(f"{self} added {handle}")

return handle

def remove_hooks(self):
def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
"""Remove all hooks belonging to a modifier"""
for hook in self._hooks:
if handles is None:
handles = self._hooks

for hook in handles:
hook.remove()

self._hooks = []
self._hooks -= handles
181 changes: 0 additions & 181 deletions src/llmcompressor/modifiers/utils/layer_compressor.py

This file was deleted.

33 changes: 3 additions & 30 deletions src/llmcompressor/modifiers/utils/pytorch_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import cycle
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, Optional

import torch
from torch.nn import Module
@@ -9,27 +9,12 @@
from llmcompressor.pytorch.utils import tensors_module_forward, tensors_to_device

__all__ = [
"EarlyStopException",
"apply_pad_mask_to_batch",
"run_calibration_forward",
"is_moe_model",
]


class EarlyStopException(Exception):
"""
Exception for stopping execution of a PyTorch model early, and saving the
inputs of the stopped module offloaded to cpu
:param args: inputs passed to the layer where the exception was raised
:param kwargs: keyword inputs passed to the layer where the excetion was raised
"""

def __init__(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
self.args = tensors_to_device(args, "cpu")
self.kwargs = kwargs


def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Apply a mask to the input ids of a batch. This is used to zero out
@@ -52,7 +37,7 @@ def run_calibration_forward(
calibration_function: Optional[Callable] = None,
device: Optional[str] = None,
mask_padding: bool = False,
) -> List[torch.Tensor]:
):
"""
Helper function used by one-shot modifiers, runs calibration data through a model to
update modifier statistics and trigger hooks
@@ -64,7 +49,6 @@ def run_calibration_forward(
:param calibration_function: option to pass a custom forward function for model
:param device: option to move the model to a specific device before calibration
:param mask_padding: whether to zero out padding tokens during calibration
:returns: list of last calculated model output if early stopping is triggered
"""
model.eval()

@@ -83,10 +67,6 @@ def run_calibration_forward(
else cycle(calibration_dataloader)
)

# Store any inputs caught from early stopping, used for sequential compression
# of GPTQ, SparseGPT and WANDA
intermediates = []

# run through the calibration data
for batch_idx, batch in enumerate(tqdm(_dataloader)):
if num_calibration_steps and batch_idx >= num_calibration_steps:
@@ -95,19 +75,12 @@ def run_calibration_forward(
batch = apply_pad_mask_to_batch(batch)
batch = tensors_to_device(batch, model_device)
with torch.no_grad():
try:
forward_fn(batch, module=model)
except EarlyStopException as e:
# model was stopped early, save last calculated output and
# move on to next calibration sample
intermediates.append((e.args, e.kwargs))
forward_fn(batch, module=model)

# TODO: not ideal, figure out where we aren't freeing memory instead
# currently without this we run OOM on the 2nd forward pass
torch.cuda.empty_cache()

return intermediates


def is_moe_model(model: Module) -> bool:
"""
5 changes: 2 additions & 3 deletions src/llmcompressor/pipelines/layer_sequential/helpers.py
Original file line number Diff line number Diff line change
@@ -119,10 +119,9 @@ def trigger_early_stop_fn(module, args, kwargs):
@dataclass
class EarlyStopException(Exception):
"""
Note: this exception is different from the exception defined in
llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace it
Dataclass for storing model activations
Attribute names `args` and `kwargs` are reserved for `dataclass`
Note: Attribute names `args` and `kwargs` are reserved for `dataclass.GenericAlias`
"""

_args: Tuple[Any, ...]
7 changes: 4 additions & 3 deletions src/llmcompressor/transformers/tracing/llava.py
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@ def maybe_install_metadata_image_features(

# TRACING: The shape of inputs_embeds is known. This function compensates for
# the fact that shape inference through `masked_scatter` is not implemented yet
def maybe_install_metadata_inputs_embeds(
def maybe_install_metadata_inputs_embeds_masked(
inputs_embeds_masked: Union[torch.Tensor, HFProxy],
inputs_embeds: Union[torch.Tensor, HFProxy],
special_image_mask: Union[torch.Tensor, HFProxy],
@@ -70,7 +70,7 @@ def maybe_install_metadata_inputs_embeds(
)
inputs_embeds_masked.install_metadata(metadata)

return inputs_embeds
return inputs_embeds_masked


# TRACING: override `__init__` and `forward`
@@ -153,6 +153,7 @@ def forward(
vision_feature_select_strategy=vision_feature_select_strategy,
)

# TRACING: install metadata
image_features = maybe_install_metadata_image_features(
image_features, pixel_values, self.config
)
@@ -223,7 +224,7 @@ def forward(
inputs_embeds_masked = inputs_embeds.masked_scatter(special_image_mask, image_features)

# TRACING: install metadata
inputs_embeds_masked = maybe_install_metadata_inputs_embeds(inputs_embeds_masked, inputs_embeds, special_image_mask, image_features)
inputs_embeds_masked = maybe_install_metadata_inputs_embeds_masked(inputs_embeds_masked, inputs_embeds, special_image_mask, image_features)
inputs_embeds = inputs_embeds_masked

outputs = self.language_model(
33 changes: 33 additions & 0 deletions tests/examples/test_sparse_2of4_quantization_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

import pytest

from tests.examples.utils import (
copy_and_run_script,
gen_cmd_fail_message,
requires_gpu_count,
)


@pytest.fixture
def example_dir() -> str:
return "examples/sparse_2of4_quantization_fp8"


@requires_gpu_count(1)
class TestSparse2of4QuantizationFP8:
"""
Tests for examples in the "sparse_2of4_quantization_fp8" example folder.
"""

@pytest.mark.parametrize(("flags"), [[], ["--fp8"]])
def test_blah(self, example_dir: str, tmp_path: Path, flags: list[str]):
"""
Tests for the "llama3_8b_2of4.py" example script.
"""
script_filename = "llama3_8b_2of4.py"
command, result = copy_and_run_script(
tmp_path, example_dir, script_filename, flags=flags
)

assert result.returncode == 0, gen_cmd_fail_message(command, result)
7 changes: 6 additions & 1 deletion tests/examples/utils.py
Original file line number Diff line number Diff line change
@@ -68,7 +68,10 @@ def copy_and_run_command(


def copy_and_run_script(
tmp_path: Path, example_dir: str, script_filename: str
tmp_path: Path,
example_dir: str,
script_filename: str,
flags: Optional[list[str]] = None,
) -> Tuple[List[str], CompletedProcess[str]]:
"""
Copies the contents of example_dir (relative to the current working directory) to
@@ -81,6 +84,8 @@ def copy_and_run_script(
:return: subprocess.CompletedProcess object
"""
command = [sys.executable, script_filename]
if flags:
command.extend(flags)
return command, copy_and_run_command(tmp_path, example_dir, command)


93 changes: 93 additions & 0 deletions tests/llmcompressor/modifiers/utils/test_hooks.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,27 @@ def test_remove_hooks():
assert mod_a.hook_called and not mod_b.hook_called


def test_remove_hooks_parameterized():
model = DummyModel()

mod_a = ModA()
mod_a_pre_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward_pre")
mod_a_post_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward")

mod_b = ModB()
mod_b_pre_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")
mod_b_post_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward")

mod_a.remove_hooks(set([mod_a_post_hook]))
mod_b.remove_hooks(set([mod_b_pre_hook]))

assert len(mod_a._hooks) == 1 and next(iter(mod_a._hooks)) == mod_a_pre_hook
assert len(mod_b._hooks) == 1 and next(iter(mod_b._hooks)) == mod_b_post_hook

model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called


def test_disable_hooks():
model = DummyModel()

@@ -81,3 +102,75 @@ def test_disable_hooks():
mod_b.hook_called = False
model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called


def test_disable_hooks_keep():
model = DummyModel()

mod_a = ModA()
handle_a = mod_a.register_hook(model.linear1, mod_a.hook, "forward")

mod_b = ModB()
handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")

with HooksMixin.disable_hooks(keep=set([handle_b])):
model(model.dummy_inputs)
assert not mod_a.hook_called and mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
with HooksMixin.disable_hooks(keep=set([handle_a])):
model(model.dummy_inputs)
assert mod_a.hook_called and not mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called


def test_disable_hooks_composable():
model = DummyModel()

mod_a = ModA()
handle_a = mod_a.register_hook(model.linear1, mod_a.hook, "forward")

mod_b = ModB()
handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")

# composing two keeps
with (
HooksMixin.disable_hooks(keep=set([handle_b])),
HooksMixin.disable_hooks(keep=set([handle_a])),
):
model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
with HooksMixin.disable_hooks():
model(model.dummy_inputs)
assert not mod_a.hook_called and not mod_b.hook_called

# composing a keep and an empty keep
mod_a.hook_called = False
mod_b.hook_called = False
with HooksMixin.disable_hooks(keep=set([handle_a])), HooksMixin.disable_hooks():
model(model.dummy_inputs)
assert mod_a.hook_called and not mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
with HooksMixin.disable_hooks():
model(model.dummy_inputs)
assert not mod_a.hook_called and not mod_b.hook_called

0 comments on commit 08d700c

Please sign in to comment.