From bb4346569cfac44e65dc57d87f5f2e50598998a3 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Wed, 3 May 2023 12:30:10 +0000 Subject: [PATCH] feat(quantize): readapt GPTQ for rwkv --- quantize/tmp_rwkv.py | 166 +++++++++++++++++++++++++------------------ 1 file changed, 97 insertions(+), 69 deletions(-) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index 1ed2ad1f..83d0dfd7 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -5,11 +5,14 @@ import os import torch.nn.functional as F -import torch.nn as nn +from collections import OrderedDict import time import math import re +WBITS = 8 +GROUPSIZE = -1 + class GPTQ_RWKV(RWKV): ### begin GPTQ @@ -29,7 +32,6 @@ def __init__(self, weight, name): self.deactivate_add_batch_call = False def add_batch(self, inp): - # After calling fasterquant, we don't want to call add_batch anymore if self.deactivate_add_batch_call: return @@ -37,9 +39,8 @@ def add_batch(self, inp): if len(inp.shape) == 2: inp = inp.unsqueeze(0) - #TODO: is the case with len = 1 still necessary ? - tmp = 1 if len(inp.shape) == 1 else inp.shape[0] - + tmp = inp.shape[0] + # Assume weight come from nn.Linear if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) @@ -52,7 +53,9 @@ def add_batch(self, inp): def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): W = self.weight.data.clone() - # Need to transpose here, same reason as in __init__ with self.columns + # OLD: Need to transpose here, same reason as in __init__ with self.columns + # UPDATE: no need to tranpose as we already transpose in my_linear() + # UPDATE2: for rwkv, this is necessary W = W.t() W = W.float() @@ -63,10 +66,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) H = self.H del self.H + dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 - + if actorder: perm = torch.argsort(torch.diag(H), descending=True) W = W[:, perm] @@ -82,6 +86,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -101,6 +110,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) if (i1 + i) % groupsize == 0: self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + q = quantize( w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq ).flatten() @@ -116,15 +130,27 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + torch.cuda.synchronize() print('time %.2f' % (time.time() - tick)) print('error', torch.sum(Losses).item()) - + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if actorder: invperm = torch.argsort(perm) Q = Q[:, invperm] + g_idx = g_idx[invperm] self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx ### end GPTQ @@ -134,6 +160,7 @@ def __init__(self, model, strategy): for i in range(self.args.n_layer): assert self.strategy[i].device == "cpu" + #TODO: Change to match my implem def _fill_subset(self, layer_id): # Keep only layer within block layer_id is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') @@ -146,18 +173,18 @@ def _fill_subset(self, layer_id): if is_last_layer: self.subset["head.weight"] = self.w["head.weight"] - + return self.subset + def alloc_gptq(self, layer_id): self.subset = {} self.gptq = {} - self._fill_subset(layer_id) - + self.subset = self._fill_subset(layer_id) + for name in self.subset: self.gptq[name] = self.GPTQ(self.subset[name], name) self.gptq[name].quantizer = Quantizer() - #TODO: add argparse to configure - self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False) + self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) def free_gptq(self): self.subset = {} @@ -166,11 +193,10 @@ def free_gptq(self): def fasterquant(self, layer_id, quantizers): for name in self.subset: - print(f"Quantizing {name} of layer {layer_id}") - #TODO: add argparse to fastquant - self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False) - # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers[name] = self.gptq[name].quantizer + print(layer_id, name) + print('Quantizing ...') + scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) ### end GPTQ_RWKV @@ -326,7 +352,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): orx = self.w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x omy = self.w[f'{att}output.weight_my'] if wtype == torch.uint8 else x ory = self.w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x - + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3], ln_w=self.w[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], @@ -338,12 +364,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, omx=omx, orx=orx, omy=omy, ory=ory, ) - - # Deactivate add_batch() after quantization is applied - kw.deactivate_add_batch_call = True - vw.deactivate_add_batch_call = True - rw.deactivate_add_batch_call = True - ow.deactivate_add_batch_call = True if dd.stream: del kw, vw, rw, ow @@ -378,11 +398,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, ) - - # Deactivate add_batch() after quantization is applied - kw.deactivate_add_batch_call = True - vw.deactivate_add_batch_call = True - rw.deactivate_add_batch_call = True if dd.stream: del kw, vw, rw @@ -392,7 +407,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x = x / 2 is_last_layer = i == (args.n_layer - 1) - if is_last_layer: dd = self.strategy[args.n_layer] x = x[-1,:] if (seq_mode and (not full_output)) else x @@ -410,63 +424,77 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): ### end RWKV -model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') - -NSAMPLES=2 -HIDDEN_SIZE=model.args.n_embd -SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m - -# train_tokens, test_tokens = get_loaders( -# dataset_name="wikitext2", -# nsamples=NSAMPLES, -# seed=42, -# seqlen=SEQLEN, -# model=model -# ) - -# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) -tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) -print("tokens.shape", tokens.shape) - -is_last_layer = lambda x: x == (model.args.n_layer - 1) - -start_time = time.time() - -#TODO: Do the same in GPU side -with torch.no_grad(): +@torch.no_grad() +def quantize_gptq_custom(model, tokens): + nsamples = tokens.shape[0] seq_mode = len(tokens) > 1 + is_last_layer = lambda x: x == (model.args.n_layer - 1) + inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]] outs = torch.zeros_like(inps) - quantizers = {} - + for layer_id in range(model.args.n_layer): + + print(f"Quantizing layer {layer_id} ...") model.alloc_gptq(layer_id) - for j in range(NSAMPLES): + for i in range(nsamples): + #TODO: Are outs value normal ? (they look almost all the same) if not is_last_layer(layer_id): - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) + outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) - + _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + + for gptq_layer in model.gptq.values(): + gptq_layer.deactivate_add_batch_call = True + + tmp = model.w["blocks.0.att.key.weight"] + model.fasterquant(layer_id, quantizers) - for j in range(NSAMPLES): + for i in range(nsamples): if not is_last_layer(layer_id): - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) + outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) - + _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + + # Assign the quantized weights to the model + for key in model.gptq.keys(): + model.w[key].copy_(model.gptq[key].weight) + model.free_gptq() # We need to pass the outputs of block i as input of block i+1 (except for last block) if not is_last_layer(layer_id): inps, outs = outs, inps -end_time = time.time() + return quantizers + +if __name__ == "__main__": + + model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') + + NSAMPLES=2 + HIDDEN_SIZE=model.args.n_embd + SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m + + train_tokens, test_tokens = get_loaders( + dataset_name="wikitext2", + nsamples=NSAMPLES, + seed=42, + seqlen=SEQLEN, + model=model + ) + + tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) + tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) + print("tokens.shape", tokens.shape) -print(f"Done in {end_time - start_time:.2f} seconds") + import pdb; pdb.set_trace() + # quantizers = quantize_gptq_custom(model, tokens) -# TODO: Do something with quantizers dictionary -# TODO: pack3 save model \ No newline at end of file + # model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + # torch.save(model.state_dict(), "model_quantized_custom.pt") + # print("Done Custom GPTQ") \ No newline at end of file