diff --git a/README_quantize.md b/README_quantize.md new file mode 100644 index 00000000..4a057a44 --- /dev/null +++ b/README_quantize.md @@ -0,0 +1,6 @@ +# GPTQ + +``` +pip install -r requirements-quantize.txt +python quantize/gptq/setup_cuda.py install +``` \ No newline at end of file diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py new file mode 100644 index 00000000..b8f12ce7 --- /dev/null +++ b/quantize/gptq/datautils.py @@ -0,0 +1,68 @@ +import numpy as np +import torch +import os +import pathlib +import tokenizers +import random +from myRWKV import RWKV + +from datasets import load_dataset + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +def get_wikitext2(nsamples, seed, seqlen, model): + is_rwkv = isinstance(model, RWKV) + + if is_rwkv: + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + print('Loading RWKV tokenizer') + tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '../20B_tokenizer.json' + tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) + trainenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(traindata['text'])).ids, dtype=torch.long), 0) + testenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(testdata['text'])).ids, dtype=torch.long), 0) + else: + # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + # testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + print('Loading tokenizer') + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + # trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + # testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + sentence = "My name is Ferdinand and I live in France" + trainenc = tokenizer("\n\n".join(sentence), return_tensors='pt') + testenc = tokenizer("\n\n".join(sentence), return_tensors='pt') + + random.seed(seed) + trainloader = [] + shape = trainenc.shape if is_rwkv else trainenc.input_ids.shape + trainenc = trainenc if is_rwkv else trainenc.input_ids + random_idx = [random.randint(0, shape[1] - seqlen - 1) for _ in range(nsamples)] + + for i in range(nsamples): + j = random_idx[i] + seqlen + inp = trainenc[:, random_idx[i]:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_loaders( + dataset_name, nsamples, seed, seqlen, model +): + if 'wikitext2' in dataset_name: + return get_wikitext2(nsamples, seed, seqlen, model) + if 'ptb' in dataset_name: + raise NotImplementedError('PTB is not supported yet') + # if 'new' in dataset_name: + # return get_ptb_new(nsamples, seed, seqlen, model) + # return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in dataset_name: + raise NotImplementedError('C4 is not supported yet') + # if 'new' in dataset_name: + # return get_c4_new(nsamples, seed, seqlen, model) + # return get_c4(nsamples, seed, seqlen, model) diff --git a/quantize/gptq/gptq.py b/quantize/gptq/gptq.py new file mode 100644 index 00000000..62ed697f --- /dev/null +++ b/quantize/gptq/gptq.py @@ -0,0 +1,178 @@ +import math +import time + +import torch +import torch.nn as nn +import transformers + +from .quant import * + + +DEBUG = False + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class GPTQ: + def __init__(self, layer, name): + self.name = name + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp, out): + if DEBUG: + self.inp1 = inp + self.out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def fasterquant( + self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False + ): + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if DEBUG: + self.layer.weight.data[:, :i2] = Q[:, :i2] + self.layer.weight.data[:, i2:] = W[:, i2:] + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + print(torch.sum(Losses)) + + torch.cuda.synchronize() + print('time %.2f' % (time.time() - tick)) + print('error', torch.sum(Losses).item()) + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + if DEBUG: + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx + + def free(self): + if DEBUG: + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() diff --git a/quantize/gptq/modelutils.py b/quantize/gptq/modelutils.py new file mode 100644 index 00000000..0c5d12b1 --- /dev/null +++ b/quantize/gptq/modelutils.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + + +DEV = torch.device('cuda:0') + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py new file mode 100644 index 00000000..c3a9a7ee --- /dev/null +++ b/quantize/gptq/quant.py @@ -0,0 +1,524 @@ +import numpy as np +import torch +import torch.nn as nn +import math + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure( + self, + bits, perchannel=False, sym=True, + mse=False, norm=2.4, grid=100, maxshrink=.8, + trits=False + ): + + self.maxq = torch.tensor(2 ** bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +try: + import quant_cuda + is_cuda = True +except: + print('CUDA extension not installed.') + is_cuda = False + +def make_quant(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + +def make_quant_custom(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear_custom): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + bias = getattr(module, attr.replace('w', 'b')) + layer_name = attr.replace('w', 'quant') + setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], bias is not None)) + + +class QuantLinear_custom(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): + super().__init__() + if bits not in [2,3,4,8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2,4,8]: + self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + self.is_cuda = is_cuda + + def pack(self, weight, bias, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if bias is not None: + self.bias = bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + #OLD: intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight.append(torch.round((weight.data[idx, :] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold): + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2,4,8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4) + zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1))&0x7 + weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4) + weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + num_itr = self.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])) + else: + num_dim = self.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + + +class QuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): + super().__init__() + if bits not in [2,3,4,8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2,4,8]: + self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + self.is_cuda = is_cuda + + def pack(self, linear, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold): + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2,4,8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4) + zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1))&0x7 + weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4) + weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + num_itr = self.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])) + else: + num_dim = self.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + diff --git a/quantize/gptq/quant_cuda.cpp b/quantize/gptq/quant_cuda.cpp new file mode 100644 index 00000000..3200a9f2 --- /dev/null +++ b/quantize/gptq/quant_cuda.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +void vecquant2matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); +} diff --git a/quantize/gptq/quant_cuda_kernel.cu b/quantize/gptq/quant_cuda_kernel.cu new file mode 100644 index 00000000..60c1dc08 --- /dev/null +++ b/quantize/gptq/quant_cuda_kernel.cu @@ -0,0 +1,509 @@ +#include +#include +#include +#include +#include + +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +__device__ double atomicAdd( + double* address, + double val +) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, + assumed, + __double_as_longlong(val + __longlong_as_double(assumed)) + ); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT2 = 16; +const int BLOCKHEIGHT3 = 24; +const int BLOCKHEIGHT4 = 32; +const int BLOCKHEIGHT8 = 64; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 16; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); + int k_bit = (k % 16) * 2; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = (h / 3) * 32; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + unsigned int z_tmp; + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; + int k_mod = k % 32; + int k_bit; + + if (k_mod != 10){ + if (k_mod != 21){ + k_bit = k_mod; + if (k_bit > 21){ + k_bit -= 22; + k_bit *= 3; + k_bit += 2; + k_w += 2; + } else if (k_bit > 10){ + k_bit -= 11; + k_bit *= 3; + k_bit += 1; + k_w += 1; + } else { + k_bit *= 3; + } + } else { + k_w += 1; + } + } + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = scalar_t((z_tmp) + 1); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = scalar_t((z_tmp) + 1); + } else { + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + } + + if (k_mod == 10) { + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); + } else if (k_mod == 21){ + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); + } else { + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); + } + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 8; + int k; + unsigned int g; + scalar_t w_tmp; + + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py new file mode 100644 index 00000000..d96ee24f --- /dev/null +++ b/quantize/gptq/sanity_check_main.py @@ -0,0 +1,538 @@ +import argparse +import time +import re +import torch +import torch.nn as nn +import torch.optim as optim +from collections import OrderedDict +import torch.nn.functional as F + +from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2 +from gptq import * +from modelutils import * +from quant import * + +WBITS = 8 +GROUPSIZE = -1 + +## =============== REFERENCE =============== +def model_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [QuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name],scale,zero,g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + +def load_quant(model, checkpoint, wbits, groupsize): + print('Loading model ...') + model = model.eval() + layers = find_layers(model) + + # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) + # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) + if "linear4" in layers: + del layers["linear4"] + + make_quant(model, layers, wbits, groupsize) + model.load_state_dict(torch.load(checkpoint)) + print('Done.') + return model + +@torch.no_grad() +def quantize_gptq(model, train_loader): + quantizers = {} + layers = list(model.modules())[1:] + layers = [l for l in layers if isinstance(l, nn.Linear)] + is_last_layer = lambda x: x == (len(layers) - 1) + + nsamples = len(train_loader.dataset) + batch_size = train_loader.batch_size + + inps = torch.zeros((nsamples, model.N), dtype=torch.float) + for i, (inp, _) in enumerate(train_loader): + inps[i*batch_size:(i+1)*batch_size] = inp.view(-1, 32*32) + outs = torch.zeros_like(inps) + + for layer_id in range(len(layers)): + + if not is_last_layer(layer_id): + + layer = layers[layer_id] + + subset = find_layers(layer) + gptq = {} + + print(f"Quantizing layer {layer_id} ...") + for name in subset: + gptq[name] = GPTQ(subset[name], name) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for i in range(nsamples): + outs[i] = layer(inps[i]) + + for h in handles: h.remove() + + for name in subset: + print(i, name) + print('Quantizing ...') + scale,zero,g_idx = gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id + 1}"] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + gptq[name].free() + + for i in range(nsamples): + outs[i] = layer(inps[i]) + + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + return quantizers + +## =============== OUR IMPLEMENTATION =============== +class GPTQ_CUSTOM(SimpleNet_V2): + + ### begin GPTQ + class GPTQ: + def __init__(self, weight, name): + #TODO: Remove name, only used for debugging + self.name = name + self.weight = weight.clone() + self.dev = weight.device + # In GPTQ, they use nn.Linear(x) which performs x @ w.T but in RWKV, we perform x @ w instead + # Problem is self.H is a square matrix which depends on self.columns = W.shape[1] in the original code + # But if we keep it that way, this will break self.H += inp.matmul(inp.t()) because inp.shape[1] != W.shape[1] + # Thus, we have to use self.W.shape[0] instead + self.columns = self.weight.shape[0] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.deactivate_add_batch_call = False + + def add_batch(self, inp): + # After calling fasterquant, we don't want to call add_batch anymore + if self.deactivate_add_batch_call: + return + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + tmp = inp.shape[0] + + # Assume weight come from nn.Linear + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): + W = self.weight.data.clone() + # OLD: Need to transpose here, same reason as in __init__ with self.columns + # UPDATE: no need to tranpose as we already transpose in my_linear() + # W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + + torch.cuda.synchronize() + print('time %.2f' % (time.time() - tick)) + print('error', torch.sum(Losses).item()) + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx + + ### end GPTQ + + ### begin GPTQ_CUSTOM + def __init__(self, checkpoint_path): + super().__init__() + self.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + + def _fill_subset(self, layer_id): + is_last_layer = (layer_id == self.nb_layers - 1) + if is_last_layer: + return {} + # Keep only layer within block layer_id + is_weight = re.compile(f'^linear{layer_id}_w$') + + for name in dir(self): + if is_weight.match(name): + self.subset[name] = getattr(self, name) + return self.subset + + def alloc_gptq(self, layer_id): + self.subset = {} + self.gptq = {} + + self.subset = self._fill_subset(layer_id) + + for name in self.subset: + self.gptq[name] = self.GPTQ(self.subset[name], name) + self.gptq[name].quantizer = Quantizer() + self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) + + def free_gptq(self): + self.subset = {} + self.gptq = {} + + def fasterquant(self, layer_id, quantizers): + + for name in self.subset: + print(layer_id, name) + print('Quantizing ...') + scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + + ## end GPTQ_CUSTOM + + ## Begin SimpleNet_V2 + def my_linear(self, x, weight, bias): + # out = x @ weight.weight.T + bias # Use version below as it is more stable + out = F.linear(x, weight.weight, bias) + weight.add_batch(x) + return out + + def forward(self, x): + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + residual = x + #TODO: maybe we would need to transpose weight when building linear0_quant ? + x = F.relu(self.linear0_quant(x)) + x = self.linear1_quant(x) + x = F.relu(x) + residual + x = self.linear2_quant(x) + x = F.relu(x) + residual + x = super().my_linear(x, self.linear3_w, self.linear3_b) + return x + ## End SimpleNet_V2 + + +@torch.no_grad() +def quantize_gptq_custom(model, train_loader): + nb_layers = model.nb_layers + is_last_layer = lambda x: x == (nb_layers - 1) + + nsamples = len(train_loader.dataset) + batch_size = train_loader.batch_size + + inps = torch.zeros((nsamples, model.N), dtype=torch.float) + for i, (inp, _) in enumerate(train_loader): + inps[i*batch_size:(i+1)*batch_size] = inp.view(-1, 32*32) + outs = torch.zeros_like(inps) + + quantizers = {} + + for layer_id in range(nb_layers): + + if not is_last_layer(layer_id): + + print(f"Quantizing layer {layer_id} ...") + + bias = getattr(model, f"linear{layer_id}_b") + + model.alloc_gptq(layer_id) + + for i in range(nsamples): + outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], bias) + + model.gptq[f"linear{layer_id}_w"].deactivate_add_batch_call = True + + model.fasterquant(layer_id, quantizers) + + for i in range(nsamples): + outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], bias) + + setattr(model, f"linear{layer_id}_w", nn.Parameter(model.gptq[f"linear{layer_id}_w"].weight)) + model.free_gptq() + + inps, outs = outs, inps + + return quantizers + +def model_pack_custom(model, quantizers, wbits, groupsize): + # Extract weights and bias from model + is_weight, is_bias = re.compile(r'^linear\d+_w$'), re.compile(r'^linear\d+_b$') + weights, bias = OrderedDict(), OrderedDict() + + for attr in dir(model): + if is_weight.match(attr): + weights[attr] = getattr(model, attr) + elif is_bias.match(attr): + bias[attr] = getattr(model, attr) + + make_quant_custom(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [QuantLinear_custom]) + + print('Packing ...') + for i in range(len(qlayers)): + name_w, name_b, layer_quant_name = f'linear{i}_w', f'linear{i}_b', f'linear{i}_quant' + quantizers[name_w],scale,zero,g_idx = quantizers[name_w] + qlayers[layer_quant_name].pack(weights[name_w], bias[name_b], scale, zero, g_idx) + print('Done.') + return model + +def load_quant_custom(model, checkpoint, wbits, groupsize): + print('Loading model ...') + model = model.eval() + # Extract weights and bias from model + is_weight, is_bias = re.compile(r'^linear\d+_w$'), re.compile(r'^linear\d+_b$') + weights, bias = OrderedDict(), OrderedDict() + for attr in dir(model): + if is_weight.match(attr): + weights[attr] = getattr(model, attr) + elif is_bias.match(attr): + bias[attr] = getattr(model, attr) + + # Create linear layer out of weights and bias + layers = {} + for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()): + layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True) + layers[w_name].weight.data = w_param + layers[w_name].bias.data = b_param + + # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) + # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) + if "linear3_w" in layers: + del layers["linear3_w"] + + make_quant_custom(model, layers, wbits, groupsize) + model.load_state_dict(torch.load(checkpoint)) + print('Done.') + return model + + +def assert_parameters(model, model_custom): + is_weight = re.compile(r'^linear\d+.weight$') + weights, bias = {}, {} + for name, param in model.named_parameters(): + if is_weight.match(name): + weights[name] = param + else: + bias[name] = param + + for i, (name, param) in enumerate(weights.items()): + assert torch.allclose(param, model_custom.state_dict()[f"linear{i}_w"]) + + for i, (name, param) in enumerate(bias.items()): + assert torch.allclose(param, model_custom.state_dict()[f"linear{i}_b"]) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train", action="store_true") + parser.add_argument("--gptq", action="store_true") + parser.add_argument("--eval_gptq", action="store_true") + parser.add_argument("--train_custom", action="store_true") + parser.add_argument("--gptq_custom", action="store_true") + parser.add_argument("--eval_gptq_custom", action="store_true") + parser.add_argument("--pyquant", action="store_true") + + args = parser.parse_args() + + seed_everything(42) + lr = 0.02 + num_epochs = 5 + criterion = nn.CrossEntropyLoss() + train_loader, _, _ = MNISTloader(train_val_split=0.95).load() + + ## ================== REFERENCE ================== + if args.train: + model = SimpleNet() + optimizer = optim.Adam(model.parameters(), lr) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + train(num_epochs, model, optimizer, criterion, train_loader, device) + torch.save(model.state_dict(), "model.pt") + elif args.gptq: + model = SimpleNet() + device = torch.device("cpu") + model.load_state_dict(torch.load("./model.pt", map_location="cpu")) + model = model.to(device) + quantizers = quantize_gptq(model, train_loader) + model_pack(model, quantizers, WBITS, GROUPSIZE) + torch.save(model.state_dict(), "model_quantized.pt") + print("Done GPTQ") + + elif args.eval_gptq: + model = SimpleNet() + device = torch.device("cuda:0") + model = load_quant(model, "model_quantized.pt", WBITS, GROUPSIZE) + model = model.to(device) + + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print(f"wbits = {WBITS} using {device}") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") + ## ================== CUSTOM ================== + elif args.train_custom: + model = SimpleNet_V2() + optimizer = optim.Adam(model.parameters(), lr) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + train(num_epochs, model, optimizer, criterion, train_loader, device) + torch.save(model.state_dict(), "model_custom.pt") + elif args.gptq_custom: + device = torch.device("cpu") + model = GPTQ_CUSTOM("./model_custom.pt") + model = model.to(device) + quantizers = quantize_gptq_custom(model, train_loader) + model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + torch.save(model.state_dict(), "model_quantized_custom.pt") + print("Done Custom GPTQ") + elif args.eval_gptq_custom: + model = GPTQ_CUSTOM("./model_custom.pt") + device = torch.device("cuda:0") + model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE) + model = model.to(device) + + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print(f"wbits = {WBITS} using {device}") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") + ## ================== MISC ================== + elif args.pyquant: + # Baseline post-training quantization from Pytorch + model = SimpleNet() + device = torch.device("cpu") + model.load_state_dict(torch.load("./model.pt")) + model.eval() + model.qconfig = torch.ao.quantization.get_default_qconfig('x86') + model_prepared = torch.ao.quantization.prepare(model) + + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = model_prepared.forward_pyquant(inputs) + + model_quant = torch.ao.quantization.convert(model_prepared) + + start_q = time.time() + val_loss_q, val_acc_q = evaluate(device, model_quant, criterion, train_loader, is_pyquant=True) + end_q = time.time() + + print("Pytorch post-training quantization INT8") + print(model_quant) + print(f"val_loss_q: {val_loss_q:.3f} \t val_acc_q:{val_acc_q:.3f}") + print(f"Latency: {end_q - start_q}") + else: + # Evaluate float 32 + model = SimpleNet() + device = torch.device("cpu") + model.load_state_dict(torch.load("./model.pt", map_location="cpu")) + model = model.to(device) + + # Evaluate float 32 + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print("Floating point FP32") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py new file mode 100644 index 00000000..22e4a1f3 --- /dev/null +++ b/quantize/gptq/sanity_check_utils.py @@ -0,0 +1,317 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import random +import numpy as np +import torch +from torch.utils.data import DataLoader, random_split +from torchvision import datasets, transforms +import math +import struct + +def seed_everything(seed: int): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + +# Model +class SimpleNet(nn.Module): + def __init__(self, num_classes=10): + super(SimpleNet, self).__init__() + seed_everything(42) + + self.N = 32 * 32 + self.linear1 = nn.Linear(in_features=self.N, out_features=self.N) + self.linear2 = nn.Linear(in_features=self.N, out_features=self.N) + self.linear3 = nn.Linear(in_features=self.N, out_features=self.N) + self.linear4 = nn.Linear(in_features=self.N, out_features=num_classes) + + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + residual = x + x = F.relu(self.linear1(x)) + x = self.linear2(x) + x = F.relu(x) + residual + x = self.linear3(x) + x = F.relu(x) + residual + x = self.linear4(x) + return x + + def forward_pyquant(self, x): + + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + x = self.quant(x) + + residual = x + + x = F.relu(self.linear1(x)) + x = self.linear2(x) + x = self.skip_add.add(F.relu(x), residual) + x = self.linear3(x) + x = self.skip_add.add(F.relu(x), residual) + x = self.linear4(x) + + x = self.dequant(x) + + return x + +class SimpleNet_V2(nn.Module): + def __init__(self, num_classes=10): + super(SimpleNet_V2, self).__init__() + seed_everything(42) + self.N = 32 * 32 + + self.linear0_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear0_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear0_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound)) + + self.linear1_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear1_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear1_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound)) + + self.linear2_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear2_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear2_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound)) + + self.linear3_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(num_classes, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear3_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear3_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(num_classes), -bound, bound)) + + self.w = {} + self.nb_layers = 4 + + def my_linear(self, x, weight, bias): + # return x @ weight.t() + bias. + # Although this is the same, they yield different results as here: https://discuss.pytorch.org/t/differences-between-implementations/129237 + return F.linear(x, weight, bias) + + def forward(self, x): + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + residual = x + x = F.relu(self.my_linear(x, self.linear0_w, self.linear0_b)) + x = self.my_linear(x, self.linear1_w, self.linear1_b) + x = F.relu(x) + residual + x = self.my_linear(x, self.linear2_w, self.linear2_b) + x = F.relu(x) + residual + x = self.my_linear(x, self.linear3_w, self.linear3_b) + return x + +# Dataset + +class MNISTloader: + def __init__( + self, + batch_size: int = 100, + data_dir: str = "./data/", + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + train_val_split: float = 0.1, + ): + self.batch_size = batch_size + self.data_dir = data_dir + self.num_workers = num_workers + self.pin_memory = pin_memory + self.shuffle = shuffle + self.train_val_split = train_val_split + + self.setup() + + def setup(self): + transform = transforms.Compose( + [ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]), + ] + ) + + self.train_dataset = datasets.MNIST( + self.data_dir, train=True, download=True, transform=transform + ) + val_split = int(len(self.train_dataset) * self.train_val_split) + train_split = len(self.train_dataset) - val_split + + self.train_dataset, self.val_dataset = random_split( + self.train_dataset, [train_split, val_split] + ) + self.test_dataset = datasets.MNIST( + self.data_dir, train=False, download=True, transform=transform + ) + + print( + "Image Shape: {}".format(self.train_dataset[0][0].numpy().shape), + end="\n\n", + ) + print("Training Set: {} samples".format(len(self.train_dataset))) + print("Validation Set: {} samples".format(len(self.val_dataset))) + print("Test Set: {} samples".format(len(self.test_dataset))) + + def load(self): + train_loader = DataLoader( + dataset=self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + val_loader = DataLoader( + dataset=self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + test_loader = DataLoader( + dataset=self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + return train_loader, val_loader, test_loader + +# Train + evaluate +def evaluate(device, model, criterion, val_loader, is_pyquant=False): + + val_loss_running, val_acc_running = 0, 0 + + model.eval().cuda() if (device.type == "cuda") else model.eval().cpu() + + with torch.no_grad(): + for inputs, labels in val_loader: + inputs, labels = inputs.to(device), labels.to(device) + if is_pyquant: + outputs = model.forward_pyquant(inputs) + else: + outputs = model(inputs) + loss = criterion(outputs, labels) + _, predictions = torch.max(outputs, dim=1) + val_loss_running += loss.item() * inputs.shape[0] + val_acc_running += torch.sum(predictions == labels.data) + + val_loss = val_loss_running / len(val_loader.sampler) + val_acc = val_acc_running / len(val_loader.sampler) + + return val_loss, val_acc + + +def train(num_epochs, model, optimizer, criterion, train_loader, device): + + model.train().cuda() if (device.type == "cuda") else model.train().cpu() + + for epoch in range(num_epochs): + + train_loss_running, train_acc_running = 0, 0 + + for inputs, labels in train_loader: + + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + outputs = model(inputs) + + _, predictions = torch.max(outputs, dim=1) + loss = criterion(outputs, labels) + + loss.backward() + optimizer.step() + + train_loss_running += loss.item() * inputs.shape[0] + train_acc_running += torch.sum(predictions == labels.data) + + train_loss = train_loss_running / len(train_loader.sampler) + train_acc = train_acc_running / len(train_loader.sampler) + + info = "Epoch: {:3}/{} \t train_loss: {:.3f} \t train_acc: {:.3f}" + print(info.format(epoch + 1, num_epochs, train_loss, train_acc)) + +def write_bin(filename, array): + from functools import reduce + # Force endianess: https://stackoverflow.com/questions/23831422/what-endianness-does-python-use-to-write-into-files + dtype_to_format = { + np.int8: 'i', + np.int16: 'i', + np.int32: 'i', + np.int64: 'i', + np.unsignedinteger: 'I', + np.float16: 'f', + np.float32: 'f', + np.float64: 'f', + np.double: 'd' + } + fmt = dtype_to_format[array.dtype.type] + shapes = [shape for shape in array.shape] + # n, c, h, w = array.shape + with open(filename, "wb") as f: + # number of dim + f.write(struct.pack('I', len(shapes))) + for shape in shapes: + f.write(struct.pack('I', shape)) + f.write(struct.pack('c', bytes(fmt, 'utf-8'))) + f.write(struct.pack(f"{fmt}"*(reduce(lambda x, y: x * y, shapes)), *array.flatten(order="C").tolist())) + +def read_bin(filename): + # https://qiita.com/madaikiteruyo/items/dadc99aa29f7eae0cdd0 + format_to_byte = { + 'c': 1, + 'i': 4, + 'I': 4, + 'f': 4, + 'd': 8 + } + + data = [] + dims, fmt = None, None + with open(filename, "rb") as f: + # read row and col (np.int = 4 bytes) + byte = f.read(format_to_byte['i']) + + if byte == b'': + raise Exception("read_bin: Empty binary") + else: + nb_dim = struct.unpack('I', byte) + + # Read dims + byte = f.read(nb_dim[0] * format_to_byte['I']) + dims = struct.unpack('I'*nb_dim[0], byte) + # Read character format + byte = f.read(1) + if byte == b'': + raise Exception("read_bin: Empty binary") + else: + fmt = chr(struct.unpack('c', byte)[0][0]) + + if len(fmt) != 1: raise Exception("read_bin: No format dumped in binary") + + while True: + byte = f.read(format_to_byte[fmt]) + if byte == b'': + break + else: + data.append(struct.unpack(fmt, byte)[0]) + + return np.array(data).reshape(*dims) \ No newline at end of file diff --git a/quantize/gptq/setup_cuda.py b/quantize/gptq/setup_cuda.py new file mode 100644 index 00000000..25231709 --- /dev/null +++ b/quantize/gptq/setup_cuda.py @@ -0,0 +1,14 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +# get current path +import os +current_path = os.path.dirname(os.path.abspath(__file__)) + +setup( + name='quant_cuda', + ext_modules=[cpp_extension.CUDAExtension( + 'quant_cuda', [f'{current_path}/quant_cuda.cpp', f'{current_path}/quant_cuda_kernel.cu'] + )], + cmdclass={'build_ext': cpp_extension.BuildExtension} +) \ No newline at end of file diff --git a/quantize/measure_perplexity.py b/quantize/measure_perplexity.py new file mode 100644 index 00000000..c1759f15 --- /dev/null +++ b/quantize/measure_perplexity.py @@ -0,0 +1,71 @@ +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from myRWKV import RWKV +# from rwkv.model import RWKV +import torch.nn as nn + +device = "cpu" + +# Model +model = RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') +# model = RWKV("./1sample_quantized.pth", strategy='cpu fp32') +# Dataset +tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile") +# sentence = "My name is Bob" +# encodings = tokenizer("\n\n".join(sentence), return_tensors='pt') +# ctx_len = 5 +# stride = ctx_len // 2 +# seq_len = encodings.input_ids.size(1) + +test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") +tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile") +encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") +ctx_len = 1024 +stride = ctx_len // 2 +seq_len = encodings.input_ids.size(1) + +nlls = [] +prev_end_loc = 0 +logits, state = None, None +loss_fct = nn.CrossEntropyLoss() + +# for begin_loc in tqdm(range(0, seq_len, stride)): +for begin_loc in tqdm(range(0, stride * 3, stride)): + end_loc = min(begin_loc + ctx_len, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + full_logits = torch.zeros((input_ids.size(1), model.w["emb.weight"].shape[0])) + + with torch.no_grad(): + for i in range(input_ids.size(1)): + logits, state = model.forward([input_ids[0, i]], state) + full_logits[i, :] = logits + + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + labels = target_ids + labels = labels.to(full_logits.device) + # Shift so that tokens < n predict n + shift_logits = full_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + neg_log_likelihood = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + +print(f"nlls: {torch.stack(nlls)}") +mean_nll = torch.stack(nlls).mean() +if mean_nll.is_cuda: + mean_nll = mean_nll.cpu().float() +ppl = torch.exp(mean_nll) +print(f"Perplexity: {ppl}") diff --git a/quantize/myRWKV.py b/quantize/myRWKV.py new file mode 100644 index 00000000..7e7e64b7 --- /dev/null +++ b/quantize/myRWKV.py @@ -0,0 +1,769 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types, gc, os, time, re +import torch +from torch.nn import functional as F +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +current_path = os.path.dirname(os.path.abspath(__file__)) + +######################################################################################################## + +if os.environ.get('RWKV_JIT_ON') != '0': + os.environ["RWKV_JIT_ON"] = '1' + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script +else: + MyModule = torch.nn.Module + def __nop(ob): + return ob + MyFunction = __nop + MyStatic = __nop + +if os.environ.get('RWKV_CUDA_ON') == '1': + from torch.utils.cpp_extension import load + load( + name=f"wkv_cuda", + sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"], + verbose=True, + extra_cuda_cflags=["-t 4", "-std=c++17", "--use_fast_math", "-O3", "--extra-device-vectorization"], + is_python_module=False) + + @MyStatic + def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): + assert 1 * C % min(C, 32) == 0 + assert k.dtype == torch.float16 + w = w.contiguous() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.float16) + torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) + return y, aa, bb, pp + @MyStatic + def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == [B, N] + assert w.shape == [N, M] + assert rx.shape == mx.shape == [M] + assert ry.shape == my.shape == [N, 1] + y = torch.empty((B, M), device=w.device, dtype=torch.float16) + torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) + return y + @MyStatic + def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == [N] + assert w.shape == [N, M] + assert rx.shape == mx.shape == [M] + assert ry.shape == my.shape == [N, 1] + y = torch.zeros((M,), device=w.device, dtype=torch.float32) + torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) + return y.to(dtype=torch.float16) +else: + os.environ["RWKV_CUDA_ON"] = '0' + +######################################################################################################## + +class RWKV(MyModule): + def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): + super().__init__() + if verbose: + prxxx = lambda *args, **kwargs: print(*args, **kwargs) + else: + prxxx = lambda *args, **kwargs: None + + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + if not re.match(STRATEGY_REGEX, strategy): + raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") + + strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') + self.args = types.SimpleNamespace() + args = self.args + args.MODEL_NAME = model + args.strategy_string = strategy + + # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) + self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 + prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') + + args.MODEL_NAME = args.MODEL_NAME.strip() + if not args.MODEL_NAME.endswith('.pth'): + args.MODEL_NAME += '.pth' + prxxx(f'Loading {args.MODEL_NAME} ...') + with torch.no_grad(): + obj = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first + import pdb; pdb.set_trace() + if isinstance(obj, list): # GPTQ + self.w_quant = obj[0] + self.w = obj[1] + gc.collect() + w = self.w + ALREADY_CONVERTED = True + else: + self.w_quant = {} + self.w = obj + gc.collect() + w = self.w + + ALREADY_CONVERTED = False + if '_strategy' in w: + ALREADY_CONVERTED = True + assert convert_and_save_and_exit == None # you should only convert a raw model + prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") + assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model + assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py + assert w['_rescale_layer'] == self.RESCALE_LAYER + del w['_strategy'] + del w['_version'] + del w['_rescale_layer'] + + args.n_embd = w['emb.weight'].shape[1] + args.n_layer = 0 + keys = list(w.keys()) + for x in keys: + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + args.n_layer = max(args.n_layer, layer_id+1) + + ####################### Compute strategy + + s = [x.strip().split(' ') for x in strategy.split('->')] + plan = [0] * len(s) + stream_i = -1 + stream_count = 0 + to_allocate = args.n_layer + 1 + allocated = 0 + free_slots = 0 + for i in range(len(s)): + si = s[i] + si1 = si[1] + if si1.startswith('fp32'): si[1] = [torch.float] + elif si1.startswith('fp16'): si[1] = [torch.float16] + elif si1.startswith('bf16'): si[1] = [torch.bfloat16] + if si1.endswith('i8'): si[1] += [torch.uint8] + else: si[1] += [si[1][0]] + if len(si) > 2: + ss = si[2] + assert ss.startswith('*') + if ss.endswith('+'): + plan[i] = int(ss[1:-1]) + stream_i = i + else: + plan[i] = int(ss[1:]) + allocated += plan[i] + if allocated >= to_allocate: + plan[i] += to_allocate - allocated + break + else: + free_slots += 1 + if stream_i < 0: + if free_slots > 0 and to_allocate > allocated: + for i in range(len(s)): + if plan[i] == 0: + plan[i] = (to_allocate - allocated) // free_slots + allocated += plan[i] + free_slots -= 1 + if to_allocate > allocated: + plan[len(s)-1] += to_allocate - allocated + else: + if to_allocate > allocated: + stream_count = to_allocate - allocated + plan[stream_i] += stream_count + prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') + for i in range(len(s)): + ss = s[i] + if i != stream_i: + prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') + else: + prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') + plan[i] += (0 if i == 0 else plan[i-1]) + self.strategy = [None] * (args.n_layer + 1) + strategy = self.strategy + + for n in range(args.n_layer + 1): + for i in range(len(s)): + if n < plan[i]: + strategy[n] = types.SimpleNamespace() + strategy[n].device = s[i][0] + strategy[n].atype = s[i][1][0] + strategy[n].wtype = s[i][1][1] + strategy[n].stream = False + if i == stream_i and n >= (plan[i] - stream_count): + strategy[n].stream = True + break + prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') + prxxx() + + ####################### Load weights to self.w + if not ALREADY_CONVERTED: + try: # precompute embedding + w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) + except: + w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) + del w['blocks.0.ln0.weight'] + del w['blocks.0.ln0.bias'] + + print_need_newline = False + keys = list(w.keys()) + for x in keys: + w[x].requires_grad = False + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + if ('ln_out.' in x) or ('head.' in x): + layer_id = args.n_layer + dd = strategy[layer_id] + DEVICE = dd.device + ATYPE = dd.atype + WTYPE = dd.wtype + + if not ALREADY_CONVERTED: + if self.RESCALE_LAYER > 0: + if 'att.output.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + if 'ffn.value.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + + if '.time_' in x: + w[x] = w[x].squeeze() + if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: + w[x] = w[x].t() + + if '.time_decay' in x: # need fp32 for this + w[x] = -torch.exp(w[x].float()) + elif '.time_first' in x: # need fp32 for this + w[x] = w[x].float() + else: + if (len(w[x].shape) == 2) and ('emb' not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x] = w[x].float() + + if w[x].shape[0] > w[x].shape[1]: + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + else: + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + + w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) + w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() + w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() + w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() + w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() + else: + w[x] = w[x].to(dtype=ATYPE) + + # GPTQ + if self.w_quant != {}: + if (len(w[x].shape) == 2) and ('emb' not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() + w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() + w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() + w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() + else: + w[x] = w[x].to(dtype=ATYPE) + + if convert_and_save_and_exit == None: + if 'emb.' in x: + w[x] = w[x].contiguous() + elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): + try: + w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) + except: + print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') + elif DEVICE != 'cpu': + w[x] = w[x].to(device=DEVICE).contiguous() + + if (dd.stream) or (DEVICE != 'cpu'): + try: + w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() + w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() + w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() + w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() + except: + pass + + if 'ffn.value.weight' in x: + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + shape = [i for i in w[x].shape if i != 1] + if len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + else: + shape = f" {str(shape[0]).rjust(5)} " + if layer_id == 0 or layer_id >= args.n_layer-1: + if print_need_newline: + prxxx('\n', end = '') + print_need_newline = False + dt = str(w[x].dtype).replace('torch.', '') + dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') + prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') + else: + print_need_newline = True + prxxx('.', end = '', flush = True) + + if convert_and_save_and_exit: + w['_strategy'] = args.strategy_string + w['_rescale_layer'] = self.RESCALE_LAYER + w['_version'] = '0.7' + if not convert_and_save_and_exit.endswith('.pth'): + convert_and_save_and_exit += '.pth' + prxxx(f'Saving to {convert_and_save_and_exit}...') + torch.save(w, convert_and_save_and_exit) + prxxx(f'Converted and saved. Now this will exit.') + exit(0) + + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + if os.environ.get('RWKV_CUDA_ON') == '1': + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + B, N, M = x.shape[0], w.shape[0], w.shape[1] + return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + N, M = w.shape[0], w.shape[1] + return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) + else: + # @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + # @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + ######################################################################################################## + + # @MyFunction + def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # vx = torch.square(torch.relu(kx @ kw)) + # out = r * (vx @ vw) + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + vx = torch.square(torch.relu(self._trigger_gptq(kx, weight=kw[0], name=kw[1]))) + out = r * (self._trigger_gptq(vx, weight=vw[0], name=vw[1])) + return x + out, xx + + # @MyFunction + def ffn_one_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx + + ######################################################################################################## + + # @MyFunction + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + # vx = torch.square(torch.relu(kx @ kw)) + vx = torch.square(torch.relu(self._trigger_gptq(kx, weight=kw[0], name=kw[1]))) + # out = r * (vx @ vw) + out = r * self._trigger_gptq(vx, weight=vw[0], name=vw[1]) + return x + out, xx[-1,:] + + # @MyFunction + def ffn_seq_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx[-1,:] + + ######################################################################################################## + + # @MyFunction + def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # k = (kx @ kw).float() + # v = (vx @ vw).float() + + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + k = self._trigger_gptq(kx, weight=kw[0], name=kw[1]).float() + v = self._trigger_gptq(vx, weight=vw[0], name=vw[1]).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + # out = (r * wkv) @ ow + out = self._trigger_gptq(r * wkv, weight=ow[0], name=ow[1]) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + # @MyFunction + def att_one_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() + v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + ######################################################################################################## + + # @MyFunction + def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # k = (kx @ kw).float() + # v = (vx @ vw).float() + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + k = (self._trigger_gptq(kx, weight=kw[0], name=kw[1])).float() + v = (self._trigger_gptq(vx, weight=vw[0], name=vw[1])).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = (r * sx) @ ow + out = self._trigger_gptq(r * sx, weight=ow[0], name=ow[1]) + return x + out, xx[-1,:], aa, bb, pp + + # @MyFunction + def att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) + return x + out, xx[-1,:], aa, bb, pp + + ######################################################################################################## + + if os.environ["RWKV_CUDA_ON"] == '1': + @MyFunction + def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = kx @ kw + v = vx @ vw + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = (r * y) @ ow + return x + out, xx[-1,:], aa, bb, pp + + @MyFunction + def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) + return x + out, xx[-1,:], aa, bb, pp + + ######################################################################################################## + + def _trigger_gptq(self, x, weight, name): + + # GPTQ + if name in self.w_quant.keys(): + w_quant = self.w_quant[name] + # Only work on CUDA because of the use of fp16 (not available on CPU) + if w_quant.bits in [2, 4, 8]: + w_quant.scales = w_quant.scales.to(x.device) + out_shape = x.shape[:-1] + (w_quant.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + + zeros = torch.bitwise_right_shift(torch.unsqueeze(w_quant.qzeros, 2).expand(-1, -1, 32 // w_quant.bits), w_quant.wf.unsqueeze(0)).to(x.device, torch.int16 if w_quant.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** w_quant.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(w_quant.scales.shape) + + new_weight = torch.bitwise_right_shift(torch.unsqueeze(w_quant.qweight, 1).expand(-1, 32 // w_quant.bits, -1), w_quant.wf.unsqueeze(-1)).to(x.device, torch.int16 if w_quant.bits == 8 else torch.int8) + torch.bitwise_and(new_weight,(2 ** w_quant.bits) - 1, out=new_weight) + + new_weight = new_weight.reshape(new_weight.shape[0] * new_weight.shape[1], new_weight.shape[2]) + num_itr = w_quant.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (w_quant.scales[w_quant.g_idx.long()] * (new_weight - zeros[w_quant.g_idx.long()])) + else: + num_dim = w_quant.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = w_quant.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = new_weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = w_quant.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) + + # out = torch.matmul(x.half(), weights) + out = torch.matmul(x, weights.to(x.dtype)) + out = out.reshape(out_shape) + return out + else: + return x @ weight + + def forward(self, tokens, state, full_output=False): + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 + state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + + seq_mode = len(tokens) > 1 + + x = w['emb.weight'][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': + ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = w[f'{att}key.weight'], f'{att}key.weight' + vw = w[f'{att}value.weight'], f'{att}value.weight' + rw = w[f'{att}receptance.weight'], f'{att}receptance.weight' + ow = w[f'{att}output.weight'], f'{att}output.weight' + + # kw = w[f'{att}key.weight'] + # vw = w[f'{att}value.weight'] + # rw = w[f'{att}receptance.weight'] + # ow = w[f'{att}output.weight'] + + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + + kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x + krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x + kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x + kry = w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x + vmx = w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x + vrx = w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x + vmy = w[f'{att}value.weight_my'] if wtype == torch.uint8 else x + vry = w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x + rmx = w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x + rry = w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x + omx = w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x + orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x + omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x + ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x + + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3], + w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], + w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], + w[f'{att}time_decay'], w[f'{att}time_first'], + kw, vw, rw, ow, + kmx, krx, kmy, kry, + vmx, vrx, vmy, vry, + rmx, rrx, rmy, rry, + omx, orx, omy, ory, + ) + if dd.stream: + del kw, vw, rw, ow + + kw = w[f'{ffn}key.weight'], f'{ffn}key.weight' + vw = w[f'{ffn}value.weight'], f'{ffn}value.weight' + rw = w[f'{ffn}receptance.weight'], f'{ffn}receptance.weight' + # kw = w[f'{ffn}key.weight'] + # vw = w[f'{ffn}value.weight'] + # rw = w[f'{ffn}receptance.weight'] + + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x + krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x + kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x + kry = w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x + vmx = w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x + vrx = w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x + vmy = w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x + vry = w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x + rmx = w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x + rry = w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+4] = FFN( + x, state[i*5+4], + w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'], + w[f'{ffn}time_mix_k'], w[f'{ffn}time_mix_r'], + kw, vw, rw, + kmx, krx, kmy, kry, + vmx, vrx, vmy, vry, + rmx, rrx, rmy, rry, + ) + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i+1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) + if w['head.weight'].dtype != torch.uint8: + x = x @ w['head.weight'] + # FIXME: WTF check load model and see why + # (Pdb++) w_quant.scales.shape + # torch.Size([1, 50277]) + # (Pdb++) zeros.shape + # torch.Size([1, 12568, 4]) + # (Pdb++) 12568 * 4 + # 50272 + # x = self._trigger_gptq(x, w["head.weight"], "head.weight") + else: + if seq_mode and full_output: + x = self.mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) + else: + x = self.mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) + return x.float(), state \ No newline at end of file diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py new file mode 100644 index 00000000..d550e196 --- /dev/null +++ b/quantize/tmp_rwkv.py @@ -0,0 +1,572 @@ +from myRWKV import RWKV +from gptq.datautils import * +from gptq.quant import Quantizer, quantize + +import os +import torch.nn.functional as F +from collections import OrderedDict +import time +import math +import re +from gptq.gptq import QuantLinear_custom + +WBITS = 8 +GROUPSIZE = -1 + +class GPTQ_RWKV(RWKV): + + ### begin GPTQ + class GPTQ: + def __init__(self, weight, name): + #TODO: Remove name, only used for debugging + self.name = name + self.weight = weight.clone() + self.dev = weight.device + # In GPTQ, they use nn.Linear(x) which performs x @ w.T but in RWKV, we perform x @ w instead + # Problem is self.H is a square matrix which depends on self.columns = W.shape[1] in the original code + # But if we keep it that way, this will break self.H += inp.matmul(inp.t()) because inp.shape[1] != W.shape[1] + # Thus, we have to use self.W.shape[0] instead + self.columns = self.weight.shape[0] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.deactivate_add_batch_call = False + + def add_batch(self, inp): + # After calling fasterquant, we don't want to call add_batch anymore + if self.deactivate_add_batch_call: + return + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + tmp = inp.shape[0] + + # Assume weight come from nn.Linear + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): + # if self.name == "blocks.0.ffn.value.weight": + # if self.name == "blocks.0.ffn.key.weight": + # import pdb; pdb.set_trace() + W = self.weight.data.clone() + # OLD: Need to transpose here, same reason as in __init__ with self.columns + # UPDATE: no need to tranpose as we already transpose in my_linear() + # UPDATE2: for rwkv, this is necessary + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + + torch.cuda.synchronize() + print('time %.2f' % (time.time() - tick)) + print('error', torch.sum(Losses).item()) + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx + + ### end GPTQ + + ### begin GPTQ_RWKV + def __init__(self, checkpoint_path, strategy): + super().__init__(checkpoint_path, strategy) + for i in range(self.args.n_layer): + assert self.strategy[i].device == "cpu" + + def _fill_subset(self, layer_id): + # Keep only layer within block layer_id + + #TODO: Uncomment me when quantizing 1 layer works + is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') + # is_weight = re.compile("blocks.0.att.key.weight") + for name in self.w.keys(): + if is_weight.match(name): + if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now + self.subset[name] = self.w[name] + + # TODO: Uncomment me when quantizing 1 layer works + # is_last_layer = (layer_id == self.args.n_layer - 1) + # if is_last_layer: + # self.subset["head.weight"] = self.w["head.weight"] + + return self.subset + + def alloc_gptq(self, layer_id): + self.subset = {} + self.gptq = {} + + self.subset = self._fill_subset(layer_id) + + for name in self.subset: + self.gptq[name] = self.GPTQ(self.subset[name], name) + self.gptq[name].quantizer = Quantizer() + self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) + + def free_gptq(self): + self.subset = {} + self.gptq = {} + + def fasterquant(self, layer_id, quantizers): + + for name in self.subset: + print(layer_id, name) + print('Quantizing ...') + scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[name] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + + ### end GPTQ_RWKV + + ### begin RWKV + def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # v = (vx @ vw).float() + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + k = (kx @ kw.weight).float() + kw.add_batch(kx) + v = (vx @ vw.weight).float() + vw.add_batch(vx) + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + # out = (r * wkv) @ ow + out = (r * wkv) @ ow.weight + ow.add_batch(r * wkv) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # k = (kx @ kw).float() + # v = (vx @ vw).float() + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + k = (kx @ kw.weight).float() + kw.add_batch(kx) + v = (vx @ vw.weight).float() + vw.add_batch(vx) + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + # out = (r * sx) @ ow + out = (r * sx) @ ow.weight + ow.add_batch(r * sx) + return x + out, xx[-1,:], aa, bb, pp + + def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # vx = torch.square(torch.relu(kx @ kw)) + # out = r * (vx @ vw) + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + vx = torch.square(torch.relu(kx @ kw.weight)) + kw.add_batch(kx) + out = r * (vx @ vw.weight) + vw.add_batch(vx) + + return x + out, xx + + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # vx = torch.square(torch.relu(kx @ kw)) + # out = r * (vx @ vw) + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + vx = torch.square(torch.relu(kx @ kw.weight)) + kw.add_batch(kx) + out = r * (vx @ vw.weight) + vw.add_batch(vx) + return x + out, xx[-1,:] + + def forward_block(self, x, state, i, seq_mode, full_output=False): + with torch.no_grad(): + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 + state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + + if seq_mode: + if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': + ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = self.gptq[f'{att}key.weight'] + vw = self.gptq[f'{att}value.weight'] + rw = self.gptq[f'{att}receptance.weight'] + ow = self.gptq[f'{att}output.weight'] + + # kw = self.w[f'{att}key.weight'] + # vw = self.w[f'{att}value.weight'] + # rw = self.w[f'{att}receptance.weight'] + # ow = self.w[f'{att}output.weight'] + + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + + kmx = self.w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x + krx = self.w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x + kmy = self.w[f'{att}key.weight_my'] if wtype == torch.uint8 else x + kry = self.w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x + vmx = self.w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x + vrx = self.w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x + vmy = self.w[f'{att}value.weight_my'] if wtype == torch.uint8 else x + vry = self.w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x + rmx = self.w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = self.w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = self.w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x + rry = self.w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x + omx = self.w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x + orx = self.w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x + omy = self.w[f'{att}output.weight_my'] if wtype == torch.uint8 else x + ory = self.w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x + + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( + x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3], + ln_w=self.w[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], + k_mix=self.w[f'{att}time_mix_k'], v_mix=self.w[f'{att}time_mix_v'], r_mix=self.w[f'{att}time_mix_r'], + t_decay=self.w[f'{att}time_decay'], t_first=self.w[f'{att}time_first'], + kw=kw, vw=vw, rw=rw, ow=ow, + kmx=kmx, krx=krx, kmy=kmy, kry=kry, + vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, + rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, + omx=omx, orx=orx, omy=omy, ory=ory, + ) + + if dd.stream: + del kw, vw, rw, ow + + # kw = self.w[f'{ffn}key.weight'] + # vw = self.w[f'{ffn}value.weight'] + # rw = self.w[f'{ffn}receptance.weight'] + + kw = self.gptq[f'{ffn}key.weight'] + vw = self.gptq[f'{ffn}value.weight'] + rw = self.gptq[f'{ffn}receptance.weight'] + + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + + kmx = self.w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x + krx = self.w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x + kmy = self.w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x + kry = self.w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x + vmx = self.w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x + vrx = self.w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x + vmy = self.w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x + vry = self.w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x + rmx = self.w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = self.w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = self.w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x + rry = self.w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x + + x, state[i*5+4] = FFN( + x=x, sx=state[i*5+4], + ln_w=self.w[f'{bbb}ln2.weight'], ln_b=self.w[f'{bbb}ln2.bias'], + k_mix=self.w[f'{ffn}time_mix_k'], r_mix=self.w[f'{ffn}time_mix_r'], + kw=kw, vw=vw, rw=rw, + kmx=kmx, krx=krx, kmy=kmy, kry=kry, + vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, + rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, + ) + + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i+1) % self.RESCALE_LAYER == 0: + x = x / 2 + + is_last_layer = i == (args.n_layer - 1) + if is_last_layer: + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + #TODO: ln_out.weight is 1D tensor + x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) + + if self.w['head.weight'].dtype != torch.uint8: + x = x @ self.w['head.weight'] + #TODO: uncommenbt me when quantizing 1 layer work + # x = x @ self.gptq['head.weight'].weight + # self.gptq['head.weight'].add_batch(x) + + return x.float() + + ### end RWKV + +@torch.no_grad() +def quantize_gptq_custom(model, tokens): + nsamples = tokens.shape[0] + seq_mode = len(tokens) > 1 + is_last_layer = lambda x: x == (model.args.n_layer - 1) + + inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]] + outs = torch.zeros_like(inps) + quantizers = {} + + # for layer_id in range(model.args.n_layer): + for layer_id in range(1): + + print(f"Quantizing layer {layer_id} ...") + + model.alloc_gptq(layer_id) + + for i in range(nsamples): + if not is_last_layer(layer_id): + outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + else: + _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + + for gptq_layer in model.gptq.values(): + gptq_layer.deactivate_add_batch_call = True + + model.fasterquant(layer_id, quantizers) + # model.gptq["blocks.0.ffn.value.weight"].weight + + for i in range(nsamples): + if not is_last_layer(layer_id): + outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + else: + _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + + # Assign the quantized weights to the model + for key in model.gptq.keys(): + model.w[key].copy_(model.gptq[key].weight) + + model.free_gptq() + + # We need to pass the outputs of block i as input of block i+1 (except for last block) + if not is_last_layer(layer_id): + inps, outs = outs, inps + + return quantizers + +def model_pack_custom(model, quantizers, wbits, groupsize): + + weights = OrderedDict() + + # is_weight = re.compile('^blocks\.\d+(\.[a-z]+[0-9]?)*\.weight$') + # for name in model.w.keys(): + # if is_weight.match(name): + # if len(model.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now + # weights[name] = model.w[name] + + for name in quantizers.keys(): + if len(model.w[name].shape) == 1: continue + weights[name] = model.w[name] + + #TODO: uncommenbt me when done + # weights["head.weight"] = model.w["head.weight"] + + assert set(quantizers) - set(weights) == set(), "Quantizers and weights don't match" + assert set(weights) - set(quantizers) == set(), "Quantizers and weights don't match" + + # Replace layer by QuantLinear + model.w_quant = {} + for key, value in model.w.items(): + if key in quantizers.keys(): + #FIXME: So far, we don't quantize ln0 et ln1 (which have bias) because 1d tensors + bias = None + model.w_quant[key] = QuantLinear_custom(wbits, groupsize, value.shape[0], value.shape[1], bias) + + # Fill QuantLinear + print('Packing ...') + for key in model.w_quant.keys(): + _, scale,zero,g_idx = quantizers[key] + bias = None + model.w_quant[key].pack(weights[key], bias, scale, zero, g_idx) + print('Done.') + return model + +if __name__ == "__main__": + + model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') + + NSAMPLES=128 + HIDDEN_SIZE=model.args.n_embd + SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m + + print("Loading data ...") + train_tokens, test_tokens = get_loaders( + dataset_name="wikitext2", + nsamples=NSAMPLES, + seed=42, + seqlen=SEQLEN, + model=model + ) + + tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) + print("tokens.shape", tokens.shape) + + print("Quantizing ...") + quantizers = quantize_gptq_custom(model, tokens) + model = model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + print("Saving ...") + torch.save([model.w_quant, model.w], "1sample_quantized.pth") + + + print("Done Custom GPTQ") \ No newline at end of file diff --git a/requirements-quantize.txt b/requirements-quantize.txt new file mode 100644 index 00000000..352f3a9a --- /dev/null +++ b/requirements-quantize.txt @@ -0,0 +1,12 @@ +rwkv==0.7.3 +-f https://download.pytorch.org/whl/cu117/torch_stable.html +torch==1.13.1+cu117 +transformers +datasets +ninja +tokenizers>=0.13.2 +prompt_toolkit +# Debug +pdbpp +line_profiler +torchvision==0.14.1+cu117 \ No newline at end of file