From 44f66fb2a05011d0494ccebfe7c3d8aeefe45ea7 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 10 Jul 2024 15:06:03 +0000 Subject: [PATCH] g_idx change --- .../quantization/gptq/utils/gptq_wrapper.py | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index 460d4457a..407f37732 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -111,6 +111,22 @@ def compress( actorder = False invperm = False + g_idx = None + + # actorder check + if hasattr(self.layer, "quantization_scheme"): + quant_scheme = self.layer.quantization_scheme + actorder = quant_scheme.weights.actorder + if actorder and quant_scheme.weights is not None: + perm = torch.argsort(torch.diag(self.H), descending=True) + W = W[:, perm] + self.H = self.H[perm][:, perm] + invperm = torch.argsort(perm) + + g_idx = [i // group_size for i in range(self.columns)] + g_idx = g_idx[invperm] + self.layer.weight_g_idx.data = g_idx + # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -143,7 +159,7 @@ def compress( q = torch.dequantize(q) elif hasattr(self.layer, "quantization_scheme"): quant_scheme = self.layer.quantization_scheme - actorder = quant_scheme.weights.actorder + # actorder = quant_scheme.weights.actorder if quant_scheme.weights is not None: from compressed_tensors.quantization import QuantizationStrategy @@ -154,26 +170,28 @@ def compress( scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point - if actorder: - perm = torch.argsort(torch.diag(self.H), descending=True) - W = W[:, perm] - self.H = self.H[perm][:, perm] - invperm = torch.argsort(perm) + # if actorder: + # perm = torch.argsort(torch.diag(self.H), descending=True) + # W = W[:, perm] + # self.H = self.H[perm][:, perm] + # invperm = torch.argsort(perm) + + group_size = quant_scheme.weights.group_size if group_size is None or group_size == -1: group_size = self.layer.weight.shape[1] - if actorder: - indices = torch.arange(self.columns, device=invperm.device) - g_idx = (perm[indices] // group_size).to(dtype=torch.int32) - g_idx = g_idx[invperm] - self.layer.weight_g_idx.data = g_idx - else: - indices = torch.arange( - self.columns, device=W.device, dtype=torch.int32 - ) - g_idx = indices // group_size + # if actorder: + # indices = torch.arange(self.columns, device=invperm.device) + # g_idx = (perm[indices] // group_size).to(dtype=torch.int32) + # g_idx = g_idx[invperm] + # self.layer.weight_g_idx.data = g_idx + # else: + # indices = torch.arange( + # self.columns, device=W.device, dtype=torch.int32 + # ) + # g_idx = indices // group_size strategy = quant_scheme.weights.strategy @@ -203,11 +221,21 @@ def compress( # ends up being a channelwise application altered_qargs = copy(quant_scheme.weights) altered_qargs.strategy = QuantizationStrategy.CHANNEL + + if actorder: + perm = torch.argsort(torch.diag(self.H), descending=True) + W = W[:, perm] + self.H = self.H[perm][:, perm] + invperm = torch.argsort(perm) + + g_idx = [i // group_size for i in range(self.columns)] + g_idx = g_idx[invperm] + self.layer.weight_g_idx.data = g_idx # apply g_idx if g_idx is not None: # scale and zp already transformed by group_size - # extract first index of group_idze + # extract first index of group_size indices_to_extract = torch.arange( 0, g_idx.shape[0], group_size )