Skip to content

Commit

Permalink
feat(sanity-check): implem version of gptq now added
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed May 2, 2023
1 parent 8a37fb4 commit 4233522
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 13 deletions.
182 changes: 182 additions & 0 deletions quantize/gptq/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,188 @@ def make_quant(module, names, bits, groupsize, name=''):
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

def make_quant_custom(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:

bias_name = attr.replace('w', 'b')
layer_name = attr.replace('w', 'quant')
setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None))


class QuantLinear_custom(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
super().__init__()
if bits not in [2,3,4,8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.groupsize = groupsize if groupsize != -1 else infeatures
self.maxq = 2 ** self.bits - 1

self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32))
if bias:
self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16))
else:
self.bias = None

# is performed by unpacking the weights and using torch.matmul
if self.bits in [2,4,8]:
self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False)
elif self.bits == 3:
self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
[0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
[0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False)

self.kernel_switch_threshold = kernel_switch_threshold
self.is_cuda = is_cuda

def pack(self, weight, bias, scales, zeros, g_idx = None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if bias is not None:
self.bias = bias.clone().half()

intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((weight[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
intweight = torch.cat(intweight,dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2,4,8]:
for j in range(i, i + (32//self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32//self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2,4,8]:
for j in range(i, i + (32//self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32//self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures, )
x = x.reshape(-1,x.shape[-1])
if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold):
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
if self.bits == 2:
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 3:
quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 4:
quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 8:
quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
out = out.half()
else:
if self.bits in [2,4,8]:
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)

weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight)
elif self.bits == 3:
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
zeros = (zeros >> self.wf.unsqueeze(0))
zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4)
zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6)
zeros = zeros & 0x7
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)

weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1)
weight = (weight >> self.wf.unsqueeze(-1))&0x7
weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4)
weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6)
weight = weight & 0x7
weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1)

weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx]))
out = torch.matmul(x.half(), weights)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
super().__init__()
Expand Down
102 changes: 89 additions & 13 deletions quantize/gptq/sanity_check_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
import torch.nn.functional as F

from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2
from gptq import *
Expand Down Expand Up @@ -34,9 +36,8 @@ def load_quant(model, checkpoint, wbits, groupsize):

# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
for name in ["linear4"]:
if name in layers:
del layers[name]
if "linear4" in layers:
del layers["linear4"]

make_quant(model, layers, wbits, groupsize)
model.load_state_dict(torch.load(checkpoint))
Expand Down Expand Up @@ -258,8 +259,8 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
### begin GPTQ_CUSTOM
def __init__(self, checkpoint_path):
super().__init__()
self.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
self.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))

def _fill_subset(self, layer_id):
is_last_layer = (layer_id == self.nb_layers - 1)
if is_last_layer:
Expand Down Expand Up @@ -292,7 +293,7 @@ def fasterquant(self, layer_id, quantizers):
print(layer_id, name)
print('Quantizing ...')
scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False)
quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())

## end GPTQ_CUSTOM

Expand All @@ -301,6 +302,19 @@ def my_linear(self, x, weight, bias):
out = x @ weight.weight + bias
weight.add_batch(x)
return out

def forward(self, x):
if len(x.shape) == 4:
x = x.view(x.size(0), -1)

residual = x
x = F.relu(self.linear0_quant(x))
x = self.linear1_quant(x)
x = F.relu(x) + residual
x = self.linear2_quant(x)
x = F.relu(x) + residual
x = super().my_linear(x, self.linear3_w, self.linear3_b)
return x
## End SimpleNet_V2


Expand All @@ -321,9 +335,11 @@ def quantize_gptq_custom(model, train_loader):
quantizers = {}

for layer_id in range(nb_layers):

if not is_last_layer(layer_id):


print(f"Quantizing layer {layer_id} ...")

model.alloc_gptq(layer_id)

for i in range(nsamples):
Expand All @@ -342,12 +358,56 @@ def quantize_gptq_custom(model, train_loader):

return quantizers


def model_pack_custom(model, quantizers, wbits, groupsize):
pass
# Extract weights and bias from model
is_weight = re.compile(r'^linear\d+_w$')
weights, bias = OrderedDict(), OrderedDict()
for name, param in model.w.items():
if is_weight.match(name):
weights[name] = param
else:
bias[name] = param

make_quant_custom(model, quantizers, wbits, groupsize)
qlayers = find_layers(model, [QuantLinear_custom])

print('Packing ...')
for i in range(len(qlayers)):
name_w, name_b, layer_quant_name = f'linear{i}_w', f'linear{i}_b', f'linear{i}_quant'
quantizers[name_w],scale,zero,g_idx = quantizers[name_w]
qlayers[layer_quant_name].pack(weights[name_w], bias[name_b], scale, zero, g_idx)
print('Done.')
return model

def load_quant_custom(model, checkpoint, wbits, groupsize):
print('Loading model ...')
model = model.eval()
# Extract weights and bias from model
is_weight = re.compile(r'^linear\d+_w$')
weights, bias = OrderedDict(), OrderedDict()
for name, param in model.w.items():
if is_weight.match(name):
weights[name] = param
else:
bias[name] = param

# Create linear layer out of weights and bias
layers = {}
for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()):
layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True)
layers[w_name].weight.data = w_param
layers[w_name].bias.data = b_param

# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
if "linear3_w" in layers:
del layers["linear3_w"]

make_quant_custom(model, layers, wbits, groupsize)
model.load_state_dict(torch.load(checkpoint))
print('Done.')
return model

def load_quant_custom(model, quantizers, wbits, groupsize):
pass

def assert_parameters(model, model_custom):
is_weight = re.compile(r'^linear\d+.weight$')
Expand All @@ -371,6 +431,7 @@ def assert_parameters(model, model_custom):
parser.add_argument("--eval_gptq", action="store_true")
parser.add_argument("--train_custom", action="store_true")
parser.add_argument("--gptq_custom", action="store_true")
parser.add_argument("--eval_gptq_custom", action="store_true")
parser.add_argument("--pyquant", action="store_true")

args = parser.parse_args()
Expand All @@ -381,7 +442,9 @@ def assert_parameters(model, model_custom):
criterion = nn.CrossEntropyLoss()
train_loader, _, _ = MNISTloader(train_val_split=0.95).load()

#TODO: Do Custom packing
#TODO: Do custom eval gptq
#TODO: Is reference GPTQ quantizing bias as well ?
#TODO: Add seed everywhere in GPT for reproducibility

## ================== REFERENCE ==================
if args.train:
Expand Down Expand Up @@ -430,6 +493,19 @@ def assert_parameters(model, model_custom):
model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
torch.save(model.state_dict(), "model_quantized_custom.pt")
print("Done Custom GPTQ")
elif args.eval_gptq_custom:
model = GPTQ_CUSTOM("./model_custom.pt")
device = torch.device("cuda:0")
model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE)
model = model.to(device)

start = time.time()
val_loss, val_acc = evaluate(device, model, criterion, train_loader)
end = time.time()

print(f"wbits = {WBITS} using {device}")
print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}")
print(f"Latency: {end - start}")
## ================== MISC ==================
elif args.pyquant:
# Baseline post-training quantization from Pytorch
Expand Down

0 comments on commit 4233522

Please sign in to comment.