-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
squash
Signed-off-by: Kyle Sayers <[email protected]>
Showing
19 changed files
with
915 additions
and
1,250 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
import warnings | ||
from collections import defaultdict | ||
from functools import partial | ||
from typing import List, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from loguru import logger | ||
from pydantic import field_validator, model_validator | ||
|
||
from llmcompressor.core import State | ||
from llmcompressor.modifiers import Modifier | ||
from llmcompressor.modifiers.utils.hooks import HooksMixin | ||
from llmcompressor.pipelines.basic import run_pipeline as run_basic | ||
from llmcompressor.pipelines.layer_sequential import ( | ||
run_pipeline as run_layer_sequential, | ||
) | ||
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential | ||
from llmcompressor.utils.pytorch.module import ( | ||
get_layers, | ||
get_no_split_params, | ||
get_prunable_layers, | ||
match_class, | ||
match_targets, | ||
) | ||
|
||
|
||
class SparsityModifierMixin(HooksMixin): | ||
@field_validator("sequential_update", mode="before", check_fields=False) | ||
def validate_sequential_update(cls, value: bool) -> bool: | ||
if not value: | ||
warnings.warn( | ||
"`sequential_update=False` is no longer supported, setting " | ||
"sequential_update=True", | ||
DeprecationWarning, | ||
) | ||
|
||
return True | ||
|
||
@field_validator("sparsity_profile", mode="before", check_fields=False) | ||
def validate_sparsity_profile(cls, value) -> None: | ||
if value is not None: | ||
warnings.warn( | ||
"`sparsity_profile` is deprecated, use `owl_m` and `owl_lmbda`" | ||
) | ||
return None | ||
|
||
@model_validator(mode="after") | ||
def validate_model_after(model: "Modifier") -> "Modifier": | ||
sparsity = model.sparsity | ||
owl_m = model.owl_m | ||
owl_lmbda = model.owl_lmbda | ||
mask_structure = model.mask_structure | ||
targets = model.targets | ||
sequential_targets = model.sequential_targets | ||
|
||
if (owl_m is not None) ^ (owl_lmbda is not None): | ||
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither") | ||
|
||
if owl_m is not None and sparsity is not None: | ||
raise ValueError("Cannot provide both sparsity and owl parameters") | ||
|
||
if targets is not None: | ||
warnings.warn( | ||
"`targets` is deprecated, use `module_targets` and `sequential_targets`" | ||
) | ||
if sequential_targets is not None: | ||
raise ValueError("Cannot use both `targets` and `sequential_targets`") | ||
model.sequential_targets = targets | ||
model.targets = None | ||
|
||
model._prune_n, model._prune_m = model._split_mask_structure(mask_structure) | ||
|
||
return model | ||
|
||
def on_initialize(self, state: "State", **kwargs) -> bool: | ||
""" | ||
Initialize and run the OBCQ algorithm on the current state | ||
:param state: session state storing input model and calibration data | ||
""" | ||
model = state.model | ||
dataloader = state.data.calib | ||
|
||
# infer module and sequential targets | ||
self.sequential_targets = self._infer_sequential_targets(model) | ||
|
||
# infer layer sparsities | ||
if self.owl_m is not None and self.owl_lmbda is not None: | ||
logger.info( | ||
"Using OWL to infer target layer-wise sparsities from " | ||
f"{len(dataloader) if dataloader else 0} calibration samples..." | ||
) | ||
self.sparsity = self._infer_owl_layer_sparsity() | ||
|
||
# register hooks | ||
for index, (name, layer) in enumerate( | ||
get_layers(self.sequential_targets, model).items() | ||
): | ||
if isinstance(self.sparsity, dict): | ||
layer_sparsity = self.sparsity[name] | ||
elif isinstance(self.sparsity, list): | ||
layer_sparsity = self.sparsity[index] | ||
else: | ||
layer_sparsity = self.sparsity | ||
|
||
for name, module in layer.named_modules(): | ||
if ( | ||
match_targets(module, self.module_targets)[0] | ||
or match_class(module, self.module_targets)[0] | ||
): | ||
self._module_names[module] = name | ||
self._module_sparsities[module] = layer_sparsity | ||
self.register_hook(module, self.calibrate_module, "forward") | ||
|
||
# infer and run pipeline | ||
model_name = state.model.__class__.__name__ | ||
input_names = state.data.calib.dataset.column_names | ||
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) | ||
try: | ||
run_sequential( | ||
state.model, | ||
state.data.calib, | ||
self.sequential_targets, | ||
self.ignore, | ||
self, | ||
) | ||
return True | ||
|
||
except Exception as exception: | ||
if isinstance(exception, torch.fx.proxy.TraceError): | ||
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") | ||
if isinstance(exception, unfixable_errors): | ||
raise exception | ||
|
||
warnings.warn("Falling back to layer_sequential pipeline") | ||
try: | ||
run_layer_sequential( | ||
state.model, | ||
state.data.calib, | ||
self.sequential_targets, | ||
self, | ||
) | ||
return True | ||
|
||
except Exception as exception: | ||
if isinstance(exception, TypeError): | ||
warnings.warn(f"{model_name} fails layer-wise assumptions") | ||
if isinstance(exception, unfixable_errors): | ||
raise exception | ||
|
||
warnings.warn( | ||
"Falling back to basic pipeline, which requires extra memory and " | ||
"may result in decreased accuracy" | ||
) | ||
run_basic(state.model, state.data.calib, self) | ||
return True | ||
|
||
return True | ||
|
||
def _infer_sequential_targets( | ||
self, model: torch.nn.Module | ||
) -> Union[str, List[str]]: | ||
if self.sequential_targets is None: | ||
return get_no_split_params(model) | ||
if isinstance(self.sequential_targets, str): | ||
return [self.sequential_targets] | ||
return self.sequential_targets | ||
|
||
def _infer_owl_layer_sparsity(self, activations): | ||
groups = {} | ||
for name, layer in self.compressible_layers_.items(): | ||
prunable_layers = get_prunable_layers(layer) | ||
z = [ | ||
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) | ||
for n, m in prunable_layers.items() | ||
] | ||
groups[name] = torch.cat([item.flatten().cpu() for item in z]) | ||
|
||
del activations | ||
torch.cuda.empty_cache() | ||
|
||
outlier_ratios = {} | ||
for group in groups: | ||
threshold = torch.mean(groups[group]) * self.owl_m | ||
outlier_ratios[group] = ( | ||
100 * (groups[group] > threshold).sum().item() / groups[group].numel() | ||
) | ||
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) | ||
for k in outlier_ratios: | ||
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( | ||
1 | ||
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min()) | ||
* self.owl_lmbda | ||
* 2 | ||
) | ||
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) | ||
sparsities = { | ||
k: 1 | ||
- ( | ||
outlier_ratios[k] | ||
- np.mean(outlier_ratios_arr) | ||
+ (1 - float(self.sparsity)) | ||
) | ||
for k in outlier_ratios | ||
} | ||
logger.info(f"OWL sparsities for sp={self.sparsity} are:") | ||
for k in sparsities: | ||
logger.info(f"Sparsity for {k}: {sparsities[k]}") | ||
return sparsities | ||
|
||
def _get_activations(self, model, dataloader, nsamples=128): | ||
acts = defaultdict(int) | ||
|
||
def save_acts(_module, input, name): | ||
nonlocal acts | ||
if isinstance(input, tuple): | ||
input = input[0] | ||
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt() | ||
|
||
# TODO: only add hooks to target modules | ||
hooks = set( | ||
self.register_hook(mod, partial(save_acts, name=name), "forward_pre") | ||
for name, mod in model.named_modules() | ||
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name | ||
) | ||
with HooksMixin.disable_hooks(keep=hooks): | ||
run_basic(model, dataloader) | ||
self.remove_hooks(hooks) | ||
|
||
return acts | ||
|
||
def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]: | ||
n, m = mask_structure.split(":") | ||
return int(n), int(m) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
import math | ||
from typing import Dict, Optional, Tuple | ||
|
||
import torch | ||
import transformers | ||
from compressed_tensors.quantization.lifecycle.forward import forward_quantize | ||
|
||
from llmcompressor.utils import getattr_chain | ||
|
||
SGPT_PRECISION = torch.float32 | ||
|
||
|
||
def make_empty_hessian( | ||
module: torch.nn.Module, device: Optional[torch.device] = None | ||
) -> torch.Tensor: | ||
weight = module.weight | ||
num_columns = weight.shape[1] | ||
device = device if device is not None else weight.device | ||
return torch.zeros((num_columns, num_columns), device=device, dtype=SGPT_PRECISION) | ||
|
||
|
||
def accumulate_hessian( | ||
inp: torch.Tensor, | ||
module: torch.nn.Module, | ||
H: Optional[torch.Tensor] = None, | ||
num_samples: int = 1, | ||
) -> Tuple[torch.Tensor, int]: | ||
inp = inp.to(device=H.device) | ||
if len(inp.shape) == 2: | ||
inp = inp.unsqueeze(0) | ||
|
||
num_added = inp.shape[0] # note this is the number of dataset samples, not | ||
# multiplied by the sequence length | ||
|
||
if isinstance(module, (torch.nn.Linear, transformers.Conv1D)): | ||
if len(inp.shape) == 3: | ||
inp = inp.reshape((-1, inp.shape[-1])) | ||
inp = inp.t() | ||
|
||
if isinstance(module, torch.nn.Conv2d): | ||
unfold = torch.nn.Unfold( | ||
module.kernel_size, | ||
dilation=module.dilation, | ||
padding=module.padding, | ||
stride=module.stride, | ||
) | ||
inp = unfold(inp) | ||
inp = inp.permute([1, 0, 2]) | ||
inp = inp.flatten(1) | ||
|
||
H *= num_samples / (num_samples + num_added) | ||
num_samples += num_added | ||
|
||
inp = inp.to(dtype=SGPT_PRECISION) | ||
inp = math.sqrt(2 / num_samples) * inp | ||
H += inp.matmul(inp.t()) | ||
|
||
return H, num_samples | ||
|
||
|
||
def sparsify_weight( | ||
module: torch.nn.Module, | ||
hessians_dict: Dict[torch.nn.Module, torch.Tensor], | ||
sparsity: float, | ||
prune_n: int, | ||
prune_m: int, | ||
block_size: int, | ||
dampening_frac: float, | ||
preserve_sparsity_mask: bool, | ||
) -> torch.Tensor: | ||
""" | ||
Run pruning and quantization(if applicable) on the layer up to the target | ||
sparsity value. | ||
:param module: module with weight being sparsified | ||
:param hessian_dict: dictionary containing preaccumulated hessian for sparsification | ||
:param sparsity: target sparsity to reach for layer | ||
:param prune_n: N for N:M pruning | ||
:param prune_m: M for N:M pruning | ||
:param block_size: Number of columns to compress in one pass | ||
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the | ||
diagonal norm | ||
:param preserve_sparsity_mask: Extend or ignore the base sparsity mask | ||
""" | ||
final_shape = module.weight.shape | ||
final_dtype = module.weight.dtype | ||
W = module.weight.clone() | ||
H = hessians_dict[module] # unfortunately python does not have a `move` keyword | ||
del hessians_dict[module] # so we have to delete the original reference manually | ||
|
||
# if this module is quantized, perform RTN quantization before sparsifying | ||
args_loc = "quantization_scheme.weights" | ||
weight_quant_args = getattr_chain(module, args_loc, None) | ||
if weight_quant_args is not None: | ||
W = forward_quantize(module, W, "weight", weight_quant_args) | ||
|
||
# standardize shape and dtype | ||
if isinstance(module, torch.nn.Conv2d): | ||
W = W.flatten(1) | ||
elif isinstance(module, transformers.Conv1D): | ||
W.transpose_(0, 1) | ||
W = W.to(dtype=SGPT_PRECISION) | ||
num_rows = W.shape[0] | ||
num_columns = W.shape[1] | ||
|
||
# mask dead hessian values | ||
dead = torch.diag(H) == 0 | ||
H[dead, dead] = 1 | ||
W[:, dead] = 0 | ||
|
||
# compute inverse hessian in place to save memory | ||
try: | ||
damp = dampening_frac * torch.mean(torch.diag(H)) | ||
diag = torch.arange(H.shape[0], device=H.device) | ||
H[diag, diag] += damp | ||
H = torch.linalg.cholesky(H) | ||
H = torch.cholesky_inverse(H) | ||
H = torch.linalg.cholesky(H, upper=True) | ||
Hinv = H | ||
except torch._C._LinAlgError: | ||
raise torch._C._LinAlgError( | ||
"Failed to invert hessian due to numerical instability. Consider " | ||
"increasing GPTQModifier.dampening_frac, increasing the number " | ||
"of calibration samples, or shuffling the calibration dataset" | ||
) | ||
|
||
# sparsity mask | ||
# TODO: consider computing sparsity mask in the same way and place as gptq | ||
mask = None | ||
if preserve_sparsity_mask: | ||
# compute existing sparsity mask | ||
mask = torch.where( | ||
W == 0, | ||
torch.tensor(1, dtype=torch.bool), | ||
torch.tensor(0, dtype=torch.bool), | ||
) | ||
current_sparsity = mask.sum() / W.numel() | ||
if current_sparsity > sparsity: | ||
raise ValueError( | ||
"The target sparsity is lower than the sparsity " | ||
"of the base model. Please retry " | ||
"after turning preserve_sparsity_mask=False" | ||
) | ||
|
||
losses = torch.zeros(num_rows, device=module.weight.device) | ||
|
||
# See section 3.4 of https://arxiv.org/abs/2203.07259 | ||
for i1 in range(0, num_columns, block_size): | ||
i2 = min(i1 + block_size, num_columns) | ||
count = i2 - i1 | ||
|
||
W1 = W[:, i1:i2].clone() | ||
Q1 = torch.zeros_like(W1) | ||
Err1 = torch.zeros_like(W1) | ||
Losses1 = torch.zeros_like(W1) | ||
Hinv1 = Hinv[i1:i2, i1:i2] | ||
|
||
if prune_n == 0: | ||
if mask is not None: | ||
mask1 = mask[:, i1:i2] | ||
if int(W1.numel() * sparsity) > mask1.sum(): | ||
# target sparsity is higher than base sparsity, extend mask1 | ||
tmp = ( | ||
(~mask[:, i1:i2]) | ||
* W1**2 | ||
/ (torch.diag(Hinv1).reshape((1, -1))) ** 2 | ||
) | ||
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] | ||
mask1 = tmp <= thresh | ||
else: | ||
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 | ||
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] | ||
mask1 = tmp <= thresh | ||
else: | ||
if mask is not None: | ||
mask1 = mask[:, i1:i2] | ||
else: | ||
mask1 = torch.zeros_like(W1) == 1 | ||
|
||
for i in range(count): | ||
w = W1[:, i] | ||
d = Hinv1[i, i] | ||
|
||
if prune_n != 0 and i % prune_m == 0: | ||
tmp = ( | ||
W1[:, i : (i + prune_m)] ** 2 | ||
/ (torch.diag(Hinv1)[i : (i + prune_m)].reshape((1, -1))) ** 2 | ||
) | ||
if mask is not None: | ||
tmp = tmp * (~mask[:, i : (i + prune_m)]) | ||
|
||
mask1.scatter_( | ||
1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True | ||
) | ||
|
||
q = w.clone() | ||
q[mask1[:, i]] = 0 | ||
|
||
Q1[:, i] = q | ||
Losses1[:, i] = (w - q) ** 2 / d**2 | ||
|
||
err1 = (w - q) / d | ||
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) | ||
Err1[:, i] = err1 | ||
|
||
W[:, i1:i2] = Q1 | ||
losses += torch.sum(Losses1, 1) / 2 | ||
|
||
if preserve_sparsity_mask: | ||
# respect the sparsity of other groups | ||
# really not needed, but kept for explicitness | ||
W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:]) | ||
else: | ||
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) | ||
|
||
if isinstance(module, transformers.Conv1D): | ||
W.transpose_(0, 1) | ||
W = W.reshape(final_shape).to(final_dtype) | ||
|
||
# perform RTN quantization | ||
if weight_quant_args is not None: | ||
W = forward_quantize(module, W, "weight", weight_quant_args) | ||
|
||
loss = torch.sum(losses).item() | ||
return loss, W |
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
src/llmcompressor/modifiers/pruning/wanda/utils/wanda_sparsify.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
import transformers | ||
|
||
WANDA_PRECISION = torch.float32 | ||
|
||
# TODO: are these needed? | ||
# torch.backends.cuda.matmul.allow_tf32 = False | ||
# torch.backends.cudnn.allow_tf32 = False | ||
|
||
|
||
def make_empty_row_scalars( | ||
module: torch.nn.Module, device: Optional[torch.device] = None | ||
) -> torch.Tensor: | ||
weight = module.weight | ||
num_columns = weight.shape[1] | ||
device = device if device is not None else weight.device | ||
return torch.zeros(num_columns, device=device) | ||
|
||
|
||
def accumulate_row_scalars( | ||
inp: torch.Tensor, | ||
module: torch.nn.Module, | ||
row_scalars: Optional[torch.Tensor], | ||
num_samples: int, | ||
): | ||
inp = inp.to(device=row_scalars.device) | ||
if len(inp.shape) == 2: | ||
inp = inp.unsqueeze(0) | ||
|
||
num_added = inp.shape[0] # note this is the number of dataset samples, not | ||
# multiplied by the sequence length | ||
|
||
if isinstance(module, (torch.nn.Linear, transformers.Conv1D)): | ||
if len(inp.shape) == 3: | ||
inp = inp.reshape((-1, inp.shape[-1])) | ||
inp = inp.t() | ||
|
||
if isinstance(module, torch.nn.Conv2d): | ||
unfold = torch.nn.Unfold( | ||
module.kernel_size, | ||
dilation=module.dilation, | ||
padding=module.padding, | ||
stride=module.stride, | ||
) | ||
inp = unfold(inp) | ||
inp = inp.permute([1, 0, 2]) | ||
inp = inp.flatten(1) | ||
|
||
row_scalars *= num_samples / (num_samples + num_added) | ||
num_samples += num_added | ||
|
||
inp = inp.type(WANDA_PRECISION) | ||
row_scalars += torch.norm(inp, p=2, dim=1) ** 2 / num_samples | ||
|
||
return row_scalars, num_samples | ||
|
||
|
||
def sparsify_weight( | ||
module: torch.nn.Module, | ||
row_scalars_dict: Dict[torch.nn.Module, torch.Tensor], | ||
sparsity: float, | ||
prune_n: int, | ||
prune_m: int, | ||
) -> torch.Tensor: | ||
""" | ||
Run pruning on the layer up to the target sparsity value. | ||
:param sparsity: target sparsity to reach for layer | ||
:param prunen: N for N:M pruning | ||
:param prunem: M for N:M pruning | ||
""" | ||
final_shape = module.weight.shape | ||
final_dtype = module.weight.dtype | ||
W = module.weight.data.clone() | ||
if isinstance(module, torch.nn.Conv2d): | ||
W = W.flatten(1) | ||
if isinstance(module, transformers.Conv1D): | ||
W = W.t() | ||
W = W.to(dtype=WANDA_PRECISION) | ||
S = row_scalars_dict[module] # unfortunately python does not have a `move` keyword | ||
del row_scalars_dict[module] # so we have to delete the original reference manually | ||
|
||
W_metric = torch.abs(W) * torch.sqrt(S.reshape((1, -1))) | ||
|
||
# initialize a mask to be all False | ||
W_mask = torch.zeros_like(W_metric) == 1 | ||
if prune_n != 0: | ||
# structured n:m sparsity | ||
for ii in range(W_metric.shape[1]): | ||
if ii % prune_m == 0: | ||
tmp = W_metric[:, ii : (ii + prune_m)].float() | ||
W_mask.scatter_( | ||
1, | ||
ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1], | ||
True, | ||
) | ||
else: | ||
sort_res = torch.sort(W_metric, dim=-1, stable=True) | ||
indices = sort_res[1][:, : int(W_metric.shape[1] * sparsity)] | ||
W_mask.scatter_(1, indices, True) | ||
|
||
W[W_mask] = 0.0 # set weights to zero | ||
|
||
if isinstance(module, transformers.Conv1D): | ||
W = W.t() | ||
|
||
W = W.reshape(final_shape).to(final_dtype) | ||
|
||
return W |
125 changes: 0 additions & 125 deletions
125
src/llmcompressor/modifiers/pruning/wanda/utils/wanda_wrapper.py
This file was deleted.
Oops, something went wrong.
117 changes: 0 additions & 117 deletions
117
src/llmcompressor/modifiers/utils/compression_wrapper.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from tests.examples.utils import ( | ||
copy_and_run_script, | ||
gen_cmd_fail_message, | ||
requires_gpu_count, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def example_dir() -> str: | ||
return "examples/sparse_2of4_quantization_fp8" | ||
|
||
|
||
@requires_gpu_count(1) | ||
class TestSparse2of4QuantizationFP8: | ||
""" | ||
Tests for examples in the "sparse_2of4_quantization_fp8" example folder. | ||
""" | ||
|
||
@pytest.mark.parametrize(("flags"), [[], ["--fp8"]]) | ||
def test_blah(self, example_dir: str, tmp_path: Path, flags: list[str]): | ||
""" | ||
Tests for the "llama3_8b_2of4.py" example script. | ||
""" | ||
script_filename = "llama3_8b_2of4.py" | ||
command, result = copy_and_run_script( | ||
tmp_path, example_dir, script_filename, flags=flags | ||
) | ||
|
||
assert result.returncode == 0, gen_cmd_fail_message(command, result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters