Skip to content

Commit

Permalink
g_idx change
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jul 10, 2024
1 parent 02abc9b commit 44f66fb
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ def compress(

actorder = False
invperm = False
g_idx = None

# actorder check
if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
if actorder and quant_scheme.weights is not None:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

g_idx = [i // group_size for i in range(self.columns)]
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx


# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
Expand Down Expand Up @@ -143,7 +159,7 @@ def compress(
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
# actorder = quant_scheme.weights.actorder

if quant_scheme.weights is not None:
from compressed_tensors.quantization import QuantizationStrategy
Expand All @@ -154,26 +170,28 @@ def compress(
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)
# if actorder:
# perm = torch.argsort(torch.diag(self.H), descending=True)
# W = W[:, perm]
# self.H = self.H[perm][:, perm]
# invperm = torch.argsort(perm)



group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

if actorder:
indices = torch.arange(self.columns, device=invperm.device)
g_idx = (perm[indices] // group_size).to(dtype=torch.int32)
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx
else:
indices = torch.arange(
self.columns, device=W.device, dtype=torch.int32
)
g_idx = indices // group_size
# if actorder:
# indices = torch.arange(self.columns, device=invperm.device)
# g_idx = (perm[indices] // group_size).to(dtype=torch.int32)
# g_idx = g_idx[invperm]
# self.layer.weight_g_idx.data = g_idx
# else:
# indices = torch.arange(
# self.columns, device=W.device, dtype=torch.int32
# )
# g_idx = indices // group_size

strategy = quant_scheme.weights.strategy

Expand Down Expand Up @@ -203,11 +221,21 @@ def compress(
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

g_idx = [i // group_size for i in range(self.columns)]
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx

# apply g_idx
if g_idx is not None:
# scale and zp already transformed by group_size
# extract first index of group_idze
# extract first index of group_size
indices_to_extract = torch.arange(
0, g_idx.shape[0], group_size
)
Expand Down

0 comments on commit 44f66fb

Please sign in to comment.