Skip to content

Commit

Permalink
feat(quantize): readapt GPTQ for rwkv
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed May 3, 2023
1 parent 2e1a70e commit bb43465
Showing 1 changed file with 97 additions and 69 deletions.
166 changes: 97 additions & 69 deletions quantize/tmp_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import os
import torch.nn.functional as F
import torch.nn as nn
from collections import OrderedDict
import time
import math
import re

WBITS = 8
GROUPSIZE = -1

class GPTQ_RWKV(RWKV):

### begin GPTQ
Expand All @@ -29,17 +32,15 @@ def __init__(self, weight, name):
self.deactivate_add_batch_call = False

def add_batch(self, inp):

# After calling fasterquant, we don't want to call add_batch anymore
if self.deactivate_add_batch_call:
return

if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

#TODO: is the case with len = 1 still necessary ?
tmp = 1 if len(inp.shape) == 1 else inp.shape[0]

tmp = inp.shape[0]

# Assume weight come from nn.Linear
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
Expand All @@ -52,7 +53,9 @@ def add_batch(self, inp):

def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False):
W = self.weight.data.clone()
# Need to transpose here, same reason as in __init__ with self.columns
# OLD: Need to transpose here, same reason as in __init__ with self.columns
# UPDATE: no need to tranpose as we already transpose in my_linear()
# UPDATE2: for rwkv, this is necessary
W = W.t()
W = W.float()

Expand All @@ -63,10 +66,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)

H = self.H
del self.H

dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
Expand All @@ -82,6 +86,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

g_idx = []
scale = []
zero = []
now_idx = 1

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand All @@ -101,6 +110,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)

if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1

q = quantize(
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
).flatten()
Expand All @@ -116,15 +130,27 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])


torch.cuda.synchronize()
print('time %.2f' % (time.time() - tick))
print('error', torch.sum(Losses).item())


groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]

self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype)

if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale,dim=1)
zero = torch.cat(zero,dim=1)
return scale,zero,g_idx

### end GPTQ

Expand All @@ -134,6 +160,7 @@ def __init__(self, model, strategy):
for i in range(self.args.n_layer):
assert self.strategy[i].device == "cpu"

#TODO: Change to match my implem
def _fill_subset(self, layer_id):
# Keep only layer within block layer_id
is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$')
Expand All @@ -146,18 +173,18 @@ def _fill_subset(self, layer_id):
if is_last_layer:
self.subset["head.weight"] = self.w["head.weight"]


return self.subset

def alloc_gptq(self, layer_id):
self.subset = {}
self.gptq = {}

self._fill_subset(layer_id)
self.subset = self._fill_subset(layer_id)

for name in self.subset:
self.gptq[name] = self.GPTQ(self.subset[name], name)
self.gptq[name].quantizer = Quantizer()
#TODO: add argparse to configure
self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False)
self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False)

def free_gptq(self):
self.subset = {}
Expand All @@ -166,11 +193,10 @@ def free_gptq(self):
def fasterquant(self, layer_id, quantizers):

for name in self.subset:
print(f"Quantizing {name} of layer {layer_id}")
#TODO: add argparse to fastquant
self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False)
# self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
quantizers[name] = self.gptq[name].quantizer
print(layer_id, name)
print('Quantizing ...')
scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False)
quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())

### end GPTQ_RWKV

Expand Down Expand Up @@ -326,7 +352,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
orx = self.w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x
omy = self.w[f'{att}output.weight_my'] if wtype == torch.uint8 else x
ory = self.w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x

x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT(
x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3],
ln_w=self.w[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'],
Expand All @@ -338,12 +364,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
rmx=rmx, rrx=rrx, rmy=rmy, rry=rry,
omx=omx, orx=orx, omy=omy, ory=ory,
)

# Deactivate add_batch() after quantization is applied
kw.deactivate_add_batch_call = True
vw.deactivate_add_batch_call = True
rw.deactivate_add_batch_call = True
ow.deactivate_add_batch_call = True

if dd.stream:
del kw, vw, rw, ow
Expand Down Expand Up @@ -378,11 +398,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
vmx=vmx, vrx=vrx, vmy=vmy, vry=vry,
rmx=rmx, rrx=rrx, rmy=rmy, rry=rry,
)

# Deactivate add_batch() after quantization is applied
kw.deactivate_add_batch_call = True
vw.deactivate_add_batch_call = True
rw.deactivate_add_batch_call = True

if dd.stream:
del kw, vw, rw
Expand All @@ -392,7 +407,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
x = x / 2

is_last_layer = i == (args.n_layer - 1)

if is_last_layer:
dd = self.strategy[args.n_layer]
x = x[-1,:] if (seq_mode and (not full_output)) else x
Expand All @@ -410,63 +424,77 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):

### end RWKV

model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')

NSAMPLES=2
HIDDEN_SIZE=model.args.n_embd
SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m

# train_tokens, test_tokens = get_loaders(
# dataset_name="wikitext2",
# nsamples=NSAMPLES,
# seed=42,
# seqlen=SEQLEN,
# model=model
# )

# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64)
print("tokens.shape", tokens.shape)

is_last_layer = lambda x: x == (model.args.n_layer - 1)

start_time = time.time()

#TODO: Do the same in GPU side
with torch.no_grad():
@torch.no_grad()
def quantize_gptq_custom(model, tokens):
nsamples = tokens.shape[0]
seq_mode = len(tokens) > 1
is_last_layer = lambda x: x == (model.args.n_layer - 1)

inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]]
outs = torch.zeros_like(inps)

quantizers = {}

for layer_id in range(model.args.n_layer):

print(f"Quantizing layer {layer_id} ...")

model.alloc_gptq(layer_id)

for j in range(NSAMPLES):
for i in range(nsamples):
#TODO: Are outs value normal ? (they look almost all the same)
if not is_last_layer(layer_id):
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)
else:
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)

_ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)

for gptq_layer in model.gptq.values():
gptq_layer.deactivate_add_batch_call = True

tmp = model.w["blocks.0.att.key.weight"]

model.fasterquant(layer_id, quantizers)

for j in range(NSAMPLES):
for i in range(nsamples):
if not is_last_layer(layer_id):
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)
else:
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)

_ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)

# Assign the quantized weights to the model
for key in model.gptq.keys():
model.w[key].copy_(model.gptq[key].weight)

model.free_gptq()

# We need to pass the outputs of block i as input of block i+1 (except for last block)
if not is_last_layer(layer_id):
inps, outs = outs, inps

end_time = time.time()
return quantizers

if __name__ == "__main__":

model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')

NSAMPLES=2
HIDDEN_SIZE=model.args.n_embd
SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m

train_tokens, test_tokens = get_loaders(
dataset_name="wikitext2",
nsamples=NSAMPLES,
seed=42,
seqlen=SEQLEN,
model=model
)

tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64)
print("tokens.shape", tokens.shape)

print(f"Done in {end_time - start_time:.2f} seconds")
import pdb; pdb.set_trace()
# quantizers = quantize_gptq_custom(model, tokens)

# TODO: Do something with quantizers dictionary
# TODO: pack3 save model
# model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
# torch.save(model.state_dict(), "model_quantized_custom.pt")
# print("Done Custom GPTQ")

0 comments on commit bb43465

Please sign in to comment.