Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ for RWKV #98

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6e556e5
feat(quantize): measure perplexity on wikitext2
3outeille Apr 18, 2023
bde6374
feat(quantize): add gptq files
3outeille Apr 18, 2023
943af70
feat(quantize): begin to readapt with RWKV
3outeille Apr 18, 2023
629fc9b
breaking(quantize): draft gptq rwkv
3outeille Apr 23, 2023
4a19476
fix(quantize): GPTQ hooks now work with RWKV
3outeille Apr 24, 2023
dba2670
feat(quantize): link fasterquant with RWKV + remove 1D tensor quanti…
3outeille Apr 25, 2023
57079e7
feat(quantize): full gptq pipeline now integrated with RKWV (quite sl…
3outeille Apr 25, 2023
8e78f2d
fix(quantize): add missing part in forward block + support head.weigh…
3outeille Apr 25, 2023
f87df05
feat(sanity-check): begin sanity check for GPTQ on MNIST
3outeille Apr 26, 2023
b77715d
breaking(sanity-check): add save & load option for reference gptq
3outeille Apr 27, 2023
816def4
breaking(sanity-check): enhance with dummy model
3outeille Apr 27, 2023
f141e52
fix(sanity-check): dont quantize last layer for dummy example
3outeille Apr 28, 2023
a1ea882
breaking(sanity-check): adding my implem gptq
3outeille Apr 28, 2023
8a37fb4
fix(sanity-check): training ref and implem now yield same outputs
3outeille Apr 28, 2023
4233522
feat(sanity-check): implem version of gptq now added
3outeille Apr 28, 2023
e74d72a
fix(sanity-check): ref and implem now yield the same results at every…
3outeille May 2, 2023
cf14124
feat(quantize): readapt GPTQ for rwkv
3outeille May 3, 2023
c2bbe64
breaking(gptq): quantizing only 1 layer yield high perplexity
3outeille May 7, 2023
9b9c714
fix(ppl): measure ppl using sliding window
3outeille May 23, 2023
3399ef0
update
3outeille Aug 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README_quantize.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# GPTQ

```
pip install -r requirements-quantize.txt
python quantize/gptq/setup_cuda.py install
```
68 changes: 68 additions & 0 deletions quantize/gptq/datautils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import torch
import os
import pathlib
import tokenizers
import random
from myRWKV import RWKV

from datasets import load_dataset

def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)

def get_wikitext2(nsamples, seed, seqlen, model):
is_rwkv = isinstance(model, RWKV)

if is_rwkv:
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

print('Loading RWKV tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '../20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
trainenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(traindata['text'])).ids, dtype=torch.long), 0)
testenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(testdata['text'])).ids, dtype=torch.long), 0)
else:
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
# testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

print('Loading tokenizer')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
# trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
# testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
sentence = "My name is Ferdinand and I live in France"
trainenc = tokenizer("\n\n".join(sentence), return_tensors='pt')
testenc = tokenizer("\n\n".join(sentence), return_tensors='pt')

random.seed(seed)
trainloader = []
shape = trainenc.shape if is_rwkv else trainenc.input_ids.shape
trainenc = trainenc if is_rwkv else trainenc.input_ids
random_idx = [random.randint(0, shape[1] - seqlen - 1) for _ in range(nsamples)]

for i in range(nsamples):
j = random_idx[i] + seqlen
inp = trainenc[:, random_idx[i]:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc

def get_loaders(
dataset_name, nsamples, seed, seqlen, model
):
if 'wikitext2' in dataset_name:
return get_wikitext2(nsamples, seed, seqlen, model)
if 'ptb' in dataset_name:
raise NotImplementedError('PTB is not supported yet')
# if 'new' in dataset_name:
# return get_ptb_new(nsamples, seed, seqlen, model)
# return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in dataset_name:
raise NotImplementedError('C4 is not supported yet')
# if 'new' in dataset_name:
# return get_c4_new(nsamples, seed, seqlen, model)
# return get_c4(nsamples, seed, seqlen, model)
178 changes: 178 additions & 0 deletions quantize/gptq/gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import math
import time

import torch
import torch.nn as nn
import transformers

from .quant import *


DEBUG = False

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


class GPTQ:
def __init__(self, layer, name):
self.name = name
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0

def add_batch(self, inp, out):
if DEBUG:
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())

def fasterquant(
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

tick = time.time()

if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)

H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]

Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

g_idx = []
scale = []
zero = []
now_idx = 1

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)

if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1

q = quantize(
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

if DEBUG:
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
print(torch.sum(Losses))

torch.cuda.synchronize()
print('time %.2f' % (time.time() - tick))
print('error', torch.sum(Losses).item())

groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]

if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
if DEBUG:
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))

if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale,dim=1)
zero = torch.cat(zero,dim=1)
return scale,zero,g_idx

def free(self):
if DEBUG:
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
16 changes: 16 additions & 0 deletions quantize/gptq/modelutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torch.nn as nn


DEV = torch.device('cuda:0')


def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
Loading