Skip to content

Commit

Permalink
feat(sanity-check): add pack and load for gptq implem
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed Apr 28, 2023
1 parent 8a37fb4 commit 5278821
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
16 changes: 16 additions & 0 deletions quantize/gptq/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,22 @@ def make_quant(module, names, bits, groupsize, name=''):
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

def make_quant_custom(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
bias_name = attr.replace('w', 'b')
setattr(module, attr, QuantLinear(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None))
#TODO: No recursive
# for name1, child in module.named_children():
# make_quant_custom(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)



class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
super().__init__()
Expand Down
84 changes: 76 additions & 8 deletions quantize/gptq/sanity_check_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict

from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2
from gptq import *
Expand Down Expand Up @@ -34,9 +35,8 @@ def load_quant(model, checkpoint, wbits, groupsize):

# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
for name in ["linear4"]:
if name in layers:
del layers[name]
if "linear4" in layers:
del layers["linear4"]

make_quant(model, layers, wbits, groupsize)
model.load_state_dict(torch.load(checkpoint))
Expand Down Expand Up @@ -292,7 +292,7 @@ def fasterquant(self, layer_id, quantizers):
print(layer_id, name)
print('Quantizing ...')
scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False)
quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())

## end GPTQ_CUSTOM

Expand Down Expand Up @@ -344,10 +344,61 @@ def quantize_gptq_custom(model, train_loader):


def model_pack_custom(model, quantizers, wbits, groupsize):
pass
# Extract weights and bias from model
is_weight = re.compile(r'^linear\d+_w$')
weights, bias = OrderedDict(), OrderedDict()
for name, param in model.w.items():
if is_weight.match(name):
weights[name] = param
else:
bias[name] = param

# Create linear layer out of weights and bias
layers = {}
for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()):
layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True)
layers[w_name].weight.data = w_param
layers[w_name].bias.data = b_param

make_quant_custom(model, quantizers, wbits, groupsize)
qlayers = find_layers(model, [QuantLinear])

print('Packing ...')
for name in qlayers:
print(name)
quantizers[name],scale,zero,g_idx = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
print('Done.')
return model

def load_quant_custom(model, checkpoint, wbits, groupsize):
print('Loading model ...')
model = model.eval()
# Extract weights and bias from model
is_weight = re.compile(r'^linear\d+_w$')
weights, bias = OrderedDict(), OrderedDict()
for name, param in model.w.items():
if is_weight.match(name):
weights[name] = param
else:
bias[name] = param

# Create linear layer out of weights and bias
layers = {}
for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()):
layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True)
layers[w_name].weight.data = w_param
layers[w_name].bias.data = b_param

# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
if "linear3_w" in layers:
del layers["linear3_w"]
make_quant_custom(model, layers, wbits, groupsize)
model.load_state_dict(torch.load(checkpoint))
print('Done.')
return model

def load_quant_custom(model, quantizers, wbits, groupsize):
pass

def assert_parameters(model, model_custom):
is_weight = re.compile(r'^linear\d+.weight$')
Expand All @@ -371,6 +422,7 @@ def assert_parameters(model, model_custom):
parser.add_argument("--eval_gptq", action="store_true")
parser.add_argument("--train_custom", action="store_true")
parser.add_argument("--gptq_custom", action="store_true")
parser.add_argument("--eval_gptq_custom", action="store_true")
parser.add_argument("--pyquant", action="store_true")

args = parser.parse_args()
Expand All @@ -381,7 +433,9 @@ def assert_parameters(model, model_custom):
criterion = nn.CrossEntropyLoss()
train_loader, _, _ = MNISTloader(train_val_split=0.95).load()

#TODO: Do Custom packing
#TODO: Do custom eval gptq
#TODO: Is reference GPTQ quantizing bias as well ?
#TODO: Add seed everywhere in GPT for reproducibility

## ================== REFERENCE ==================
if args.train:
Expand Down Expand Up @@ -430,6 +484,20 @@ def assert_parameters(model, model_custom):
model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
torch.save(model.state_dict(), "model_quantized_custom.pt")
print("Done Custom GPTQ")
elif args.eval_gptq_custom:
model = GPTQ_CUSTOM("./model_custom.pt")
device = torch.device("cuda:0")
model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE)
model = model.to(device)

#TODO: Fix eval
# start = time.time()
# val_loss, val_acc = evaluate(device, model, criterion, train_loader)
# end = time.time()

# print(f"wbits = {WBITS} using {device}")
# print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}")
# print(f"Latency: {end - start}")
## ================== MISC ==================
elif args.pyquant:
# Baseline post-training quantization from Pytorch
Expand Down

0 comments on commit 5278821

Please sign in to comment.