Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

actorder #16

Closed
wants to merge 17 commits into from
71 changes: 59 additions & 12 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def compress(
self.H[dead, dead] = 1
W[:, dead] = 0

g_idx = None
if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder

if actorder:
group_size = quant_scheme.weights.group_size
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

g_idx = torch.Tensor(
[perm[i] // group_size for i in range(self.columns)]
).to(device=invperm.device)
self.layer.weight_g_idx.data = g_idx

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(self.H))
Expand All @@ -123,6 +140,10 @@ def compress(
if preserve_zeros:
W1_nz_mask = W_nz_mask[:, i1:i2]

is_layer_updated_actorder = False
if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
Expand All @@ -140,14 +161,20 @@ def compress(
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme

if quant_scheme.weights is not None:
# fetch latest correct scale and ZP relevant for any changes
# such as activation reordering
from compressed_tensors.quantization import (
update_layer_weight_quant_params,
)

update_layer_weight_quant_params(self.layer)
if not is_layer_updated_actorder:
# such as activation reordering
from compressed_tensors.quantization import (
update_layer_weight_quant_params,
)

observer = getattr(self.layer, "weight_observer", None)
observer.reset()
horheynm marked this conversation as resolved.
Show resolved Hide resolved
update_layer_weight_quant_params(self.layer, g_idx)
is_layer_updated_actorder = True

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
Expand All @@ -156,6 +183,13 @@ def compress(
fake_quantize,
)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
Expand Down Expand Up @@ -184,12 +218,22 @@ def compress(
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)

if g_idx is not None:
q = fake_quantize(
q,
scale[:, int(g_idx[column_idx])],
zero_point[:, int(g_idx[column_idx])],
altered_qargs,
)

else:
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Expand All @@ -210,10 +254,13 @@ def compress(
W[:, i2:] -= w_err * W_nz_mask[:, i2:]
else:
W[:, i2:] -= w_err

logger.info("time %.2f" % (time.time() - tick))
logger.info("error %.2f" % torch.sum(Losses).item())

if actorder:
W = W[:, invperm]
self.H = self.H[perm][:, perm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)
Expand Down