Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
inherit from shared mixin, both runnable
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
kylesayrs committed Jan 13, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 0a0f591 commit d8c3261
Showing 7 changed files with 289 additions and 601 deletions.
211 changes: 21 additions & 190 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,28 @@
import contextlib
import warnings
from collections import defaultdict
from functools import partial
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr, field_validator
from pydantic import Field, PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.utils.sgpt_sparsify import (
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
accumulate_hessian,
make_empty_hessian,
sparsify_weight,
)
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.metric_logging import CompressionLogger
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
)

__all__ = ["SparseGPTModifier"]


class SparseGPTModifier(Modifier):
class SparseGPTModifier(SparsityModifierMixin, Modifier):
"""
Modifier for applying the one-shot SparseGPT algorithm to a model
@@ -96,114 +81,21 @@ class SparseGPTModifier(Modifier):

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
module_targets: Union[str, List[str], None] = None
module_targets: Union[str, List[str]] = ["Linear"]
targets: Union[str, List[str], None] = None # deprecated, clones sequential_targets
sequential_targets: Union[str, List[str], None] = None
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_update_size: Optional[int] = PrivateAttr(default=None)

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)

return True

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
model = state.model
dataloader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)

# infer layer sparsities
if self.owl_m is not None and self.owl_lmbda is not None:
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()

# register hooks
for index, name, layer in enumerate(get_layers(self.sequential_targets, model)):
if isinstance(self.sparsity, Dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, List):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in get_prunable_layers(layer):
post_hook = partial(
self.compress_module,
name,
layer_sparsity,
)
self.register_hook(module, post_hook, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = state.data.calib.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
self.sequential_targets,
self.ignore,
state.data.calib,
propagate_error=True,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
self.sequential_targets,
state.data.calib,
propagate_error=True,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib)
return True

return True
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

def compress_module(
def calibrate_module(
self,
name: str,
sparsity: float,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
@@ -226,9 +118,17 @@ def compress_module(
self._num_samples[module],
)

# After enough samples are accumulated, perform sparsification
if self._num_samples[module] >= self._update_size:
logger.info(f"Sparsifying {name} using {self._num_samples[module]} samples")
def on_sequential_batch_end(self):
"""
Sparsify modules
TODO: implement with event callback
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
num_samples = self._num_samples[module]

logger.info(f"Sparsifying {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
@@ -242,7 +142,7 @@ def compress_module(
prune_m=self._prune_m,
block_size=self.block_size,
dampening_frac=self.dampening_frac,
preserve_sparsity_mask=self.preserve_sparsity_mask, # TODO: should we deprecate this? GPTQ just checks against a sparsity threshold
preserve_sparsity_mask=self.preserve_sparsity_mask, # TODO: should we deprecate this? GPTQ just checks against a sparsity threshold
)
comp_logger.set_loss(loss)

@@ -262,72 +162,3 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
if self.offload_hessians:
if module in self._hessians: # may have been deleted in context
self._hessians[module] = self._hessians[module].to(device="cpu")

def _infer_sequential_targets(self, model):
if self.sequential_targets is None:
return get_no_split_params(model)
if isinstance(self.sequential_targets, str):
return [self.sequential_targets]

def _infer_owl_layer_sparsity(self, activations):
groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
groups[name] = torch.cat([item.flatten().cpu() for item in z])

del activations
torch.cuda.empty_cache()

outlier_ratios = {}
for group in groups:
threshold = torch.mean(groups[group]) * self.owl_m
outlier_ratios[group] = (
100 * (groups[group] > threshold).sum().item() / groups[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

def _get_activations(self, model, dataloader, nsamples=128):
acts = defaultdict(int)

def save_acts(_module, input, name):
nonlocal acts
if isinstance(input, tuple):
input = input[0]
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()

# TODO: only add hooks to target modules
hooks = set(
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
for name, mod in model.named_modules()
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
)
with HooksMixin.disable_hooks(keep=hooks):
run_basic(model, dataloader)
self.remove_hooks(hooks)

return acts
235 changes: 235 additions & 0 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import warnings
from collections import defaultdict
from functools import partial
from typing import List, Tuple, Union

import numpy as np
import torch
from loguru import logger
from pydantic import field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
match_class,
match_targets,
)


class SparsityModifierMixin(HooksMixin):
@field_validator("sequential_update", mode="before", check_fields=False)
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)

return True

@field_validator("sparsity_profile", mode="before", check_fields=False)
def validate_sparsity_profile(cls, value) -> None:
if value is not None:
warnings.warn(
"`sparsity_profile` is deprecated, use `owl_m` and `owl_lmbda`"
)
return None

@model_validator(mode="after")
def validate_model_after(model: "Modifier") -> "Modifier":
sparsity = model.sparsity
owl_m = model.owl_m
owl_lmbda = model.owl_lmbda
mask_structure = model.mask_structure
targets = model.targets
sequential_targets = model.sequential_targets

if (owl_m is not None) ^ (owl_lmbda is not None):
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")

if owl_m is not None and sparsity is not None:
raise ValueError("Cannot provide both sparsity and owl parameters")

if targets is not None:
warnings.warn(
"`targets` is deprecated, use `module_targets` and `sequential_targets`"
)
if sequential_targets is not None:
raise ValueError("Cannot use both `targets` and `sequential_targets`")
model.sequential_targets = targets
model.targets = None

model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)

return model

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
model = state.model
dataloader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)

# infer layer sparsities
if self.owl_m is not None and self.owl_lmbda is not None:
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()

# register hooks
for index, (name, layer) in enumerate(
get_layers(self.sequential_targets, model).items()
):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, list):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in layer.named_modules():
if (
match_targets(module, self.module_targets)[0]
or match_class(module, self.module_targets)[0]
):
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = state.data.calib.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib, self)
return True

return True

def _infer_sequential_targets(
self, model: torch.nn.Module
) -> Union[str, List[str]]:
if self.sequential_targets is None:
return get_no_split_params(model)
if isinstance(self.sequential_targets, str):
return [self.sequential_targets]
return self.sequential_targets

def _infer_owl_layer_sparsity(self, activations):
groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
groups[name] = torch.cat([item.flatten().cpu() for item in z])

del activations
torch.cuda.empty_cache()

outlier_ratios = {}
for group in groups:
threshold = torch.mean(groups[group]) * self.owl_m
outlier_ratios[group] = (
100 * (groups[group] > threshold).sum().item() / groups[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

def _get_activations(self, model, dataloader, nsamples=128):
acts = defaultdict(int)

def save_acts(_module, input, name):
nonlocal acts
if isinstance(input, tuple):
input = input[0]
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()

# TODO: only add hooks to target modules
hooks = set(
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
for name, mod in model.named_modules()
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
)
with HooksMixin.disable_hooks(keep=hooks):
run_basic(model, dataloader)
self.remove_hooks(hooks)

return acts

def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
n, m = mask_structure.split(":")
return int(n), int(m)
Original file line number Diff line number Diff line change
@@ -64,8 +64,8 @@ def sparsify_weight(
sparsity: float,
prune_n: int,
prune_m: int,
blocksize: int,
percdamp: float,
block_size: int,
dampening_frac: float,
preserve_sparsity_mask: bool,
) -> torch.Tensor:
"""
@@ -77,8 +77,8 @@ def sparsify_weight(
:param sparsity: target sparsity to reach for layer
:param prune_n: N for N:M pruning
:param prune_m: M for N:M pruning
:param blocksize: Number of columns to compress in one pass
:param percdamp: Amount of dampening to apply to H, as a fraction of the
:param block_size: Number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Extend or ignore the base sparsity mask
"""
@@ -110,7 +110,7 @@ def sparsify_weight(

# compute inverse hessian in place to save memory
try:
damp = percdamp * torch.mean(torch.diag(H))
damp = dampening_frac * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0], device=H.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
@@ -145,8 +145,8 @@ def sparsify_weight(
losses = torch.zeros(num_rows, device=module.weight.device)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, num_columns, blocksize):
i2 = min(i1 + blocksize, num_columns)
for i1 in range(0, num_columns, block_size):
i2 = min(i1 + block_size, num_columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Empty file.
241 changes: 22 additions & 219 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,27 @@
import warnings
from collections import defaultdict
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr, field_validator, model_validator
from pydantic import Field, PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.pruning.wanda.utils.wanda_sparsify import (
accumulate_row_scalars,
make_empty_row_scalars,
sparsify_weight,
)
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.metric_logging import CompressionLogger
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
)

__all__ = ["WandaPruningModifier"]


class WandaPruningModifier(Modifier):
class WandaPruningModifier(SparsityModifierMixin, Modifier):
"""
Modifier for applying the one-shot WANDA algorithm to a model
from the paper: https://arxiv.org/abs/2306.11695
@@ -59,9 +44,10 @@ class WandaPruningModifier(Modifier):
sparsity_profile: Optional[str] = None # deprecated

# data pipeline arguments
module_targets: Union[str, List[str], None] = None
module_targets: Union[str, List[str]] = ["Linear"]
targets: Union[str, List[str], None] = None # deprecated, clones sequential_targets
sequential_targets: Union[str, List[str], None] = None
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
@@ -70,130 +56,11 @@ class WandaPruningModifier(Modifier):
default_factory=dict
)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_update_size: Optional[int] = PrivateAttr(default=None)

@field_validator("sparsity_profile", mode="before")
def validate_sparsity_profile(cls, value) -> None:
if value is not None:
warnings.warn(
"`sparsity_profile` is deprecated, use `owl_m` and `owl_lmbda`"
)
return None

@model_validator(mode="after")
def validate_model_after(model: "WandaPruningModifier") -> Dict[str, Any]:
sparsity = model.sparsity
owl_m = model.owl_m
owl_lmbda = model.owl_lmbda
mask_structure = model.mask_structure
targets = model.targets
sequential_targets = model.sequential_targets

if (owl_m is not None) ^ (owl_lmbda is not None):
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")

if owl_m is not None and sparsity is not None:
raise ValueError("Cannot provide both sparsity and owl parameters")

if targets is not None:
warnings.warn(
"`targets` is deprecated, use `module_targets` and `sequential_targets`"
)
if sequential_targets is not None:
raise ValueError("Cannot use both `targets` and `sequential_targets`")
model.sequential_targets = targets

model._prune_n, model._prune_m = mask_structure.split(":")

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the WANDA algorithm on the current state
:param state: session state storing input model and calibration data
:param kwargs: Unused, kept to conform to the parent method signature
"""
model = state.model
dataloader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)

# infer layer sparsities
if self.owl_m is not None and self.owl_lmbda is not None:
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()

# infer update size
if self._update_size is None:
self._update_size = len(dataloader)

# register hooks
for index, name, layer in enumerate(get_layers(self.sequential_targets, model)):
if isinstance(self.sparsity, Dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, List):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in get_prunable_layers(layer):
post_hook = partial(
self.compress_module,
name,
layer_sparsity,
)
self.register_hook(module, post_hook, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = state.data.calib.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
self.sequential_targets,
self.ignore,
state.data.calib,
propagate_error=True,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
self.sequential_targets,
state.data.calib,
propagate_error=True,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib)
return True

def compress_module(
def calibrate_module(
self,
name: str,
sparsity: float,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
@@ -215,9 +82,18 @@ def compress_module(
self._num_samples[module],
)

# After enough samples are accumulated, perform sparsification
if self._num_samples[module] >= self._update_size:
logger.info(f"Sparsifying {name} using {self._num_samples[module]} samples")
def on_sequential_batch_end(self):
"""
Sparsify modules
TODO: implement with event callback
"""

for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
num_samples = self._num_samples[module]

logger.info(f"Sparsifying {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
@@ -233,78 +109,5 @@ def compress_module(

update_offload_parameter(module, "weight", sparsified_weight)

# self._hessians[module] already deleted by sparsify_weight
# self._row_scalars[module] already deleted by sparsify_weight
del self._num_samples[module]

def _infer_sequential_targets(self, model):
if self.sequential_targets is None:
return get_no_split_params(model)
if isinstance(self.sequential_targets, str):
return [self.sequential_targets]

def _infer_owl_layer_sparsity(
self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader
) -> Dict[str, float]:
activations = self._get_activations(model, dataloader)

wanda = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
wanda[name] = torch.cat([item.flatten().cpu() for item in z])

del activations
torch.cuda.empty_cache()

outlier_ratios = {}
for group in wanda:
threshold = torch.mean(wanda[group]) * self.owl_m
outlier_ratios[group] = (
100 * (wanda[group] > threshold).sum().item() / wanda[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

def _get_activations(self, model, dataloader, nsamples=128):
acts = defaultdict(int)

def save_acts(_module, input, name):
nonlocal acts
if isinstance(input, tuple):
input = input[0]
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()

# TODO: only add hooks to target modules
hooks = set(
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
for name, mod in model.named_modules()
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
)
with HooksMixin.disable_hooks(keep=hooks):
run_basic(model, dataloader)
self.remove_hooks(hooks)

return acts
8 changes: 4 additions & 4 deletions src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
@@ -86,11 +86,11 @@ def wrapped_hook(*args, **kwargs):
return hook(*args, **kwargs)

register_function = getattr(target, f"register_{hook_type}_hook")
handle.value = register_function(wrapped_hook, **kwargs)
self._hooks.add(handle.value)
logger.debug(f"{self} added {handle.value}")
handle = register_function(wrapped_hook, **kwargs)
self._hooks.add(handle)
logger.debug(f"{self} added {handle}")

return handle.value
return handle

def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
"""Remove all hooks belonging to a modifier"""
181 changes: 0 additions & 181 deletions src/llmcompressor/modifiers/utils/layer_compressor.py

This file was deleted.

0 comments on commit d8c3261

Please sign in to comment.