From 6e556e504d0e8252f7c3c6dd6ad1ab01ea403e94 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 18 Apr 2023 14:00:11 +0000 Subject: [PATCH 01/20] feat(quantize): measure perplexity on wikitext2 --- quantize/measure_perplexity.py | 111 +++++++++++++++++++++++++++++++++ requirements-quantize.txt | 10 +++ 2 files changed, 121 insertions(+) create mode 100644 quantize/measure_perplexity.py create mode 100644 requirements-quantize.txt diff --git a/quantize/measure_perplexity.py b/quantize/measure_perplexity.py new file mode 100644 index 00000000..e0b5be29 --- /dev/null +++ b/quantize/measure_perplexity.py @@ -0,0 +1,111 @@ +# Measures perplexity and per-token latency of an RWKV model on a given text file. +# Perplexity is defined here as exp() of average cross-entropy loss. +# Usage: python measure_perplexity.py RWKV-4-Pile-169M-20220807-8023.pth wikitext2 2048 + +import os +import time +import pathlib +import argparse +import tokenizers +import torch +from typing import List +from rwkv.model import RWKV + +def parse_args(): + parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') + parser.add_argument('model_path', help='Path to model checkpoint file') + parser.add_argument('dataset_path', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') + parser.add_argument('nsamples', help='How many samples', type=int, default=4096) + return parser.parse_args() + +args = parse_args() + +def get_wikitext2(nsamples): + from datasets import load_dataset + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + print('Loading 20B tokenizer (RWKV)') + tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' + tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) + + print('Loading text') + test_text: str = "\n\n".join(testdata['text']) + test_tokens = torch.tensor(tokenizer.encode(test_text).ids, dtype=torch.long) + print(f'{len(test_tokens)} test tokens in the text') + + import random + random.seed(42) + # Randomly select a sample of nsamples tokens + i = random.randint(0, len(test_tokens) - nsamples) + return tokenizer, test_tokens[i:i+nsamples] + +def get_loaders(dataset_path, nsamples): + if 'wikitext2' in dataset_path: + return get_wikitext2(nsamples) + else: + # https://github.com/IST-DASLab/gptq/blob/main/datautils.py + raise NotImplementedError("Only wikitext2 is supported for now") + +tokenizer, test_tokens = get_loaders(args.dataset_path, args.nsamples) + +def format_loss(loss: torch.Tensor) -> str: + return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1] + +def format_loss_with_perplexity(loss: torch.Tensor) -> str: + return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}' + +# --- +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +# device=torch.device('cpu') + +model = RWKV(model=args.model_path, strategy='cuda fp16i8') + +logits, state = None, None +loss_sum: torch.Tensor = torch.tensor([0.0], device=device) +loss_count: int = 0 +token_count = len(test_tokens) +run_count = token_count - 1 +# Ignore 20% of the tokens to let the model warmup +ignore_first_n_tokens = int(token_count * 0.2) +start: float = time.time() + +for i in range(run_count): + token: int = test_tokens[i] + target: int = test_tokens[i + 1] + + logits, state = model.forward([token], None if i == 0 else state) + + if ignore_first_n_tokens == 0 or i + 1 >= ignore_first_n_tokens: + losses = torch.tensor([ + torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long, device=device), reduction='none').item() + ] + , device=device) + + loss_sum += losses + loss_count += 1 + + if i % 100 == 0: + avg_loss_so_far = loss_sum / loss_count + + duration: float = time.time() - start + duration_per_token: float = duration / (i + 1) + runs_remaining: int = run_count - i - 1 + duration_remaining: int = int(runs_remaining * duration_per_token) + + print(f'Token #{i}/{token_count}, ' + f'{int(100.0 * i / token_count)}%, ' + f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='') + + if loss_count > 0: + print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}') + else: + print() + +print() +print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token') + +print() +print(f'Model: {os.path.basename(args.model_path)}, ' + f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens, ' + f'Ignored first {ignore_first_n_tokens} tokens, ' + f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}') diff --git a/requirements-quantize.txt b/requirements-quantize.txt new file mode 100644 index 00000000..006cd3b8 --- /dev/null +++ b/requirements-quantize.txt @@ -0,0 +1,10 @@ +rwkv==0.7.3 +-f https://download.pytorch.org/whl/cu117/torch_stable.html +torch==1.13.1+cu117 +transformers +datasets +ninja +tokenizers>=0.13.2 +prompt_toolkit +# Debug +pdbpp \ No newline at end of file From bde6374df6bb256e1bc943537894c920fd3e84be Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 18 Apr 2023 14:29:01 +0000 Subject: [PATCH 02/20] feat(quantize): add gptq files --- README_quantize.md | 6 + quantize/gptq/datautils.py | 175 +++++++++++ quantize/gptq/gptq.py | 157 ++++++++++ quantize/gptq/modelutils.py | 16 + quantize/gptq/opt.py | 471 +++++++++++++++++++++++++++++ quantize/gptq/quant.py | 212 +++++++++++++ quantize/gptq/quant_cuda.cpp | 34 +++ quantize/gptq/quant_cuda_kernel.cu | 244 +++++++++++++++ quantize/gptq/setup_cuda.py | 14 + 9 files changed, 1329 insertions(+) create mode 100644 README_quantize.md create mode 100644 quantize/gptq/datautils.py create mode 100644 quantize/gptq/gptq.py create mode 100644 quantize/gptq/modelutils.py create mode 100644 quantize/gptq/opt.py create mode 100644 quantize/gptq/quant.py create mode 100644 quantize/gptq/quant_cuda.cpp create mode 100644 quantize/gptq/quant_cuda_kernel.cu create mode 100644 quantize/gptq/setup_cuda.py diff --git a/README_quantize.md b/README_quantize.md new file mode 100644 index 00000000..4a057a44 --- /dev/null +++ b/README_quantize.md @@ -0,0 +1,6 @@ +# GPTQ + +``` +pip install -r requirements-quantize.txt +python quantize/gptq/setup_cuda.py install +``` \ No newline at end of file diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py new file mode 100644 index 00000000..193953c5 --- /dev/null +++ b/quantize/gptq/datautils.py @@ -0,0 +1,175 @@ +import numpy as np +import torch + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_ptb(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_c4(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' + ) + valdata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' + ) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + +def get_ptb_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_c4_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' + ) + valdata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' + ) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders( + name, nsamples=128, seed=0, seqlen=2048, model='' +): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, model) + if 'ptb' in name: + if 'new' in name: + return get_ptb_new(nsamples, seed, seqlen, model) + return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in name: + if 'new' in name: + return get_c4_new(nsamples, seed, seqlen, model) + return get_c4(nsamples, seed, seqlen, model) diff --git a/quantize/gptq/gptq.py b/quantize/gptq/gptq.py new file mode 100644 index 00000000..2bae9786 --- /dev/null +++ b/quantize/gptq/gptq.py @@ -0,0 +1,157 @@ +import math +import time + +import torch +import torch.nn as nn +import transformers + +from quant import * + + +DEBUG = False + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class GPTQ: + + def __init__(self, layer): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp, out): + if DEBUG: + self.inp1 = inp + self.out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def fasterquant( + self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False + ): + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if DEBUG: + self.layer.weight.data[:, :i2] = Q[:, :i2] + self.layer.weight.data[:, i2:] = W[:, i2:] + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + print(torch.sum(Losses)) + + torch.cuda.synchronize() + print('time %.2f' % (time.time() - tick)) + print('error', torch.sum(Losses).item()) + + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + if DEBUG: + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + + def free(self): + if DEBUG: + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() diff --git a/quantize/gptq/modelutils.py b/quantize/gptq/modelutils.py new file mode 100644 index 00000000..0c5d12b1 --- /dev/null +++ b/quantize/gptq/modelutils.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + + +DEV = torch.device('cuda:0') + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res diff --git a/quantize/gptq/opt.py b/quantize/gptq/opt.py new file mode 100644 index 00000000..60112b24 --- /dev/null +++ b/quantize/gptq/opt.py @@ -0,0 +1,471 @@ +import time + +import torch +import torch.nn as nn + +from gptq import * +from modelutils import * +from quant import * + + +def get_opt(model): + import torch + def skip(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import OPTForCausalLM + model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') + model.seqlen = model.config.max_position_embeddings + return model + +@torch.no_grad() +def opt_sequential(model, dataloader, dev): + print('Starting ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.decoder.layers + + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) + model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.to(dev) + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() + model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.cpu() + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + + print('Ready.') + + quantizers = {} + for i in range(len(layers)): + layer = layers[i].to(dev) + + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits + ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + for h in handles: + h.remove() + + for name in subset: + print(i, name) + print('Quantizing ...') + gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer + gptq[name].free() + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + model.config.use_cache = use_cache + + return quantizers + +@torch.no_grad() +def opt_eval(model, testenc, dev): + print('Evaluating ...') + + testenc = testenc.input_ids + nsamples = testenc.numel() // model.seqlen + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.decoder.layers + + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) + model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.to(dev) + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + raise ValueError + layers[0] = Catcher(layers[0]) + for i in range(nsamples): + batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) + try: + model(batch) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() + model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.cpu() + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + + for i in range(len(layers)): + print(i) + layer = layers[i].to(dev) + + if args.nearest: + subset = find_layers(layer) + for name in subset: + quantizer = Quantizer() + quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False + ) + W = subset[name].weight.data + quantizer.find_params(W, weight=True) + subset[name].weight.data = quantize( + W, quantizer.scale, quantizer.zero, quantizer.maxq + ).to(next(iter(layer.parameters())).dtype) + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + layers[i] = layer.cpu() + del layer + torch.cuda.empty_cache() + inps, outs = outs, inps + + if model.model.decoder.final_layer_norm is not None: + model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) + if model.model.decoder.project_out is not None: + model.model.decoder.project_out = model.model.decoder.project_out.to(dev) + model.lm_head = model.lm_head.to(dev) + + testenc = testenc.to(dev) + nlls = [] + for i in range(nsamples): + hidden_states = inps[i].unsqueeze(0) + if model.model.decoder.final_layer_norm is not None: + hidden_states = model.model.decoder.final_layer_norm(hidden_states) + if model.model.decoder.project_out is not None: + hidden_states = model.model.decoder.project_out(hidden_states) + lm_logits = model.lm_head(hidden_states) + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = testenc[ + :, (i * model.seqlen):((i + 1) * model.seqlen) + ][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + neg_log_likelihood = loss.float() * model.seqlen + nlls.append(neg_log_likelihood) + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + print(ppl.item()) + + model.config.use_cache = use_cache + +# TODO: perform packing on GPU +def opt_pack3(model, quantizers): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant3(model, quantizers, faster=args.faster_kernel) + qlayers = find_layers(model, [Quant3Linear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name] = quantizers[name].cpu() + qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) + print('Done.') + return model + +def load_quant3(model, checkpoint): + from transformers import OPTConfig, OPTForCausalLM + config = OPTConfig.from_pretrained(model) + def noop(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = OPTForCausalLM(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']: + if name in layers: + del layers[name] + make_quant3(model, layers, faster=args.faster_kernel) + + print('Loading model ...') + model.load_state_dict(torch.load(checkpoint)) + model.seqlen = model.config.max_position_embeddings + print('Done.') + + return model + +def opt_multigpu(model, gpus): + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0]) + model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0]) + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0]) + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1]) + if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm: + model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1]) + import copy + model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) + + cache = {'mask': None} + + class MoveModule(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.dev = next(iter(self.module.parameters())).device + def forward(self, *inp, **kwargs): + inp = list(inp) + if inp[0].device != self.dev: + inp[0] = inp[0].to(self.dev) + if cache['mask'] is None or cache['mask'].device != self.dev: + cache['mask'] = kwargs['attention_mask'].to(self.dev) + kwargs['attention_mask'] = cache['mask'] + tmp = self.module(*inp, **kwargs) + return tmp + + layers = model.model.decoder.layers + pergpu = math.ceil(len(layers) / len(gpus)) + for i in range(len(layers)): + layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) + + model.gpus = gpus + +def benchmark(model, input_ids, check=False): + input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) + torch.cuda.synchronize() + + cache = {'past': None} + def clear_past(i): + def tmp(layer, inp, out): + if cache['past']: + cache['past'][i] = None + return tmp + for i, layer in enumerate(model.model.decoder.layers): + layer.register_forward_hook(clear_past(i)) + + print('Benchmarking ...') + + if check: + loss = nn.CrossEntropyLoss() + tot = 0. + + def sync(): + if hasattr(model, 'gpus'): + for gpu in model.gpus: + torch.cuda.synchronize(gpu) + else: + torch.cuda.synchronize() + with torch.no_grad(): + attention_mask = torch.ones((1, input_ids.numel()), device=DEV) + times = [] + for i in range(input_ids.numel()): + tick = time.time() + out = model( + input_ids[:, i].reshape(-1), + past_key_values=cache['past'], + attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)) + ) + sync() + times.append(time.time() - tick) + print(i, times[-1]) + if check and i != input_ids.numel() - 1: + tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() + cache['past'] = list(out.past_key_values) + del out + sync() + import numpy as np + print('Median:', np.median(times)) + if check: + print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) + + +if __name__ == '__main__': + import argparse + from datautils import * + + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, + help='OPT model to load; pass `facebook/opt-X`.' + ) + parser.add_argument( + 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.' + ) + parser.add_argument( + '--seed', + type=int, default=0, help='Seed for sampling the calibration data.' + ) + parser.add_argument( + '--nsamples', type=int, default=128, + help='Number of calibration data samples.' + ) + parser.add_argument( + '--percdamp', type=float, default=.01, + help='Percent of the average Hessian diagonal to use for dampening.' + ) + parser.add_argument( + '--nearest', action='store_true', + help='Whether to run the RTN baseline.' + ) + parser.add_argument( + '--wbits', type=int, default=16, choices=[2, 3, 4, 16], + help='#bits to use for quantization; use 16 for evaluating base model.' + ) + parser.add_argument( + '--trits', action='store_true', + help='Whether to use trits for quantization.' + ) + parser.add_argument( + '--groupsize', type=int, default=-1, + help='Groupsize to use for quantization; default uses full row.' + ) + parser.add_argument( + '--sym', action='store_true', + help='Whether to perform symmetric quantization.' + ) + parser.add_argument( + '--save', type=str, default='', + help='Save quantized checkpoint under this name.' + ) + parser.add_argument( + '--load', type=str, default='', + help='Load quantized model.' + ) + parser.add_argument( + '--benchmark', type=int, default=0, + help='Number of tokens to use for benchmarking.' + ) + parser.add_argument( + '--check', action='store_true', + help='Whether to compute perplexity during benchmarking for verification.' + ) + parser.add_argument( + '--new-eval', action='store_true', + help='Whether to use the new PTB and C4 eval.' + ) + parser.add_argument( + '--faster-kernel', action='store_true', + help='Whether to use the new faster kernel for benchmarking.' + ) + parser.add_argument( + '--act-order', action='store_true', + help='Whether to apply the activation order GPTQ heuristic' + ) + + args = parser.parse_args() + + if args.load: + model = load_quant3(args.model, args.load) + else: + model = get_opt(args.model) + model.eval() + + dataloader, testloader = get_loaders( + args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen + ) + + if args.wbits < 16 and not args.nearest: + tick = time.time() + quantizers = opt_sequential(model, dataloader, DEV) + print(time.time() - tick) + + if args.benchmark: + gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + opt_multigpu(model, gpus) + else: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, :args.benchmark] + benchmark(model, input_ids, check=args.check) + if args.load: + exit() + + datasets = ['wikitext2', 'ptb', 'c4'] + if args.new_eval: + datasets = ['wikitext2', 'ptb-new', 'c4-new'] + for dataset in datasets: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen + ) + print(dataset) + opt_eval(model, testloader, DEV) + + if args.save: + opt_pack3(model, quantizers) + torch.save(model.state_dict(), args.save) diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py new file mode 100644 index 00000000..77c27e00 --- /dev/null +++ b/quantize/gptq/quant.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +import torch.nn as nn + + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure( + self, + bits, perchannel=False, sym=True, + mse=False, norm=2.4, grid=100, maxshrink=.8, + trits=False + ): + self.maxq = torch.tensor(2 ** bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +try: + import quant_cuda +except: + print('CUDA extension not installed.') + +# Assumes layer is perfectly divisible into 1024 * 1024 blocks +class Quant3Linear(nn.Module): + + def __init__(self, infeatures, outfeatures, faster=False): + super().__init__() + self.register_buffer('zeros', torch.zeros((outfeatures, 1))) + self.register_buffer('scales', torch.zeros((outfeatures, 1))) + self.register_buffer('bias', torch.zeros(outfeatures)) + self.register_buffer( + 'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int) + ) + self.faster = faster + + def pack(self, linear, scales, zeros): + self.zeros = zeros * scales + self.scales = scales.clone() + self.bias = linear.bias.clone() + + intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * 3, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + def forward(self, x): + if x.shape[-1] == x.numel(): + outshape = list(x.shape) + y = self.bias.clone() + outshape[-1] = self.bias.numel() + dtype = x.dtype + if self.faster: + x = x.half() + quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros) + else: + x = x.float() + quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros) + y = y.to(dtype) + return y.reshape(outshape) + raise ValueError('Only supports a single token currently.') + +def make_quant3(module, names, name='', faster=False): + if isinstance(module, Quant3Linear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + setattr( + module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster) + ) + for name1, child in module.named_children(): + make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster) diff --git a/quantize/gptq/quant_cuda.cpp b/quantize/gptq/quant_cuda.cpp new file mode 100644 index 00000000..1bf08941 --- /dev/null +++ b/quantize/gptq/quant_cuda.cpp @@ -0,0 +1,34 @@ +#include +#include +#include + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant3matmul_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant3matmul_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version"); +} diff --git a/quantize/gptq/quant_cuda_kernel.cu b/quantize/gptq/quant_cuda_kernel.cu new file mode 100644 index 00000000..101167f0 --- /dev/null +++ b/quantize/gptq/quant_cuda_kernel.cu @@ -0,0 +1,244 @@ +#include +#include +#include +#include +#include + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int height, + int width +); + +__global__ void VecQuant3MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const float* __restrict__ zeros, + int height, + int width +); + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT = 24; + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + height, width + ); + }) + ); +} + +void vecquant3matmul_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant3MatMulKernelFaster<<>>( + (half2*) vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + scales.data_ptr(), + zeros.data_ptr(), + height, width + ); +} + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int height, + int width +) { + int row = BLOCKHEIGHT * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t scale = scales[col]; + scalar_t zero = zeros[col]; + + scalar_t res = 0; + int i = width * row + col; + int k = 0; + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp1 = as_unsigned(mat[i]); + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); + tmp2 >>= 1; + res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; + k += 11; + res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); + tmp1 >>= 2; + res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; + k += 11; + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; + i += width; + k += 10; + } + + atomicAdd(&mul[col], res); +} + +__global__ void VecQuant3MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const float* __restrict__ zeros, + int height, + int width +) { + const int blockwidth2 = BLOCKWIDTH / 2; + + int row = BLOCKHEIGHT * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ half2 blockvec[blockwidth2]; + if (threadIdx.x < blockwidth2) + blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * blockwidth2 + threadIdx.x]; + + __shared__ half2 deq2[64][32]; + int val = threadIdx.x / 32; + int off = threadIdx.x % 32; + for (; val < 64; val += BLOCKWIDTH / 32) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0x7), __int2half_rn(val >> 3) + ); + } + + half2 scale = __float2half2_rn(scales[col]); + half2 zero = __float2half2_rn(-zeros[col]); + + int i = width * row + col; + int k = 0; + + float res = 0; + half2 res2; + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + + __syncthreads(); + + while (k < blockwidth2) { + res2 = {}; + tmp1 = as_unsigned(mat[i]); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c); + res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2); + tmp2 >>= 4; + k += 6; + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30); + res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2); + tmp1 >>= 2; + k += 5; + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); + i += width; + k += 5; + res += __half2float(res2.x) + __half2float(res2.y); + } + + atomicAdd(&mul[col], res); +} diff --git a/quantize/gptq/setup_cuda.py b/quantize/gptq/setup_cuda.py new file mode 100644 index 00000000..25231709 --- /dev/null +++ b/quantize/gptq/setup_cuda.py @@ -0,0 +1,14 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +# get current path +import os +current_path = os.path.dirname(os.path.abspath(__file__)) + +setup( + name='quant_cuda', + ext_modules=[cpp_extension.CUDAExtension( + 'quant_cuda', [f'{current_path}/quant_cuda.cpp', f'{current_path}/quant_cuda_kernel.cu'] + )], + cmdclass={'build_ext': cpp_extension.BuildExtension} +) \ No newline at end of file From 943af70adb056e89e3c6503310dde17b4946597f Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 18 Apr 2023 19:41:49 +0000 Subject: [PATCH 03/20] feat(quantize): begin to readapt with RWKV --- quantize/compress_rwkv.py | 239 +++++++++++++++++++ quantize/gptq/datautils.py | 314 +++++++++++++------------ quantize/gptq/gptq.py | 2 +- quantize/gptq/opt.py | 471 ------------------------------------- quantize/opt.py | 216 +++++++++++++++++ 5 files changed, 625 insertions(+), 617 deletions(-) create mode 100644 quantize/compress_rwkv.py delete mode 100644 quantize/gptq/opt.py create mode 100644 quantize/opt.py diff --git a/quantize/compress_rwkv.py b/quantize/compress_rwkv.py new file mode 100644 index 00000000..dbe17a87 --- /dev/null +++ b/quantize/compress_rwkv.py @@ -0,0 +1,239 @@ +import time +import torch +import torch.nn as nn + +from rwkv.model import RWKV +from gptq.gptq import * +from gptq.modelutils import * +from gptq.quant import * +from gptq.datautils import * + +# TODO: perform packing on GPU +def opt_pack3(model, quantizers): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant3(model, quantizers, faster=args.faster_kernel) + qlayers = find_layers(model, [Quant3Linear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name] = quantizers[name].cpu() + qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) + print('Done.') + return model + +@torch.no_grad() +def quantize_model(model, train_tokens, device): + print('Starting ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.decoder.layers + + # Load layer to device + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device) + model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device) + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.to(device) + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.to(device) + layers[0] = layers[0].to(device) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device + ) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + raise ValueError + layers[0] = Catcher(layers[0]) + + for batch in train_tokens: + try: + # model(batch[0].to(device)) + # IndexError: invalid index of a 0-dim tensor. + # Use `tensor.item()` in Python or `tensor.item()` + # in C++ to convert a 0-dim tensor to a number + model(batch[0].to(device)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() + model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.cpu() + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + + print('Ready.') + + quantizers = {} + for i in range(len(layers)): + layer = layers[i].to(device) + + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits + ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + for h in handles: + h.remove() + + for name in subset: + print(i, name) + print('Quantizing ...') + gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer + gptq[name].free() + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + model.config.use_cache = use_cache + + return quantizers + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + '--model_path', type=str, + help='Path to model checkpoint file.' + ) + + parser.add_argument( + '--dataset_name', type=str, choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.' + ) + + parser.add_argument( + '--wbits', type=int, default=16, choices=[2, 3, 4, 16], + help='#bits to use for quantization; use 16 for evaluating base model.' + ) + + parser.add_argument( + '--groupsize', type=int, default=-1, + help='Groupsize to use for quantization; default uses full row.' + ) + + parser.add_argument( + '--save', type=str, default='', + help='Save quantized checkpoint under this name.' + ) + + # ==== DEFAULT ==== + parser.add_argument( + '--seed', + type=int, default=0, help='Seed for sampling the calibration data.' + ) + parser.add_argument( + '--nsamples', type=int, default=128, + help='Number of calibration data samples.' + ) + parser.add_argument( + '--percdamp', type=float, default=.01, + help='Percent of the average Hessian diagonal to use for dampening.' + ) + + parser.add_argument( + '--nearest', action='store_true', + help='Whether to run the RTN baseline.' + ) + + parser.add_argument( + '--trits', action='store_true', + help='Whether to use trits for quantization.' + ) + + parser.add_argument( + '--sym', action='store_true', + help='Whether to perform symmetric quantization.' + ) + + parser.add_argument( + '--faster-kernel', action='store_true', + help='Whether to use the new faster kernel for benchmarking.' + ) + + parser.add_argument( + '--act-order', action='store_true', + help='Whether to apply the activation order GPTQ heuristic' + ) + + args = parser.parse_args() + + # FIXME: Seems like quantization with OPT is not working in CPU mode + # device = torch.device('cpu') + device = torch.device('cuda:0') + + # Model + # model = model = RWKV(args.model_path, strategy='cpu fp32') + + def skip(*args, **kwargs): pass + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import OPTForCausalLM + model = OPTForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype='auto') + model.seqlen = model.config.max_position_embeddings + + model.eval() + + # Dataset + train_tokens, test_tokens = get_loaders( + dataset_name=args.dataset_name, + nsamples=args.nsamples, + seed=args.seed, + seqlen=model.seqlen, + model="facebook/opt-125m", + # model=None + ) + + print(f'{len(train_tokens)} train tokens in the text') + print(f'{len(test_tokens)} test tokens in the text') + + if args.wbits < 16 and not args.nearest: + start_time = time.time() + quantizers = quantize_model(model, train_tokens, device) + end_time = time.time() + print('Quantization time: ', end_time - start_time) + + if args.save: + print('Saving quantized model to ', args.save) + opt_pack3(model, quantizers) + torch.save(model.state_dict(), args.save) \ No newline at end of file diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py index 193953c5..4ed1e39b 100644 --- a/quantize/gptq/datautils.py +++ b/quantize/gptq/datautils.py @@ -1,175 +1,199 @@ import numpy as np import torch +import os +import pathlib +import tokenizers +import random +from datasets import load_dataset def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) - def get_wikitext2(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') - testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') - - import random + is_rwkv = True if model is None else False + + if is_rwkv: + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + print('Loading RWKV tokenizer') + tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '../20B_tokenizer.json' + tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) + trainenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(traindata['text'])).ids, dtype=torch.long), 0) + testenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(testdata['text'])).ids, dtype=torch.long), 0) + else: + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + print('Loading tokenizer') + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + random.seed(seed) trainloader = [] + shape = trainenc.shape if is_rwkv else trainenc.input_ids.shape + trainenc = trainenc if is_rwkv else trainenc.input_ids + for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + i = random.randint(0, shape[1] - seqlen - 1) j = i + seqlen - inp = trainenc.input_ids[:, i:j] + inp = trainenc[:, i:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) return trainloader, testenc def get_ptb(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') - testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc + raise NotImplementedError('PTB not implemented yet') + # from datasets import load_dataset + # traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + # valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') + + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + # trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') + # testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') + + # import random + # random.seed(seed) + # trainloader = [] + # for _ in range(nsamples): + # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + # j = i + seqlen + # inp = trainenc.input_ids[:, i:j] + # tar = inp.clone() + # tar[:, :-1] = -100 + # trainloader.append((inp, tar)) + # return trainloader, testenc def get_c4(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset( - 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' - ) - valdata = load_dataset( - 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' - ) - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - - import random - random.seed(0) - valenc = [] - for _ in range(256): - while True: - i = random.randint(0, len(valdata) - 1) - tmp = tokenizer(valdata[i]['text'], return_tensors='pt') - if tmp.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - valenc.append(tmp.input_ids[:, i:j]) - valenc = torch.hstack(valenc) - class TokenizerWrapper: - def __init__(self, input_ids): - self.input_ids = input_ids - valenc = TokenizerWrapper(valenc) - - return trainloader, valenc + raise NotImplementedError('C4 not implemented yet') + # from datasets import load_dataset + # traindata = load_dataset( + # 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' + # ) + # valdata = load_dataset( + # 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' + # ) + + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + # import random + # random.seed(seed) + # trainloader = [] + # for _ in range(nsamples): + # while True: + # i = random.randint(0, len(traindata) - 1) + # trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + # if trainenc.input_ids.shape[1] >= seqlen: + # break + # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + # j = i + seqlen + # inp = trainenc.input_ids[:, i:j] + # tar = inp.clone() + # tar[:, :-1] = -100 + # trainloader.append((inp, tar)) + + # import random + # random.seed(0) + # valenc = [] + # for _ in range(256): + # while True: + # i = random.randint(0, len(valdata) - 1) + # tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + # if tmp.input_ids.shape[1] >= seqlen: + # break + # i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + # j = i + seqlen + # valenc.append(tmp.input_ids[:, i:j]) + # valenc = torch.hstack(valenc) + # class TokenizerWrapper: + # def __init__(self, input_ids): + # self.input_ids = input_ids + # valenc = TokenizerWrapper(valenc) + + # return trainloader, valenc def get_ptb_new(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') - testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc + raise NotImplementedError('PTB not implemented yet') + # from datasets import load_dataset + # traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + # testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + # trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') + # testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') + + # import random + # random.seed(seed) + # trainloader = [] + # for _ in range(nsamples): + # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + # j = i + seqlen + # inp = trainenc.input_ids[:, i:j] + # tar = inp.clone() + # tar[:, :-1] = -100 + # trainloader.append((inp, tar)) + # return trainloader, testenc def get_c4_new(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset( - 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' - ) - valdata = load_dataset( - 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' - ) - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - - valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') - valenc = valenc.input_ids[:, :(256 * seqlen)] - - class TokenizerWrapper: - def __init__(self, input_ids): - self.input_ids = input_ids - valenc = TokenizerWrapper(valenc) - - return trainloader, valenc + raise NotImplementedError('C4 not implemented yet') + # from datasets import load_dataset + # traindata = load_dataset( + # 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' + # ) + # valdata = load_dataset( + # 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' + # ) + + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + # import random + # random.seed(seed) + # trainloader = [] + # for _ in range(nsamples): + # while True: + # i = random.randint(0, len(traindata) - 1) + # trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + # if trainenc.input_ids.shape[1] >= seqlen: + # break + # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + # j = i + seqlen + # inp = trainenc.input_ids[:, i:j] + # tar = inp.clone() + # tar[:, :-1] = -100 + # trainloader.append((inp, tar)) + + # valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + # valenc = valenc.input_ids[:, :(256 * seqlen)] + + # class TokenizerWrapper: + # def __init__(self, input_ids): + # self.input_ids = input_ids + # valenc = TokenizerWrapper(valenc) + + # return trainloader, valenc def get_loaders( - name, nsamples=128, seed=0, seqlen=2048, model='' + dataset_name, nsamples, seed, seqlen, model ): - if 'wikitext2' in name: + if 'wikitext2' in dataset_name: return get_wikitext2(nsamples, seed, seqlen, model) - if 'ptb' in name: - if 'new' in name: - return get_ptb_new(nsamples, seed, seqlen, model) - return get_ptb(nsamples, seed, seqlen, model) - if 'c4' in name: - if 'new' in name: - return get_c4_new(nsamples, seed, seqlen, model) - return get_c4(nsamples, seed, seqlen, model) + if 'ptb' in dataset_name: + raise NotImplementedError('PTB is not supported yet') + # if 'new' in dataset_name: + # return get_ptb_new(nsamples, seed, seqlen, model) + # return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in dataset_name: + raise NotImplementedError('C4 is not supported yet') + # if 'new' in dataset_name: + # return get_c4_new(nsamples, seed, seqlen, model) + # return get_c4(nsamples, seed, seqlen, model) diff --git a/quantize/gptq/gptq.py b/quantize/gptq/gptq.py index 2bae9786..ae857e58 100644 --- a/quantize/gptq/gptq.py +++ b/quantize/gptq/gptq.py @@ -5,7 +5,7 @@ import torch.nn as nn import transformers -from quant import * +from .quant import * DEBUG = False diff --git a/quantize/gptq/opt.py b/quantize/gptq/opt.py deleted file mode 100644 index 60112b24..00000000 --- a/quantize/gptq/opt.py +++ /dev/null @@ -1,471 +0,0 @@ -import time - -import torch -import torch.nn as nn - -from gptq import * -from modelutils import * -from quant import * - - -def get_opt(model): - import torch - def skip(*args, **kwargs): - pass - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - from transformers import OPTForCausalLM - model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') - model.seqlen = model.config.max_position_embeddings - return model - -@torch.no_grad() -def opt_sequential(model, dataloader, dev): - print('Starting ...') - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.decoder.layers - - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) - model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.to(dev) - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.to(dev) - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev - ) - cache = {'i': 0, 'attention_mask': None} - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - raise ValueError - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].to(dev)) - except ValueError: - pass - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() - model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.cpu() - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - - print('Ready.') - - quantizers = {} - for i in range(len(layers)): - layer = layers[i].to(dev) - - subset = find_layers(layer) - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure( - args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits - ) - - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - return tmp - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - for h in handles: - h.remove() - - for name in subset: - print(i, name) - print('Quantizing ...') - gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer - gptq[name].free() - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - - model.config.use_cache = use_cache - - return quantizers - -@torch.no_grad() -def opt_eval(model, testenc, dev): - print('Evaluating ...') - - testenc = testenc.input_ids - nsamples = testenc.numel() // model.seqlen - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.decoder.layers - - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) - model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.to(dev) - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.to(dev) - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev - ) - cache = {'i': 0, 'attention_mask': None} - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - raise ValueError - layers[0] = Catcher(layers[0]) - for i in range(nsamples): - batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) - try: - model(batch) - except ValueError: - pass - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() - model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.cpu() - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - - for i in range(len(layers)): - print(i) - layer = layers[i].to(dev) - - if args.nearest: - subset = find_layers(layer) - for name in subset: - quantizer = Quantizer() - quantizer.configure( - args.wbits, perchannel=True, sym=args.sym, mse=False - ) - W = subset[name].weight.data - quantizer.find_params(W, weight=True) - subset[name].weight.data = quantize( - W, quantizer.scale, quantizer.zero, quantizer.maxq - ).to(next(iter(layer.parameters())).dtype) - - for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - layers[i] = layer.cpu() - del layer - torch.cuda.empty_cache() - inps, outs = outs, inps - - if model.model.decoder.final_layer_norm is not None: - model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) - if model.model.decoder.project_out is not None: - model.model.decoder.project_out = model.model.decoder.project_out.to(dev) - model.lm_head = model.lm_head.to(dev) - - testenc = testenc.to(dev) - nlls = [] - for i in range(nsamples): - hidden_states = inps[i].unsqueeze(0) - if model.model.decoder.final_layer_norm is not None: - hidden_states = model.model.decoder.final_layer_norm(hidden_states) - if model.model.decoder.project_out is not None: - hidden_states = model.model.decoder.project_out(hidden_states) - lm_logits = model.lm_head(hidden_states) - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = testenc[ - :, (i * model.seqlen):((i + 1) * model.seqlen) - ][:, 1:] - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - neg_log_likelihood = loss.float() * model.seqlen - nlls.append(neg_log_likelihood) - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) - print(ppl.item()) - - model.config.use_cache = use_cache - -# TODO: perform packing on GPU -def opt_pack3(model, quantizers): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - make_quant3(model, quantizers, faster=args.faster_kernel) - qlayers = find_layers(model, [Quant3Linear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name] = quantizers[name].cpu() - qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) - print('Done.') - return model - -def load_quant3(model, checkpoint): - from transformers import OPTConfig, OPTForCausalLM - config = OPTConfig.from_pretrained(model) - def noop(*args, **kwargs): - pass - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - transformers.modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = OPTForCausalLM(config) - torch.set_default_dtype(torch.float) - model = model.eval() - layers = find_layers(model) - for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']: - if name in layers: - del layers[name] - make_quant3(model, layers, faster=args.faster_kernel) - - print('Loading model ...') - model.load_state_dict(torch.load(checkpoint)) - model.seqlen = model.config.max_position_embeddings - print('Done.') - - return model - -def opt_multigpu(model, gpus): - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0]) - model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0]) - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0]) - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1]) - if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm: - model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1]) - import copy - model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) - - cache = {'mask': None} - - class MoveModule(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - self.dev = next(iter(self.module.parameters())).device - def forward(self, *inp, **kwargs): - inp = list(inp) - if inp[0].device != self.dev: - inp[0] = inp[0].to(self.dev) - if cache['mask'] is None or cache['mask'].device != self.dev: - cache['mask'] = kwargs['attention_mask'].to(self.dev) - kwargs['attention_mask'] = cache['mask'] - tmp = self.module(*inp, **kwargs) - return tmp - - layers = model.model.decoder.layers - pergpu = math.ceil(len(layers) / len(gpus)) - for i in range(len(layers)): - layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) - - model.gpus = gpus - -def benchmark(model, input_ids, check=False): - input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) - torch.cuda.synchronize() - - cache = {'past': None} - def clear_past(i): - def tmp(layer, inp, out): - if cache['past']: - cache['past'][i] = None - return tmp - for i, layer in enumerate(model.model.decoder.layers): - layer.register_forward_hook(clear_past(i)) - - print('Benchmarking ...') - - if check: - loss = nn.CrossEntropyLoss() - tot = 0. - - def sync(): - if hasattr(model, 'gpus'): - for gpu in model.gpus: - torch.cuda.synchronize(gpu) - else: - torch.cuda.synchronize() - with torch.no_grad(): - attention_mask = torch.ones((1, input_ids.numel()), device=DEV) - times = [] - for i in range(input_ids.numel()): - tick = time.time() - out = model( - input_ids[:, i].reshape(-1), - past_key_values=cache['past'], - attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)) - ) - sync() - times.append(time.time() - tick) - print(i, times[-1]) - if check and i != input_ids.numel() - 1: - tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() - cache['past'] = list(out.past_key_values) - del out - sync() - import numpy as np - print('Median:', np.median(times)) - if check: - print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) - - -if __name__ == '__main__': - import argparse - from datautils import * - - parser = argparse.ArgumentParser() - - parser.add_argument( - 'model', type=str, - help='OPT model to load; pass `facebook/opt-X`.' - ) - parser.add_argument( - 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], - help='Where to extract calibration data from.' - ) - parser.add_argument( - '--seed', - type=int, default=0, help='Seed for sampling the calibration data.' - ) - parser.add_argument( - '--nsamples', type=int, default=128, - help='Number of calibration data samples.' - ) - parser.add_argument( - '--percdamp', type=float, default=.01, - help='Percent of the average Hessian diagonal to use for dampening.' - ) - parser.add_argument( - '--nearest', action='store_true', - help='Whether to run the RTN baseline.' - ) - parser.add_argument( - '--wbits', type=int, default=16, choices=[2, 3, 4, 16], - help='#bits to use for quantization; use 16 for evaluating base model.' - ) - parser.add_argument( - '--trits', action='store_true', - help='Whether to use trits for quantization.' - ) - parser.add_argument( - '--groupsize', type=int, default=-1, - help='Groupsize to use for quantization; default uses full row.' - ) - parser.add_argument( - '--sym', action='store_true', - help='Whether to perform symmetric quantization.' - ) - parser.add_argument( - '--save', type=str, default='', - help='Save quantized checkpoint under this name.' - ) - parser.add_argument( - '--load', type=str, default='', - help='Load quantized model.' - ) - parser.add_argument( - '--benchmark', type=int, default=0, - help='Number of tokens to use for benchmarking.' - ) - parser.add_argument( - '--check', action='store_true', - help='Whether to compute perplexity during benchmarking for verification.' - ) - parser.add_argument( - '--new-eval', action='store_true', - help='Whether to use the new PTB and C4 eval.' - ) - parser.add_argument( - '--faster-kernel', action='store_true', - help='Whether to use the new faster kernel for benchmarking.' - ) - parser.add_argument( - '--act-order', action='store_true', - help='Whether to apply the activation order GPTQ heuristic' - ) - - args = parser.parse_args() - - if args.load: - model = load_quant3(args.model, args.load) - else: - model = get_opt(args.model) - model.eval() - - dataloader, testloader = get_loaders( - args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen - ) - - if args.wbits < 16 and not args.nearest: - tick = time.time() - quantizers = opt_sequential(model, dataloader, DEV) - print(time.time() - tick) - - if args.benchmark: - gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] - if len(gpus) > 1: - opt_multigpu(model, gpus) - else: - model = model.to(DEV) - if args.benchmark: - input_ids = next(iter(dataloader))[0][:, :args.benchmark] - benchmark(model, input_ids, check=args.check) - if args.load: - exit() - - datasets = ['wikitext2', 'ptb', 'c4'] - if args.new_eval: - datasets = ['wikitext2', 'ptb-new', 'c4-new'] - for dataset in datasets: - dataloader, testloader = get_loaders( - dataset, seed=args.seed, model=args.model, seqlen=model.seqlen - ) - print(dataset) - opt_eval(model, testloader, DEV) - - if args.save: - opt_pack3(model, quantizers) - torch.save(model.state_dict(), args.save) diff --git a/quantize/opt.py b/quantize/opt.py new file mode 100644 index 00000000..16a776f5 --- /dev/null +++ b/quantize/opt.py @@ -0,0 +1,216 @@ +import time + +import torch +import torch.nn as nn + +from gptq.gptq import * +from gptq.modelutils import * +from gptq.quant import * + + +def get_opt(model): + import torch + def skip(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import OPTForCausalLM + model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') + model.seqlen = model.config.max_position_embeddings + return model + +@torch.no_grad() +def opt_sequential(model, dataloader, dev): + print('Starting ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.decoder.layers + + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) + model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.to(dev) + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() + model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() + if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: + model.model.decoder.project_out = model.model.decoder.project_out.cpu() + if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: + model.model.decoder.project_in = model.model.decoder.project_in.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + + print('Ready.') + + quantizers = {} + for i in range(len(layers)): + layer = layers[i].to(dev) + + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits + ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + for h in handles: + h.remove() + + for name in subset: + print(i, name) + print('Quantizing ...') + gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer + gptq[name].free() + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + model.config.use_cache = use_cache + + return quantizers + +if __name__ == '__main__': + import argparse + from gptq.datautils import * + + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, + help='OPT model to load; pass `facebook/opt-X`.' + ) + parser.add_argument( + 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.' + ) + parser.add_argument( + '--seed', + type=int, default=0, help='Seed for sampling the calibration data.' + ) + parser.add_argument( + '--nsamples', type=int, default=128, + help='Number of calibration data samples.' + ) + parser.add_argument( + '--percdamp', type=float, default=.01, + help='Percent of the average Hessian diagonal to use for dampening.' + ) + parser.add_argument( + '--nearest', action='store_true', + help='Whether to run the RTN baseline.' + ) + parser.add_argument( + '--wbits', type=int, default=16, choices=[2, 3, 4, 16], + help='#bits to use for quantization; use 16 for evaluating base model.' + ) + parser.add_argument( + '--trits', action='store_true', + help='Whether to use trits for quantization.' + ) + parser.add_argument( + '--groupsize', type=int, default=-1, + help='Groupsize to use for quantization; default uses full row.' + ) + parser.add_argument( + '--sym', action='store_true', + help='Whether to perform symmetric quantization.' + ) + parser.add_argument( + '--save', type=str, default='', + help='Save quantized checkpoint under this name.' + ) + parser.add_argument( + '--load', type=str, default='', + help='Load quantized model.' + ) + parser.add_argument( + '--benchmark', type=int, default=0, + help='Number of tokens to use for benchmarking.' + ) + parser.add_argument( + '--check', action='store_true', + help='Whether to compute perplexity during benchmarking for verification.' + ) + parser.add_argument( + '--new-eval', action='store_true', + help='Whether to use the new PTB and C4 eval.' + ) + parser.add_argument( + '--faster-kernel', action='store_true', + help='Whether to use the new faster kernel for benchmarking.' + ) + parser.add_argument( + '--act-order', action='store_true', + help='Whether to apply the activation order GPTQ heuristic' + ) + + args = parser.parse_args() + + model = get_opt(args.model) + model.eval() + + dataloader, testloader = get_loaders( + args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen + ) + + print(f'{len(dataloader)} train tokens in the text') + print(f'{len(testloader)} test tokens in the text') + + + if args.wbits < 16 and not args.nearest: + tick = time.time() + quantizers = opt_sequential(model, dataloader, DEV) + print(time.time() - tick) + + # if args.save: + # opt_pack3(model, quantizers) + # torch.save(model.state_dict(), args.save) From 629fc9b34de9f84d598c9d754f1588147dab72dd Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Sun, 23 Apr 2023 12:50:23 +0000 Subject: [PATCH 04/20] breaking(quantize): draft gptq rwkv --- quantize/tmp_rwkv.py | 322 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 quantize/tmp_rwkv.py diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py new file mode 100644 index 00000000..c87c0efc --- /dev/null +++ b/quantize/tmp_rwkv.py @@ -0,0 +1,322 @@ + +from rwkv.model import RWKV +from gptq.gptq import * +from gptq.datautils import * +import os +import torch.nn.functional as F +import gc +import re + +if os.environ.get('RWKV_JIT_ON') != '0': + os.environ["RWKV_JIT_ON"] = '1' + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script +else: + MyModule = torch.nn.Module + def __nop(ob): + return ob + MyFunction = __nop + MyStatic = __nop + +class GPTQ_RWKV(RWKV): + + ### begin GPTQ + class GPTQ: + def __init__(): + pass + + def add_batch(self): + pass + + def fasterquant(self): + pass + + def __init__(self, model, strategy): + super().__init__(model, strategy) + + self.subset = {} + self.gptq = {} + ### end GPTQ + + def _filter_layer_within_block(self, layer_id, model): + + def _create_layer(model, name): + if len(model.w[name].shape) == 1: + #TODO: maybe reshape (-1, 1) ? + w = model.w[name].reshape(1, -1) + layer = nn.Linear(*w.shape, bias=False) + layer.weight = nn.Parameter(w) + else: + layer = nn.Linear(*model.w[name].shape, bias=False) + layer.weight = nn.Parameter(model.w[name]) + return layer + + res = {} + dd = model.strategy[layer_id] + dev = dd.device + + for name in model.w.keys(): + if re.match(f'^blocks\.{layer_id}\..*\.weight$', name): + layer = _create_layer(model, name) + print(f"{name} = {model.w[name].shape}") + + if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name): + layer = layer.to(device=dev, non_blocking=True) + + res[name] = layer + + return res + + def alloc_gptq(self, layer_id, subset): + + self.subset = self.__filter_layer_within_block(layer_id, model) + + for name in subset: + self.gptq[name] = GPTQ(subset[name]) + self.gptq[name].quantizer = Quantizer() + self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False) + + def free_gptq(self): + del self.subset + del self.gptq + gc.collect() + + def fasterquant(self, layer_id, quantizers): + for name in self.subset: + print(f"Quantizing {name} of layer {layer_id}") + #TODO: add argparse to fasterquand + self.gptq[name].fastquant(percdamp=0.01, groupsize=-1, actorder=False) + # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers[name] = self.gptq[name].quantizer + + @MyFunction + def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = (kx @ kw).float() + # k = (kx @ kw.weight).float() + # kw.add_batch(kx) + v = (vx @ vw).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = (r * wkv) @ ow + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + @MyFunction + def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = (kx @ kw).float() + v = (vx @ vw).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = (r * sx) @ ow + return x + out, xx[-1,:], aa, bb, pp + + @MyFunction + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + vx = torch.square(torch.relu(kx @ kw)) + out = r * (vx @ vw) + return x + out, xx[-1,:] + + def forward_block(self, x, state, i, seq_mode, full_output=False): + with torch.no_grad(): + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 + state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + + if seq_mode: + if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': + ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = self.gptq[f'{att}key.weight'] + vw = self.gptq[f'{att}value.weight'] + rw = self.gptq[f'{att}receptance.weight'] + ow = self.gptq[f'{att}output.weight'] + + kmx = self.w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x + krx = self.w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x + kmy = self.w[f'{att}key.weight_my'] if wtype == torch.uint8 else x + kry = self.w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x + vmx = self.w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x + vrx = self.w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x + vmy = self.w[f'{att}value.weight_my'] if wtype == torch.uint8 else x + vry = self.w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x + rmx = self.w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = self.w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = self.w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x + rry = self.w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x + omx = self.w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x + orx = self.w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x + omy = self.w[f'{att}output.weight_my'] if wtype == torch.uint8 else x + ory = self.w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x + + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( + x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3], + ln_w=self.gptq[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], + k_mix=self.w[f'{att}time_mix_k'], v_mix=self.w[f'{att}time_mix_v'], r_mix=self.w[f'{att}time_mix_r'], + t_decay=self.w[f'{att}time_decay'], t_first=self.w[f'{att}time_first'], + kw=kw, vw=vw, rw=rw, pw=ow, + kmx=kmx, krx=krx, kmy=kmy, kry=kry, + vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, + rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, + omx=omx, orx=orx, omy=omy, ory=ory, + ) + + if dd.stream: + del kw, vw, rw, ow + + kw = self.gptq[f'{ffn}key.weight'] + vw = self.gptq[f'{ffn}value.weight'] + rw = self.gptq[f'{ffn}receptance.weight'] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + + kmx = self.w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x + krx = self.w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x + kmy = self.w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x + kry = self.w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x + vmx = self.w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x + vrx = self.w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x + vmy = self.w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x + vry = self.w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x + rmx = self.w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = self.w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = self.w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x + rry = self.w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+4] = FFN( + x=x, sx=state[i*5+4], + ln_w=self.gptq[f'{bbb}ln2.weight'], ln_b=self.w[f'{bbb}ln2.bias'], + k_mix=self.w[f'{ffn}time_mix_k'], r_mix=self.w[f'{ffn}time_mix_r'], + kw=kw, vw=vw, rw=rw, + kmx=kmx, krx=krx, kmy=kmy, kry=kry, + vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, + rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, + ) + + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i+1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + #TODO: Add GPTQ support for head & ln_out + x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) + if self.w['head.weight'].dtype != torch.uint8: + x = x @ self.w['head.weight'] + else: + if seq_mode and full_output: + x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + else: + x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + + return x.float(), state + + +NSAMPLES=1 +HIDDEN_SIZE=768 +SEQLEN=HIDDEN_SIZE # TODO: this is chosen by the model + +train_tokens, test_tokens = get_loaders( + dataset_name="wikitext2", + nsamples=NSAMPLES, + seed=42, + seqlen=SEQLEN, + model=None +) + +tokens = [inp.squeeze() for inp, _ in train_tokens] + +model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') + +with torch.no_grad(): + seq_mode = len(tokens) > 1 + x = model.w['emb.weight'][tokens if seq_mode else tokens[0]] + + quantizers = {} + + for layer_id in range(model.args.n_layer): + + model.alloc_gptq(layer_id, model) + + for j in range(NSAMPLES): + _ = model.forward_block(x[j].unsqueeze(0), state=None, i=layer_id, seq_mode=seq_mode, full_output=full_output) + + model.fasterquant(layer_id, quantizers) + + model.free_gptq() + +# TODO: create a function that check if all weights were properly quantized From 4a194766bdd01bbb6e6f313bef24cf019c9bcec1 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Mon, 24 Apr 2023 19:08:28 +0000 Subject: [PATCH 05/20] fix(quantize): GPTQ hooks now work with RWKV --- quantize/tmp_rwkv.py | 221 ++++++++++++++++++++++++++----------------- 1 file changed, 135 insertions(+), 86 deletions(-) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index c87c0efc..dc1695f5 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -1,108 +1,115 @@ from rwkv.model import RWKV -from gptq.gptq import * from gptq.datautils import * +from gptq.quant import Quantizer + import os import torch.nn.functional as F +import torch.nn as nn import gc +import math import re -if os.environ.get('RWKV_JIT_ON') != '0': - os.environ["RWKV_JIT_ON"] = '1' - MyModule = torch.jit.ScriptModule - MyFunction = torch.jit.script_method - MyStatic = torch.jit.script -else: - MyModule = torch.nn.Module - def __nop(ob): - return ob - MyFunction = __nop - MyStatic = __nop - class GPTQ_RWKV(RWKV): ### begin GPTQ class GPTQ: - def __init__(): - pass + def __init__(self, weight, name): + #TODO: Remove name, only used for debugging + self.name = name + self.weight = weight.clone() + self.dev = weight.device + # In GPTQ, they use nn.Linear(x) which performs x @ w.T but in RWKV, we perform x @ w instead + # Problem is self.H is a square matrix which depends on self.columns = W.shape[1] in the original code + # But if we keep it that way, this will break self.H += inp.matmul(inp.t()) because inp.shape[1] != W.shape[1] + # Thus, we have to use self.W.shape[0] instead + self.columns = self.weight.shape[0] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + tmp = 1 if len(inp.shape) == 1 else inp.shape[0] - def add_batch(self): - pass + # Assume weight come from nn.Linear + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) def fasterquant(self): pass + ### end GPTQ + ### begin GPTQ_RWKV def __init__(self, model, strategy): super().__init__(model, strategy) - + #TODO: add assert to only quantize in CPU FP32 mode self.subset = {} self.gptq = {} - ### end GPTQ - def _filter_layer_within_block(self, layer_id, model): - - def _create_layer(model, name): - if len(model.w[name].shape) == 1: - #TODO: maybe reshape (-1, 1) ? - w = model.w[name].reshape(1, -1) - layer = nn.Linear(*w.shape, bias=False) - layer.weight = nn.Parameter(w) - else: - layer = nn.Linear(*model.w[name].shape, bias=False) - layer.weight = nn.Parameter(model.w[name]) - return layer - - res = {} - dd = model.strategy[layer_id] + def _fill_subset(self, layer_id): + # Keep only layer within block layer_id + dd = self.strategy[layer_id] dev = dd.device - for name in model.w.keys(): + for name in self.w.keys(): if re.match(f'^blocks\.{layer_id}\..*\.weight$', name): - layer = _create_layer(model, name) - print(f"{name} = {model.w[name].shape}") + tensor = self.w[name] + print(f"{name} = {self.w[name].shape}") if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name): - layer = layer.to(device=dev, non_blocking=True) - - res[name] = layer + tensor = tensor.to(device=dev, non_blocking=True) - return res + self.subset[name] = tensor - def alloc_gptq(self, layer_id, subset): - - self.subset = self.__filter_layer_within_block(layer_id, model) + def alloc_gptq(self, layer_id): + + self._fill_subset(layer_id) - for name in subset: - self.gptq[name] = GPTQ(subset[name]) + for name in self.subset: + self.gptq[name] = self.GPTQ(self.subset[name], name) self.gptq[name].quantizer = Quantizer() + #TODO: add argparse to configure self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False) def free_gptq(self): - del self.subset - del self.gptq + if len(self.subset) > 0: del self.subset + if len(self.gptq) > 0: del self.gptq gc.collect() def fasterquant(self, layer_id, quantizers): for name in self.subset: print(f"Quantizing {name} of layer {layer_id}") - #TODO: add argparse to fasterquand + #TODO: add argparse to fastquant self.gptq[name].fastquant(percdamp=0.01, groupsize=-1, actorder=False) # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) quantizers[name] = self.gptq[name].quantizer + # TODO: may be free gptq here to save memory - @MyFunction + ### end GPTQ_RWKV + + ### begin RWKV def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) + ln_w.add_batch(x) kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(rx @ rw) - k = (kx @ kw).float() - # k = (kx @ kw.weight).float() - # kw.add_batch(kx) - v = (vx @ vw).float() + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + k = (kx @ kw.weight).float() + kw.add_batch(kx) + v = (vx @ vw.weight).float() + vw.add_batch(vx) ww = t_first + k p = torch.maximum(pp, ww) @@ -114,20 +121,24 @@ def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t e1 = torch.exp(ww - p) e2 = torch.exp(k - p) - out = (r * wkv) @ ow + out = (r * wkv) @ ow.weight + ow.add_batch((r * wkv)) return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p - @MyFunction def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) + ln_w.add_batch(x) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(rx @ rw) - k = (kx @ kw).float() - v = (vx @ vw).float() + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + k = (kx @ kw.weight).float() + kw.add_batch(kx) + v = (vx @ vw.weight).float() + vw.add_batch(vx) T = x.shape[0] for t in range(T): @@ -145,19 +156,49 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p - out = (r * sx) @ ow + out = (r * sx) @ ow.weight + ow.add_batch((r * sx)) return x + out, xx[-1,:], aa, bb, pp - @MyFunction + def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) + ln_w.add_batch(x) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + vx = torch.square(torch.relu(kx @ kw.weight)) + kw.add_batch(kx) + out = r * (vx @ vw.weight) + vw.add_batch(vx) + return x + out, xx + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + # x = (2048, 768) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) + # xx = (2048, 768) + ln_w.add_batch(x) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + # sx = (2048, 768) kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(rx @ rw) - vx = torch.square(torch.relu(kx @ kw)) - out = r * (vx @ vw) + # kx = (2048, 768) + # rx = (2048, 768) + + r = torch.sigmoid(rx @ rw.weight) + # r = (2048, 768) + rw.add_batch(rx) + print("kx: ", kx.shape) + print("kw.weight: ", kw.weight.shape) + vx = torch.square(torch.relu(kx @ kw.weight)) + # vx = (2048, 3072) + # kx: (2048, 768) + # kw.weight: (768, 3072) + # vx: (2048, 3072) + kw.add_batch(kx) + out = r * (vx @ vw.weight) + vw.add_batch(vx) return x + out, xx[-1,:] def forward_block(self, x, state, i, seq_mode, full_output=False): @@ -222,7 +263,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): ln_w=self.gptq[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], k_mix=self.w[f'{att}time_mix_k'], v_mix=self.w[f'{att}time_mix_v'], r_mix=self.w[f'{att}time_mix_r'], t_decay=self.w[f'{att}time_decay'], t_first=self.w[f'{att}time_first'], - kw=kw, vw=vw, rw=rw, pw=ow, + kw=kw, vw=vw, rw=rw, ow=ow, kmx=kmx, krx=krx, kmy=kmy, kry=kry, vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, @@ -284,39 +325,47 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) return x.float(), state + + ### end RWKV - -NSAMPLES=1 +NSAMPLES=2 HIDDEN_SIZE=768 -SEQLEN=HIDDEN_SIZE # TODO: this is chosen by the model +SEQLEN=2048 # TODO: this is chosen by the model -train_tokens, test_tokens = get_loaders( - dataset_name="wikitext2", - nsamples=NSAMPLES, - seed=42, - seqlen=SEQLEN, - model=None -) +# train_tokens, test_tokens = get_loaders( +# dataset_name="wikitext2", +# nsamples=NSAMPLES, +# seed=42, +# seqlen=SEQLEN, +# model=None +# ) -tokens = [inp.squeeze() for inp, _ in train_tokens] +# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) +tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) +print("tokens.shape", tokens.shape) model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') with torch.no_grad(): seq_mode = len(tokens) > 1 x = model.w['emb.weight'][tokens if seq_mode else tokens[0]] - + quantizers = {} for layer_id in range(model.args.n_layer): - model.alloc_gptq(layer_id, model) + model.alloc_gptq(layer_id) + # TODO: call add_batch() for each layer inside att_seq etc function for j in range(NSAMPLES): - _ = model.forward_block(x[j].unsqueeze(0), state=None, i=layer_id, seq_mode=seq_mode, full_output=full_output) + _ = model.forward_block(x[j], state=None, i=layer_id, seq_mode=seq_mode) + + # model.fasterquant(layer_id, quantizers) - model.fasterquant(layer_id, quantizers) + # model.free_gptq() - model.free_gptq() + #TODO: Since we quantize per block, we should pass the outputs of block 0 to input of block 1 ? + # inps, outs = outs, inps # TODO: create a function that check if all weights were properly quantized +print("Done") \ No newline at end of file From dba26705fd710958837cb3ff1080abef5bc1d251 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 25 Apr 2023 08:09:52 +0000 Subject: [PATCH 06/20] feat(quantize): link fasterquant with RWKV + remove 1D tensor quantization for now --- quantize/tmp_rwkv.py | 127 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index dc1695f5..e45c055b 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -1,11 +1,12 @@ from rwkv.model import RWKV from gptq.datautils import * -from gptq.quant import Quantizer +from gptq.quant import Quantizer, quantize import os import torch.nn.functional as F import torch.nn as nn +import time import gc import math import re @@ -31,6 +32,7 @@ def add_batch(self, inp): if len(inp.shape) == 2: inp = inp.unsqueeze(0) + #TODO: is the case with len = 1 still necessary ? tmp = 1 if len(inp.shape) == 1 else inp.shape[0] # Assume weight come from nn.Linear @@ -43,16 +45,88 @@ def add_batch(self, inp): inp = math.sqrt(2 / self.nsamples) * inp.float() self.H += inp.matmul(inp.t()) - def fasterquant(self): - pass + def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): + W = self.weight.data.clone() + # Need to transpose here, same reason as in __init__ with self.columns + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + print('time %.2f' % (time.time() - tick)) + print('error', torch.sum(Losses).item()) + + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + + self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) + ### end GPTQ ### begin GPTQ_RWKV def __init__(self, model, strategy): super().__init__(model, strategy) #TODO: add assert to only quantize in CPU FP32 mode - self.subset = {} - self.gptq = {} def _fill_subset(self, layer_id): # Keep only layer within block layer_id @@ -62,14 +136,21 @@ def _fill_subset(self, layer_id): for name in self.w.keys(): if re.match(f'^blocks\.{layer_id}\..*\.weight$', name): tensor = self.w[name] - print(f"{name} = {self.w[name].shape}") + + #TODO: Skip 1D tensors for now + if len(tensor.shape) == 1: + continue + print(f"{name} = {self.w[name].shape}") + if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name): tensor = tensor.to(device=dev, non_blocking=True) self.subset[name] = tensor def alloc_gptq(self, layer_id): + self.subset = {} + self.gptq = {} self._fill_subset(layer_id) @@ -80,15 +161,15 @@ def alloc_gptq(self, layer_id): self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False) def free_gptq(self): - if len(self.subset) > 0: del self.subset - if len(self.gptq) > 0: del self.gptq - gc.collect() + self.subset = {} + self.gptq = {} def fasterquant(self, layer_id, quantizers): + for name in self.subset: print(f"Quantizing {name} of layer {layer_id}") #TODO: add argparse to fastquant - self.gptq[name].fastquant(percdamp=0.01, groupsize=-1, actorder=False) + self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False) # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) quantizers[name] = self.gptq[name].quantizer # TODO: may be free gptq here to save memory @@ -98,8 +179,7 @@ def fasterquant(self, layer_id, quantizers): ### begin RWKV def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) - ln_w.add_batch(x) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) @@ -126,8 +206,7 @@ def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) - ln_w.add_batch(x) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) @@ -161,8 +240,7 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t return x + out, xx[-1,:], aa, bb, pp def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) - ln_w.add_batch(x) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) @@ -176,9 +254,8 @@ def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): # x = (2048, 768) - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w.weight, bias=ln_b) + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) # xx = (2048, 768) - ln_w.add_batch(x) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) # sx = (2048, 768) kx = xx * k_mix + sx * (1 - k_mix) @@ -189,8 +266,6 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr r = torch.sigmoid(rx @ rw.weight) # r = (2048, 768) rw.add_batch(rx) - print("kx: ", kx.shape) - print("kw.weight: ", kw.weight.shape) vx = torch.square(torch.relu(kx @ kw.weight)) # vx = (2048, 3072) # kx: (2048, 768) @@ -260,7 +335,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3], - ln_w=self.gptq[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], + ln_w=self.w[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], k_mix=self.w[f'{att}time_mix_k'], v_mix=self.w[f'{att}time_mix_v'], r_mix=self.w[f'{att}time_mix_r'], t_decay=self.w[f'{att}time_decay'], t_first=self.w[f'{att}time_first'], kw=kw, vw=vw, rw=rw, ow=ow, @@ -295,7 +370,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): rry = self.w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x x, state[i*5+4] = FFN( x=x, sx=state[i*5+4], - ln_w=self.gptq[f'{bbb}ln2.weight'], ln_b=self.w[f'{bbb}ln2.bias'], + ln_w=self.w[f'{bbb}ln2.weight'], ln_b=self.w[f'{bbb}ln2.bias'], k_mix=self.w[f'{ffn}time_mix_k'], r_mix=self.w[f'{ffn}time_mix_r'], kw=kw, vw=vw, rw=rw, kmx=kmx, krx=krx, kmy=kmy, kry=kry, @@ -346,6 +421,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') +#TODO: Do the same in GPU side with torch.no_grad(): seq_mode = len(tokens) > 1 x = model.w['emb.weight'][tokens if seq_mode else tokens[0]] @@ -356,13 +432,12 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): model.alloc_gptq(layer_id) - # TODO: call add_batch() for each layer inside att_seq etc function for j in range(NSAMPLES): _ = model.forward_block(x[j], state=None, i=layer_id, seq_mode=seq_mode) + + model.fasterquant(layer_id, quantizers) - # model.fasterquant(layer_id, quantizers) - - # model.free_gptq() + model.free_gptq() #TODO: Since we quantize per block, we should pass the outputs of block 0 to input of block 1 ? # inps, outs = outs, inps From 57079e774537e969f479550ccb5ce28dcd729e8b Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 25 Apr 2023 09:30:14 +0000 Subject: [PATCH 07/20] feat(quantize): full gptq pipeline now integrated with RKWV (quite slow for some layer + need tests) --- quantize/tmp_rwkv.py | 76 +++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index e45c055b..b4697b64 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -27,8 +27,14 @@ def __init__(self, weight, name): self.columns = self.weight.shape[0] self.H = torch.zeros((self.columns, self.columns), device=self.dev) self.nsamples = 0 + self.deactivate_add_batch_call = False def add_batch(self, inp): + + # After calling fasterquant, we don't want to call add_batch anymore + if self.deactivate_add_batch_call: + return + if len(inp.shape) == 2: inp = inp.unsqueeze(0) @@ -253,30 +259,20 @@ def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr return x + out, xx def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - # x = (2048, 768) xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - # xx = (2048, 768) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - # sx = (2048, 768) kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - # kx = (2048, 768) - # rx = (2048, 768) r = torch.sigmoid(rx @ rw.weight) - # r = (2048, 768) rw.add_batch(rx) vx = torch.square(torch.relu(kx @ kw.weight)) - # vx = (2048, 3072) - # kx: (2048, 768) - # kw.weight: (768, 3072) - # vx: (2048, 3072) kw.add_batch(kx) out = r * (vx @ vw.weight) vw.add_batch(vx) return x + out, xx[-1,:] - def forward_block(self, x, state, i, seq_mode, full_output=False): + def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False): with torch.no_grad(): args = self.args @@ -344,6 +340,11 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, omx=omx, orx=orx, omy=omy, ory=ory, ) + + kw.deactivate_add_batch_call = True + vw.deactivate_add_batch_call = True + rw.deactivate_add_batch_call = True + ow.deactivate_add_batch_call = True if dd.stream: del kw, vw, rw, ow @@ -378,6 +379,11 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, ) + # Deactivate add_batch() after quantization is applied + kw.deactivate_add_batch_call = True + vw.deactivate_add_batch_call = True + rw.deactivate_add_batch_call = True + if dd.stream: del kw, vw, rw @@ -385,21 +391,22 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): if (i+1) % self.RESCALE_LAYER == 0: x = x / 2 - dd = self.strategy[args.n_layer] - x = x[-1,:] if (seq_mode and (not full_output)) else x - x = x.to(dtype=dd.atype, device=dd.device) - - #TODO: Add GPTQ support for head & ln_out - x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) - if self.w['head.weight'].dtype != torch.uint8: - x = x @ self.w['head.weight'] - else: - if seq_mode and full_output: - x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + if is_last_layer: + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + #TODO: Add GPTQ support for head & ln_out + x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) + if self.w['head.weight'].dtype != torch.uint8: + x = x @ self.w['head.weight'] else: - x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + if seq_mode and full_output: + x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + else: + x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) - return x.float(), state + return x.float() ### end RWKV @@ -420,11 +427,13 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): print("tokens.shape", tokens.shape) model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') +is_last_layer = [False] * (model.args.n_layer - 1) + [True] #TODO: Do the same in GPU side with torch.no_grad(): seq_mode = len(tokens) > 1 - x = model.w['emb.weight'][tokens if seq_mode else tokens[0]] + inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]] + outs = torch.zeros_like(inps) quantizers = {} @@ -433,14 +442,23 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): model.alloc_gptq(layer_id) for j in range(NSAMPLES): - _ = model.forward_block(x[j], state=None, i=layer_id, seq_mode=seq_mode) - + if not is_last_layer[layer_id]: + outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + else: + _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + model.fasterquant(layer_id, quantizers) + for j in range(NSAMPLES): + if not is_last_layer[layer_id]: + outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + else: + _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) model.free_gptq() - #TODO: Since we quantize per block, we should pass the outputs of block 0 to input of block 1 ? - # inps, outs = outs, inps + if not is_last_layer[layer_id]: + # We need to pass the outputs of block i as input of block i+1 (except for last block) + inps, outs = outs, inps # TODO: create a function that check if all weights were properly quantized print("Done") \ No newline at end of file From 8e78f2d405ee9d66270bbda276825680de3ef8c2 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 25 Apr 2023 13:14:17 +0000 Subject: [PATCH 08/20] fix(quantize): add missing part in forward block + support head.weight quantization --- quantize/gptq/datautils.py | 11 +++-- quantize/tmp_rwkv.py | 96 +++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py index 4ed1e39b..cd296d3c 100644 --- a/quantize/gptq/datautils.py +++ b/quantize/gptq/datautils.py @@ -4,6 +4,7 @@ import pathlib import tokenizers import random +from rwkv.model import RWKV from datasets import load_dataset @@ -12,7 +13,7 @@ def set_seed(seed): torch.random.manual_seed(seed) def get_wikitext2(nsamples, seed, seqlen, model): - is_rwkv = True if model is None else False + is_rwkv = isinstance(model, RWKV) if is_rwkv: traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') @@ -37,11 +38,11 @@ def get_wikitext2(nsamples, seed, seqlen, model): trainloader = [] shape = trainenc.shape if is_rwkv else trainenc.input_ids.shape trainenc = trainenc if is_rwkv else trainenc.input_ids + random_idx = [random.randint(0, shape[1] - seqlen - 1) for _ in range(nsamples)] - for _ in range(nsamples): - i = random.randint(0, shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc[:, i:j] + for i in range(nsamples): + j = random_idx[i] + seqlen + inp = trainenc[:, random_idx[i]:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index b4697b64..1ed2ad1f 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -7,7 +7,6 @@ import torch.nn.functional as F import torch.nn as nn import time -import gc import math import re @@ -132,27 +131,21 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) ### begin GPTQ_RWKV def __init__(self, model, strategy): super().__init__(model, strategy) - #TODO: add assert to only quantize in CPU FP32 mode + for i in range(self.args.n_layer): + assert self.strategy[i].device == "cpu" def _fill_subset(self, layer_id): # Keep only layer within block layer_id - dd = self.strategy[layer_id] - dev = dd.device - - for name in self.w.keys(): - if re.match(f'^blocks\.{layer_id}\..*\.weight$', name): - tensor = self.w[name] - - #TODO: Skip 1D tensors for now - if len(tensor.shape) == 1: - continue - - print(f"{name} = {self.w[name].shape}") - - if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name): - tensor = tensor.to(device=dev, non_blocking=True) - - self.subset[name] = tensor + is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') + for name in self.w.keys(): + if is_weight.match(name): + if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now + self.subset[name] = self.w[name] + + is_last_layer = (layer_id == self.args.n_layer - 1) + if is_last_layer: + self.subset["head.weight"] = self.w["head.weight"] + def alloc_gptq(self, layer_id): self.subset = {} @@ -178,7 +171,6 @@ def fasterquant(self, layer_id, quantizers): self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False) # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) quantizers[name] = self.gptq[name].quantizer - # TODO: may be free gptq here to save memory ### end GPTQ_RWKV @@ -272,7 +264,7 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr vw.add_batch(vx) return x + out, xx[-1,:] - def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False): + def forward_block(self, x, state, i, seq_mode, full_output=False): with torch.no_grad(): args = self.args @@ -312,6 +304,12 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) rw = self.gptq[f'{att}receptance.weight'] ow = self.gptq[f'{att}output.weight'] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + kmx = self.w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x krx = self.w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x kmy = self.w[f'{att}key.weight_my'] if wtype == torch.uint8 else x @@ -341,6 +339,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) omx=omx, orx=orx, omy=omy, ory=ory, ) + # Deactivate add_batch() after quantization is applied kw.deactivate_add_batch_call = True vw.deactivate_add_batch_call = True rw.deactivate_add_batch_call = True @@ -352,6 +351,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) kw = self.gptq[f'{ffn}key.weight'] vw = self.gptq[f'{ffn}value.weight'] rw = self.gptq[f'{ffn}receptance.weight'] + if dd.stream: kw = kw.to(device=dev, non_blocking=True) vw = vw.to(device=dev, non_blocking=True) @@ -391,43 +391,46 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) if (i+1) % self.RESCALE_LAYER == 0: x = x / 2 - if is_last_layer: + is_last_layer = i == (args.n_layer - 1) + + if is_last_layer: dd = self.strategy[args.n_layer] x = x[-1,:] if (seq_mode and (not full_output)) else x x = x.to(dtype=dd.atype, device=dd.device) - #TODO: Add GPTQ support for head & ln_out + #TODO: ln_out.weight is 1D tensor x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) + if self.w['head.weight'].dtype != torch.uint8: - x = x @ self.w['head.weight'] - else: - if seq_mode and full_output: - x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) - else: - x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + x = x @ self.gptq['head.weight'].weight + self.gptq['head.weight'].add_batch(x) + self.gptq['head.weight'].deactivate_add_batch_call = True return x.float() ### end RWKV +model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') + NSAMPLES=2 -HIDDEN_SIZE=768 -SEQLEN=2048 # TODO: this is chosen by the model +HIDDEN_SIZE=model.args.n_embd +SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m # train_tokens, test_tokens = get_loaders( # dataset_name="wikitext2", # nsamples=NSAMPLES, # seed=42, # seqlen=SEQLEN, -# model=None +# model=model # ) # tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) print("tokens.shape", tokens.shape) -model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') -is_last_layer = [False] * (model.args.n_layer - 1) + [True] +is_last_layer = lambda x: x == (model.args.n_layer - 1) + +start_time = time.time() #TODO: Do the same in GPU side with torch.no_grad(): @@ -442,23 +445,28 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) model.alloc_gptq(layer_id) for j in range(NSAMPLES): - if not is_last_layer[layer_id]: - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + if not is_last_layer(layer_id): + outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) model.fasterquant(layer_id, quantizers) for j in range(NSAMPLES): - if not is_last_layer[layer_id]: - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + if not is_last_layer(layer_id): + outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) + model.free_gptq() - if not is_last_layer[layer_id]: - # We need to pass the outputs of block i as input of block i+1 (except for last block) + # We need to pass the outputs of block i as input of block i+1 (except for last block) + if not is_last_layer(layer_id): inps, outs = outs, inps -# TODO: create a function that check if all weights were properly quantized -print("Done") \ No newline at end of file +end_time = time.time() + +print(f"Done in {end_time - start_time:.2f} seconds") + +# TODO: Do something with quantizers dictionary +# TODO: pack3 save model \ No newline at end of file From f87df05395289d9e3685eec4c68b134c95c7c747 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Wed, 26 Apr 2023 09:55:19 +0000 Subject: [PATCH 09/20] feat(sanity-check): begin sanity check for GPTQ on MNIST --- quantize/gptq/sanity_check_main.py | 118 +++++++++++++++++++ quantize/gptq/sanity_check_utils.py | 168 ++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+) create mode 100644 quantize/gptq/sanity_check_main.py create mode 100644 quantize/gptq/sanity_check_utils.py diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py new file mode 100644 index 00000000..f95a522a --- /dev/null +++ b/quantize/gptq/sanity_check_main.py @@ -0,0 +1,118 @@ +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate +from gptq import * +from modelutils import * +from quant import * + +def quantize_gptq(model, train_loader, device): + quantizers = {} + layers = list(model.modules())[1:] + is_last_layer = lambda x: x == (len(layers) - 1) + + nsamples = len(train_loader.dataset) + batch_size = train_loader.batch_size + + inps = torch.zeros((nsamples, model.N), dtype=torch.float) + for i, (inp, _) in enumerate(train_loader): + inps[i*batch_size:(i+1)*batch_size] = inp.view(-1, 32*32) + outs = torch.zeros_like(inps) + + for layer_id in range(len(layers)): + layer = layers[layer_id] + + subset = find_layers(layer) + gptq = {} + + for name in subset: + gptq[name] = GPTQ(subset[name], name) + gptq[name].quantizer = Quantizer() + # TODO: 8 bits quantize so that we can compare with pytorch post-training quantization + gptq[name].quantizer.configure(bits=8, perchannel=True, sym=False, mse=False, trits=False) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for i in range(nsamples): + if not is_last_layer(layer_id): + outs[i] = layer(inps[i]) + else: + _ = layer(inps[i]) + + for h in handles: h.remove() + + for name in subset: + print(i, name) + print('Quantizing ...') + gptq[name].fasterquant(percdamp=0.1, groupsize=-1, actorder=False) + quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer + gptq[name].free() + + + for i in range(nsamples): + if not is_last_layer(layer_id): + outs[i] = layer(inps[i]) + else: + _ = layer(inps[i]) + + del layer + del gptq + torch.cuda.empty_cache() + + if not is_last_layer(layer_id): + inps, outs = outs, inps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train", action="store_true") + parser.add_argument("--gptq", action="store_true") + parser.add_argument("--pyquant", action="store_true") + + args = parser.parse_args() + + seed_everything(42) + lr = 0.02 + num_epochs = 5 + model = SimpleNet() + optimizer = optim.Adam(model.parameters(), lr) + criterion = nn.CrossEntropyLoss() + train_loader, _, _ = MNISTloader(train_val_split=0.95).load() + + if args.train: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + train(num_epochs, model, optimizer, criterion, train_loader, device) + torch.save(model.state_dict(), "model.pt") + elif args.gptq: + device = torch.device("cpu") + model.load_state_dict(torch.load("./model.pt", map_location="cpu")) + model = model.to(device) + quantize_gptq(model, train_loader, device) + elif args.pyquant: + pass + else: + device = torch.device("cpu") + model.load_state_dict(torch.load("./model.pt", map_location="cpu")) + model = model.to(device) + + # Evaluate float 32 + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print("Floating point FP32") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py new file mode 100644 index 00000000..08bd120d --- /dev/null +++ b/quantize/gptq/sanity_check_utils.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import random +import numpy as np +import torch +from torch.utils.data import DataLoader, random_split +from torchvision import datasets, transforms + +def seed_everything(seed: int): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + +# Model +class SimpleNet(nn.Module): + def __init__(self, num_classes=10, init_weights=True): + super(SimpleNet, self).__init__() + self.N = 32 * 32 + self.linear1 = nn.Linear(in_features=32 * 32, out_features=self.N) + self.linear2 = nn.Linear(in_features=self.N, out_features=self.N) + self.linear3 = nn.Linear(in_features=self.N, out_features=self.N) + self.linear4 = nn.Linear(in_features=self.N, out_features=num_classes) + + def forward(self, x): + # Assume input is already flattened + x = F.relu(self.linear1(x)) + x = F.relu(self.linear2(x)) + x = F.relu(self.linear3(x)) + x = self.linear4(x) + return x + +# Dataset + +class MNISTloader: + def __init__( + self, + batch_size: int = 100, + data_dir: str = "./data/", + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + train_val_split: float = 0.1, + ): + self.batch_size = batch_size + self.data_dir = data_dir + self.num_workers = num_workers + self.pin_memory = pin_memory + self.shuffle = shuffle + self.train_val_split = train_val_split + + self.setup() + + def setup(self): + transform = transforms.Compose( + [ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]), + ] + ) + + self.train_dataset = datasets.MNIST( + self.data_dir, train=True, download=True, transform=transform + ) + val_split = int(len(self.train_dataset) * self.train_val_split) + train_split = len(self.train_dataset) - val_split + + self.train_dataset, self.val_dataset = random_split( + self.train_dataset, [train_split, val_split] + ) + self.test_dataset = datasets.MNIST( + self.data_dir, train=False, download=True, transform=transform + ) + + print( + "Image Shape: {}".format(self.train_dataset[0][0].numpy().shape), + end="\n\n", + ) + print("Training Set: {} samples".format(len(self.train_dataset))) + print("Validation Set: {} samples".format(len(self.val_dataset))) + print("Test Set: {} samples".format(len(self.test_dataset))) + + def load(self): + train_loader = DataLoader( + dataset=self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + val_loader = DataLoader( + dataset=self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + test_loader = DataLoader( + dataset=self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + return train_loader, val_loader, test_loader + +# Train + evaluate + +def evaluate(device, model, criterion, val_loader): + + val_loss_running, val_acc_running = 0, 0 + + model.eval().cuda() if (device.type == "cuda") else model.eval().cpu() + + with torch.no_grad(): + for inputs, labels in val_loader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + loss = criterion(outputs, labels) + _, predictions = torch.max(outputs, dim=1) + val_loss_running += loss.item() * inputs.shape[0] + val_acc_running += torch.sum(predictions == labels.data) + + val_loss = val_loss_running / len(val_loader.sampler) + val_acc = val_acc_running / len(val_loader.sampler) + + return val_loss, val_acc + + +def train(num_epochs, model, optimizer, criterion, train_loader, device): + + model.train().cuda() if (device.type == "cuda") else model.train().cpu() + + for epoch in range(num_epochs): + + train_loss_running, train_acc_running = 0, 0 + + for inputs, labels in train_loader: + + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + outputs = model(inputs) + + _, predictions = torch.max(outputs, dim=1) + loss = criterion(outputs, labels) + + loss.backward() + optimizer.step() + + train_loss_running += loss.item() * inputs.shape[0] + train_acc_running += torch.sum(predictions == labels.data) + + train_loss = train_loss_running / len(train_loader.sampler) + train_acc = train_acc_running / len(train_loader.sampler) + + info = "Epoch: {:3}/{} \t train_loss: {:.3f} \t train_acc: {:.3f}" + print(info.format(epoch + 1, num_epochs, train_loss, train_acc)) From b77715dc5b51d3b2f9e9b0551d22b8b1d045befa Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Thu, 27 Apr 2023 08:11:16 +0000 Subject: [PATCH 10/20] breaking(sanity-check): add save & load option for reference gptq --- quantize/gptq/gptq.py | 33 +- quantize/gptq/quant.py | 251 ++++++++---- quantize/gptq/quant_cuda.cpp | 58 ++- quantize/gptq/quant_cuda_kernel.cu | 595 ++++++++++++++++++++-------- quantize/gptq/sanity_check_main.py | 57 ++- quantize/gptq/sanity_check_utils.py | 6 +- requirements-quantize.txt | 4 +- 7 files changed, 738 insertions(+), 266 deletions(-) diff --git a/quantize/gptq/gptq.py b/quantize/gptq/gptq.py index ae857e58..4cb03c85 100644 --- a/quantize/gptq/gptq.py +++ b/quantize/gptq/gptq.py @@ -5,7 +5,7 @@ import torch.nn as nn import transformers -from .quant import * +from quant import * DEBUG = False @@ -15,8 +15,8 @@ class GPTQ: - - def __init__(self, layer): + def __init__(self, layer, name): + self.name = name self.layer = layer self.dev = self.layer.weight.device W = layer.weight.data.clone() @@ -77,7 +77,7 @@ def fasterquant( dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 - + if actorder: perm = torch.argsort(torch.diag(H), descending=True) W = W[:, perm] @@ -93,6 +93,11 @@ def fasterquant( H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -112,6 +117,11 @@ def fasterquant( if (i1 + i) % groupsize == 0: self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + q = quantize( w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq ).flatten() @@ -136,17 +146,28 @@ def fasterquant( torch.cuda.synchronize() print('time %.2f' % (time.time() - tick)) print('error', torch.sum(Losses).item()) - + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if actorder: invperm = torch.argsort(perm) Q = Q[:, invperm] + g_idx = g_idx[invperm] if isinstance(self.layer, transformers.Conv1D): Q = Q.t() self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) if DEBUG: print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) - + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx + def free(self): if DEBUG: self.inp1 = None diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py index 77c27e00..e301c521 100644 --- a/quantize/gptq/quant.py +++ b/quantize/gptq/quant.py @@ -1,7 +1,7 @@ import numpy as np import torch import torch.nn as nn - +import math def quantize(x, scale, zero, maxq): if maxq < 0: @@ -22,7 +22,8 @@ def configure( bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False - ): + ): + self.maxq = torch.tensor(2 ** bits - 1) self.perchannel = perchannel self.sym = sym @@ -66,14 +67,14 @@ def find_params(self, x, weight=False): xmax[tmp] = +1 if self.maxq < 0: - self.scale = xmax - self.zero = xmin + self.scale = xmax + self.zero = xmin else: - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) - else: - self.zero = torch.round(-xmin / self.scale) + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) if self.mse: best = torch.full([x.shape[0]], float('inf'), device=dev) @@ -127,86 +128,190 @@ def enabled(self): def ready(self): return torch.all(self.scale != 0) - try: import quant_cuda + is_cuda = True except: print('CUDA extension not installed.') + is_cuda = False -# Assumes layer is perfectly divisible into 1024 * 1024 blocks -class Quant3Linear(nn.Module): +def make_quant(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) - def __init__(self, infeatures, outfeatures, faster=False): +class QuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): super().__init__() - self.register_buffer('zeros', torch.zeros((outfeatures, 1))) - self.register_buffer('scales', torch.zeros((outfeatures, 1))) - self.register_buffer('bias', torch.zeros(outfeatures)) - self.register_buffer( - 'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int) - ) - self.faster = faster + if bits not in [2,3,4,8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 - def pack(self, linear, scales, zeros): - self.zeros = zeros * scales - self.scales = scales.clone() - self.bias = linear.bias.clone() + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2,4,8]: + self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + self.is_cuda = is_cuda - intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) + def pack(self, linear, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) qweight = np.zeros( - (intweight.shape[0] // 32 * 3, intweight.shape[1]), dtype=np.uint32 + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 ) i = 0 row = 0 while row < qweight.shape[0]: - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i)) - i += 10 - qweight[row] |= intweight[i] << 30 - row += 1 - qweight[row] |= (intweight[i] >> 2) & 1 - i += 1 - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 1) - i += 10 - qweight[row] |= intweight[i] << 31 - row += 1 - qweight[row] |= (intweight[i] >> 1) & 0x3 - i += 1 - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 2) - i += 10 - row += 1 - + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) - - def forward(self, x): - if x.shape[-1] == x.numel(): - outshape = list(x.shape) - y = self.bias.clone() - outshape[-1] = self.bias.numel() - dtype = x.dtype - if self.faster: - x = x.half() - quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros) + + zeros -= 1; + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 else: - x = x.float() - quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros) - y = y.to(dtype) - return y.reshape(outshape) - raise ValueError('Only supports a single token currently.') - -def make_quant3(module, names, name='', faster=False): - if isinstance(module, Quant3Linear): - return - for attr in dir(module): - tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr - if name1 in names: - setattr( - module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster) - ) - for name1, child in module.named_children(): - make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster) + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold): + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2,4,8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4) + zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1))&0x7 + weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4) + weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx])) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out diff --git a/quantize/gptq/quant_cuda.cpp b/quantize/gptq/quant_cuda.cpp index 1bf08941..3200a9f2 100644 --- a/quantize/gptq/quant_cuda.cpp +++ b/quantize/gptq/quant_cuda.cpp @@ -2,33 +2,69 @@ #include #include -void vecquant3matmul_cuda( +void vecquant2matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor scales, torch::Tensor zeros -); + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} -void vecquant3matmul_faster_cuda( +void vecquant3matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor scales, torch::Tensor zeros + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx ); void vecquant3matmul( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor scales, torch::Tensor zeros + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - vecquant3matmul_cuda(vec, mat, mul, scales, zeros); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx); } -void vecquant3matmul_faster( +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor scales, torch::Tensor zeros + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); - m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); } diff --git a/quantize/gptq/quant_cuda_kernel.cu b/quantize/gptq/quant_cuda_kernel.cu index 101167f0..60c1dc08 100644 --- a/quantize/gptq/quant_cuda_kernel.cu +++ b/quantize/gptq/quant_cuda_kernel.cu @@ -4,42 +4,209 @@ #include #include +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +__device__ double atomicAdd( + double* address, + double val +) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, + assumed, + __double_as_longlong(val + __longlong_as_double(assumed)) + ); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, - const scalar_t* __restrict__ zeros, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, int height, - int width + int width, + int zero_width ); -__global__ void VecQuant3MatMulKernelFaster( - const half2* __restrict__ vec, - const int* __restrict__ mat, - float* __restrict__ mul, - const float* __restrict__ scales, - const float* __restrict__ zeros, +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, int height, - int width + int width, + int zero_width +); + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width ); const int BLOCKWIDTH = 256; -const int BLOCKHEIGHT = 24; +const int BLOCKHEIGHT2 = 16; +const int BLOCKHEIGHT3 = 24; +const int BLOCKHEIGHT4 = 32; +const int BLOCKHEIGHT8 = 64; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 16; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); + int k_bit = (k % 16) * 2; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} void vecquant3matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, - torch::Tensor zeros + torch::Tensor zeros, + torch::Tensor g_idx ) { + int batch = vec.size(0); + int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); + int zero_width = zeros.size(1); dim3 blocks( - (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); @@ -48,197 +215,295 @@ void vecquant3matmul_cuda( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), - height, width + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width ); }) ); } -void vecquant3matmul_faster_cuda( +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = (h / 3) * 32; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + unsigned int z_tmp; + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; + int k_mod = k % 32; + int k_bit; + + if (k_mod != 10){ + if (k_mod != 21){ + k_bit = k_mod; + if (k_bit > 21){ + k_bit -= 22; + k_bit *= 3; + k_bit += 2; + k_w += 2; + } else if (k_bit > 10){ + k_bit -= 11; + k_bit *= 3; + k_bit += 1; + k_w += 1; + } else { + k_bit *= 3; + } + } else { + k_w += 1; + } + } + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = scalar_t((z_tmp) + 1); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = scalar_t((z_tmp) + 1); + } else { + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + } + + if (k_mod == 10) { + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); + } else if (k_mod == 21){ + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); + } else { + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); + } + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant4matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, - torch::Tensor zeros + torch::Tensor zeros, + torch::Tensor g_idx ) { + int batch = vec.size(0); + int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); + int zero_width = zeros.size(1); dim3 blocks( - (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); - VecQuant3MatMulKernelFaster<<>>( - (half2*) vec.data_ptr(), - mat.data_ptr(), - mul.data_ptr(), - scales.data_ptr(), - zeros.data_ptr(), - height, width + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) ); } -__device__ inline unsigned int as_unsigned(int i) { - return *reinterpret_cast(&i); -} - template -__global__ void VecQuant3MatMulKernel( +__global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, - const scalar_t* __restrict__ zeros, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, int height, - int width + int width, + int zero_width ) { - int row = BLOCKHEIGHT * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + __shared__ scalar_t blockvec[BLOCKWIDTH]; - blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * BLOCKWIDTH + threadIdx.x]; - __syncthreads(); - - scalar_t scale = scales[col]; - scalar_t zero = zeros[col]; - - scalar_t res = 0; - int i = width * row + col; - int k = 0; - - unsigned int tmp1; - unsigned int tmp2; - unsigned int tmp; - - while (k < BLOCKWIDTH) { - tmp1 = as_unsigned(mat[i]); - res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; - res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; - res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; - res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; - res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; - res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; - res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; - res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; - res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; - res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - i += width; - tmp2 = as_unsigned(mat[i]); - tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); - tmp2 >>= 1; - res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; - k += 11; - res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; - res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; - res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; - res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; - res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; - res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; - res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; - res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; - res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; - res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; - i += width; - tmp1 = as_unsigned(mat[i]); - tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); - tmp1 >>= 2; - res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; - k += 11; - res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; - res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; - res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; - res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; - res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; - res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; - res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; - res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; - res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; - res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - i += width; - k += 10; + int i = width * h + w; + int g_h = h * 8; + int k; + unsigned int g; + scalar_t w_tmp; + + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); + + weight[k] = scale * (w_tmp - zero); } - atomicAdd(&mul[col], res); + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } } -__global__ void VecQuant3MatMulKernelFaster( - const half2* __restrict__ vec, - const int* __restrict__ mat, - float* __restrict__ mul, - const float* __restrict__ scales, - const float* __restrict__ zeros, - int height, - int width +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx ) { - const int blockwidth2 = BLOCKWIDTH / 2; + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); - int row = BLOCKHEIGHT * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); - __shared__ half2 blockvec[blockwidth2]; - if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * blockwidth2 + threadIdx.x]; + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} - __shared__ half2 deq2[64][32]; - int val = threadIdx.x / 32; - int off = threadIdx.x % 32; - for (; val < 64; val += BLOCKWIDTH / 32) { - deq2[val][off] = __halves2half2( - __int2half_rn(val & 0x7), __int2half_rn(val >> 3) - ); +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); } - half2 scale = __float2half2_rn(scales[col]); - half2 zero = __float2half2_rn(-zeros[col]); - - int i = width * row + col; - int k = 0; - - float res = 0; - half2 res2; - - unsigned int tmp1; - unsigned int tmp2; - unsigned int tmp; - - __syncthreads(); - - while (k < blockwidth2) { - res2 = {}; - tmp1 = as_unsigned(mat[i]); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); - i += width; - tmp2 = as_unsigned(mat[i]); - tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c); - res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2); - tmp2 >>= 4; - k += 6; - res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); - i += width; - tmp1 = as_unsigned(mat[i]); - tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30); - res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2); - tmp1 >>= 2; - k += 5; - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); - res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); - i += width; - k += 5; - res += __half2float(res2.x) + __half2float(res2.y); + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); } - - atomicAdd(&mul[col], res); } diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index f95a522a..475332f7 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -10,6 +10,9 @@ from modelutils import * from quant import * +WBITS = 8 +GROUPSIZE = -1 + def quantize_gptq(model, train_loader, device): quantizers = {} layers = list(model.modules())[1:] @@ -32,8 +35,7 @@ def quantize_gptq(model, train_loader, device): for name in subset: gptq[name] = GPTQ(subset[name], name) gptq[name].quantizer = Quantizer() - # TODO: 8 bits quantize so that we can compare with pytorch post-training quantization - gptq[name].quantizer.configure(bits=8, perchannel=True, sym=False, mse=False, trits=False) + gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) def add_batch(name): def tmp(_, inp, out): @@ -56,11 +58,10 @@ def tmp(_, inp, out): for name in subset: print(i, name) print('Quantizing ...') - gptq[name].fasterquant(percdamp=0.1, groupsize=-1, actorder=False) - quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer + scale,zero,g_idx = gptq[name].fasterquant(percdamp=0.1, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id + 1}"] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) gptq[name].free() - for i in range(nsamples): if not is_last_layer(layer_id): outs[i] = layer(inps[i]) @@ -73,13 +74,38 @@ def tmp(_, inp, out): if not is_last_layer(layer_id): inps, outs = outs, inps + + return quantizers + +# TODO: perform packing on GPU +def model_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [QuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name],scale,zero,g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + +def load_quant(model, checkpoint, wbits, groupsize): + print('Loading model ...') + model = model.eval() + layers = find_layers(model) + make_quant(model, layers, wbits, groupsize) + model.load_state_dict(torch.load(checkpoint)) + print('Done.') + return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train", action="store_true") parser.add_argument("--gptq", action="store_true") - parser.add_argument("--pyquant", action="store_true") + parser.add_argument("--eval_gptq", action="store_true") args = parser.parse_args() @@ -100,9 +126,22 @@ def tmp(_, inp, out): device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) model = model.to(device) - quantize_gptq(model, train_loader, device) - elif args.pyquant: - pass + quantizers = quantize_gptq(model, train_loader, device) + model_pack(model, quantizers, WBITS, GROUPSIZE) + torch.save(model.state_dict(), "model_quantized.pt") + print("Done GPTQ") + elif args.eval_gptq: + device = torch.device("cuda:0") + model = load_quant(model, "model_quantized.pt", WBITS, GROUPSIZE) + model = model.to(device) + + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print(f"wbits = {WBITS} using {device}") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") else: device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py index 08bd120d..0de025ce 100644 --- a/quantize/gptq/sanity_check_utils.py +++ b/quantize/gptq/sanity_check_utils.py @@ -7,6 +7,7 @@ import torch from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms +from collections import OrderedDict def seed_everything(seed: int): random.seed(seed) @@ -22,13 +23,16 @@ class SimpleNet(nn.Module): def __init__(self, num_classes=10, init_weights=True): super(SimpleNet, self).__init__() self.N = 32 * 32 - self.linear1 = nn.Linear(in_features=32 * 32, out_features=self.N) + + self.linear1 = nn.Linear(in_features=self.N, out_features=self.N) self.linear2 = nn.Linear(in_features=self.N, out_features=self.N) self.linear3 = nn.Linear(in_features=self.N, out_features=self.N) self.linear4 = nn.Linear(in_features=self.N, out_features=num_classes) def forward(self, x): # Assume input is already flattened + if len(x.shape) == 4: + x = x.view(x.size(0), -1) x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) x = F.relu(self.linear3(x)) diff --git a/requirements-quantize.txt b/requirements-quantize.txt index 006cd3b8..352f3a9a 100644 --- a/requirements-quantize.txt +++ b/requirements-quantize.txt @@ -7,4 +7,6 @@ ninja tokenizers>=0.13.2 prompt_toolkit # Debug -pdbpp \ No newline at end of file +pdbpp +line_profiler +torchvision==0.14.1+cu117 \ No newline at end of file From 816def4edcce9c154b290465c1f2053da80ef1e8 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Thu, 27 Apr 2023 21:52:32 +0000 Subject: [PATCH 11/20] breaking(sanity-check): enhance with dummy model --- quantize/gptq/sanity_check_main.py | 28 +++++++++- quantize/gptq/sanity_check_utils.py | 81 ++++++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 8 deletions(-) diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index 475332f7..636e243f 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -16,6 +16,7 @@ def quantize_gptq(model, train_loader, device): quantizers = {} layers = list(model.modules())[1:] + layers = [l for l in layers if isinstance(l, nn.Linear)] is_last_layer = lambda x: x == (len(layers) - 1) nsamples = len(train_loader.dataset) @@ -25,6 +26,7 @@ def quantize_gptq(model, train_loader, device): for i, (inp, _) in enumerate(train_loader): inps[i*batch_size:(i+1)*batch_size] = inp.view(-1, 32*32) outs = torch.zeros_like(inps) + for layer_id in range(len(layers)): layer = layers[layer_id] @@ -58,7 +60,7 @@ def tmp(_, inp, out): for name in subset: print(i, name) print('Quantizing ...') - scale,zero,g_idx = gptq[name].fasterquant(percdamp=0.1, groupsize=GROUPSIZE, actorder=False) + scale,zero,g_idx = gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) quantizers[f"linear{layer_id + 1}"] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) gptq[name].free() @@ -106,6 +108,7 @@ def load_quant(model, checkpoint, wbits, groupsize): parser.add_argument("--train", action="store_true") parser.add_argument("--gptq", action="store_true") parser.add_argument("--eval_gptq", action="store_true") + parser.add_argument("--pyquant", action="store_true") args = parser.parse_args() @@ -123,6 +126,7 @@ def load_quant(model, checkpoint, wbits, groupsize): train(num_epochs, model, optimizer, criterion, train_loader, device) torch.save(model.state_dict(), "model.pt") elif args.gptq: + #FIXME: WHY ON EARTH QUANTIZATION ERROR IS SO DAMN HIGH FOR LAYER 3 AND 4 ?! device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) model = model.to(device) @@ -142,6 +146,28 @@ def load_quant(model, checkpoint, wbits, groupsize): print(f"wbits = {WBITS} using {device}") print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") print(f"Latency: {end - start}") + elif args.pyquant: + # Baseline post-training quantization from Pytorch + device = torch.device("cpu") + model.load_state_dict(torch.load("./model.pt")) + model.eval() + model.qconfig = torch.ao.quantization.get_default_qconfig('x86') + model_prepared = torch.ao.quantization.prepare(model) + + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = model_prepared.forward(inputs, is_pyquant=True) + + model_quant = torch.ao.quantization.convert(model_prepared) + + start_q = time.time() + val_loss_q, val_acc_q = evaluate(device, model_quant, criterion, train_loader, is_pyquant=True) + end_q = time.time() + + print("Pytorch post-training quantization INT8") + print(model_quant) + print(f"val_loss_q: {val_loss_q:.3f} \t val_acc_q:{val_acc_q:.3f}") + print(f"Latency: {end_q - start_q}") else: device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py index 0de025ce..5790414b 100644 --- a/quantize/gptq/sanity_check_utils.py +++ b/quantize/gptq/sanity_check_utils.py @@ -23,22 +23,89 @@ class SimpleNet(nn.Module): def __init__(self, num_classes=10, init_weights=True): super(SimpleNet, self).__init__() self.N = 32 * 32 - self.linear1 = nn.Linear(in_features=self.N, out_features=self.N) self.linear2 = nn.Linear(in_features=self.N, out_features=self.N) self.linear3 = nn.Linear(in_features=self.N, out_features=self.N) self.linear4 = nn.Linear(in_features=self.N, out_features=num_classes) + + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.skip_add = nn.quantized.FloatFunctional() - def forward(self, x): - # Assume input is already flattened + def forward(self, x, is_pyquant=False): if len(x.shape) == 4: x = x.view(x.size(0), -1) + + if is_pyquant: x = self.quant(x) + + residual = x x = F.relu(self.linear1(x)) - x = F.relu(self.linear2(x)) - x = F.relu(self.linear3(x)) + + x = self.linear2(x) + + if is_pyquant: + x = self.skip_add.add(F.relu(x), residual) + else: + x = F.relu(x) + residual + + x = self.linear3(x) + + if is_pyquant: + x = self.skip_add.add(F.relu(x), residual) + else: + x = F.relu(x) + residual + x = self.linear4(x) + + if is_pyquant: x = self.dequant(x) + return x +# class ResNet(nn.Module): +# class BasicBlock(nn.Module): +# expansion = 1 + +# def __init__(self, in_planes, planes, stride=1): +# super().__init__() +# self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) +# self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False) + +# self.shortcut = lambda x: x +# if stride != 1 or in_planes != self.expansion*planes: +# self.shortcut = nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False) + +# self.skip_add = nn.quantized.FloatFunctional() + +# def forward(self, x, is_pyquant=False): +# out = F.relu(self.conv1(x)) +# out = self.conv2(out) +# if is_pyquant: +# out = self.skip_add.add(out, self.shortcut(x)) +# else: +# out += self.shortcut(x) +# out = F.relu(out) +# return out + +# def __init__(self, num_classes=10): +# super(ResNet, self).__init__() +# self.conv1 = nn.Conv2d(1, 1, kernel_size=3,stride=1, padding=1, bias=False) +# self.layer1 = self.BasicBlock(1, 1, stride=1) +# self.linear = nn.Linear(64, num_classes) +# self.quant = torch.ao.quantization.QuantStub() +# self.dequant = torch.ao.quantization.DeQuantStub() + +# def forward(self, x, is_pyquant=False): +# if is_pyquant: +# x = self.quant(x) +# out = F.relu(self.conv1(x)) +# out = self.layer1.forward(out, is_pyquant=is_pyquant) +# out = F.avg_pool2d(out, 4) +# out = torch.flatten(out, 1) +# out = self.linear(out) +# if is_pyquant: +# out = self.dequant(out) +# return out + # Dataset class MNISTloader: @@ -119,7 +186,7 @@ def load(self): # Train + evaluate -def evaluate(device, model, criterion, val_loader): +def evaluate(device, model, criterion, val_loader, is_pyquant=False): val_loss_running, val_acc_running = 0, 0 @@ -128,7 +195,7 @@ def evaluate(device, model, criterion, val_loader): with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) + outputs = model.forward(inputs, is_pyquant) loss = criterion(outputs, labels) _, predictions = torch.max(outputs, dim=1) val_loss_running += loss.item() * inputs.shape[0] From f141e5223a2f7ee7c33df269f1cb663051def68b Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Fri, 28 Apr 2023 10:36:30 +0000 Subject: [PATCH 12/20] fix(sanity-check): dont quantize last layer for dummy example --- quantize/gptq/sanity_check_main.py | 20 ++++-- quantize/gptq/sanity_check_utils.py | 94 +++++++++-------------------- 2 files changed, 42 insertions(+), 72 deletions(-) diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index 636e243f..a5e2b549 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -13,10 +13,12 @@ WBITS = 8 GROUPSIZE = -1 -def quantize_gptq(model, train_loader, device): +@torch.no_grad() +def quantize_gptq(model, train_loader): quantizers = {} layers = list(model.modules())[1:] layers = [l for l in layers if isinstance(l, nn.Linear)] + layers = layers[:-1] is_last_layer = lambda x: x == (len(layers) - 1) nsamples = len(train_loader.dataset) @@ -37,7 +39,7 @@ def quantize_gptq(model, train_loader, device): for name in subset: gptq[name] = GPTQ(subset[name], name) gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) + gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=True, mse=False, trits=False) def add_batch(name): def tmp(_, inp, out): @@ -88,7 +90,7 @@ def model_pack(model, quantizers, wbits, groupsize): print('Packing ...') for name in qlayers: print(name) - quantizers[name],scale,zero,g_idx = quantizers[name] + quantizers[name],scale,zero,g_idx = quantizers[name] qlayers[name].pack(layers[name], scale, zero, g_idx) print('Done.') return model @@ -97,6 +99,13 @@ def load_quant(model, checkpoint, wbits, groupsize): print('Loading model ...') model = model.eval() layers = find_layers(model) + + # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) + # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) + for name in ["linear4"]: + if name in layers: + del layers[name] + make_quant(model, layers, wbits, groupsize) model.load_state_dict(torch.load(checkpoint)) print('Done.') @@ -126,11 +135,10 @@ def load_quant(model, checkpoint, wbits, groupsize): train(num_epochs, model, optimizer, criterion, train_loader, device) torch.save(model.state_dict(), "model.pt") elif args.gptq: - #FIXME: WHY ON EARTH QUANTIZATION ERROR IS SO DAMN HIGH FOR LAYER 3 AND 4 ?! device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) model = model.to(device) - quantizers = quantize_gptq(model, train_loader, device) + quantizers = quantize_gptq(model, train_loader) model_pack(model, quantizers, WBITS, GROUPSIZE) torch.save(model.state_dict(), "model_quantized.pt") print("Done GPTQ") @@ -156,7 +164,7 @@ def load_quant(model, checkpoint, wbits, groupsize): for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) - outputs = model_prepared.forward(inputs, is_pyquant=True) + outputs = model_prepared.forward_pyquant(inputs) model_quant = torch.ao.quantization.convert(model_prepared) diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py index 5790414b..90e17cbc 100644 --- a/quantize/gptq/sanity_check_utils.py +++ b/quantize/gptq/sanity_check_utils.py @@ -27,84 +27,44 @@ def __init__(self, num_classes=10, init_weights=True): self.linear2 = nn.Linear(in_features=self.N, out_features=self.N) self.linear3 = nn.Linear(in_features=self.N, out_features=self.N) self.linear4 = nn.Linear(in_features=self.N, out_features=num_classes) - + self.quant = torch.ao.quantization.QuantStub() self.dequant = torch.ao.quantization.DeQuantStub() self.skip_add = nn.quantized.FloatFunctional() - def forward(self, x, is_pyquant=False): + def forward(self, x): if len(x.shape) == 4: x = x.view(x.size(0), -1) - if is_pyquant: x = self.quant(x) - residual = x + x = F.relu(self.linear1(x)) - x = self.linear2(x) - - if is_pyquant: - x = self.skip_add.add(F.relu(x), residual) - else: - x = F.relu(x) + residual - + x = F.relu(x) + residual x = self.linear3(x) + x = F.relu(x) + residual + x = self.linear4(x) + return x + + def forward_pyquant(self, x): - if is_pyquant: - x = self.skip_add.add(F.relu(x), residual) - else: - x = F.relu(x) + residual - + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + x = self.quant(x) + + residual = x + + x = F.relu(self.linear1(x)) + x = self.linear2(x) + x = self.skip_add.add(F.relu(x), residual) + x = self.linear3(x) + x = self.skip_add.add(F.relu(x), residual) x = self.linear4(x) - - if is_pyquant: x = self.dequant(x) - - return x -# class ResNet(nn.Module): -# class BasicBlock(nn.Module): -# expansion = 1 - -# def __init__(self, in_planes, planes, stride=1): -# super().__init__() -# self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) -# self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False) - -# self.shortcut = lambda x: x -# if stride != 1 or in_planes != self.expansion*planes: -# self.shortcut = nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False) - -# self.skip_add = nn.quantized.FloatFunctional() - -# def forward(self, x, is_pyquant=False): -# out = F.relu(self.conv1(x)) -# out = self.conv2(out) -# if is_pyquant: -# out = self.skip_add.add(out, self.shortcut(x)) -# else: -# out += self.shortcut(x) -# out = F.relu(out) -# return out - -# def __init__(self, num_classes=10): -# super(ResNet, self).__init__() -# self.conv1 = nn.Conv2d(1, 1, kernel_size=3,stride=1, padding=1, bias=False) -# self.layer1 = self.BasicBlock(1, 1, stride=1) -# self.linear = nn.Linear(64, num_classes) -# self.quant = torch.ao.quantization.QuantStub() -# self.dequant = torch.ao.quantization.DeQuantStub() - -# def forward(self, x, is_pyquant=False): -# if is_pyquant: -# x = self.quant(x) -# out = F.relu(self.conv1(x)) -# out = self.layer1.forward(out, is_pyquant=is_pyquant) -# out = F.avg_pool2d(out, 4) -# out = torch.flatten(out, 1) -# out = self.linear(out) -# if is_pyquant: -# out = self.dequant(out) -# return out + x = self.dequant(x) + + return x # Dataset @@ -185,7 +145,6 @@ def load(self): return train_loader, val_loader, test_loader # Train + evaluate - def evaluate(device, model, criterion, val_loader, is_pyquant=False): val_loss_running, val_acc_running = 0, 0 @@ -195,7 +154,10 @@ def evaluate(device, model, criterion, val_loader, is_pyquant=False): with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) - outputs = model.forward(inputs, is_pyquant) + if is_pyquant: + outputs = model.forward_pyquant(inputs) + else: + outputs = model(inputs) loss = criterion(outputs, labels) _, predictions = torch.max(outputs, dim=1) val_loss_running += loss.item() * inputs.shape[0] From a1ea8822a84cca7a1e59b7fb31438169d83ddc50 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Fri, 28 Apr 2023 14:49:11 +0000 Subject: [PATCH 13/20] breaking(sanity-check): adding my implem gptq --- quantize/gptq/sanity_check_main.py | 323 +++++++++++++++++++++++++--- quantize/gptq/sanity_check_utils.py | 38 +++- 2 files changed, 330 insertions(+), 31 deletions(-) diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index a5e2b549..ee60595e 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -1,11 +1,11 @@ import argparse import time -import numpy as np +import re import torch import torch.nn as nn import torch.optim as optim -from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate +from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2 from gptq import * from modelutils import * from quant import * @@ -13,6 +13,36 @@ WBITS = 8 GROUPSIZE = -1 +## =============== REFERENCE =============== +def model_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [QuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name],scale,zero,g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + +def load_quant(model, checkpoint, wbits, groupsize): + print('Loading model ...') + model = model.eval() + layers = find_layers(model) + + # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) + # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) + for name in ["linear4"]: + if name in layers: + del layers[name] + + make_quant(model, layers, wbits, groupsize) + model.load_state_dict(torch.load(checkpoint)) + print('Done.') + return model + @torch.no_grad() def quantize_gptq(model, train_loader): quantizers = {} @@ -81,42 +111,248 @@ def tmp(_, inp, out): return quantizers -# TODO: perform packing on GPU -def model_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - make_quant(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [QuantLinear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name],scale,zero,g_idx = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return model +## =============== OUR IMPLEMENTATION =============== +class GPTQ_CUSTOM(SimpleNet_V2): -def load_quant(model, checkpoint, wbits, groupsize): - print('Loading model ...') - model = model.eval() - layers = find_layers(model) + ### begin GPTQ + class GPTQ: + def __init__(self, weight, name): + #TODO: Remove name, only used for debugging + self.name = name + self.weight = weight.clone() + self.dev = weight.device + # In GPTQ, they use nn.Linear(x) which performs x @ w.T but in RWKV, we perform x @ w instead + # Problem is self.H is a square matrix which depends on self.columns = W.shape[1] in the original code + # But if we keep it that way, this will break self.H += inp.matmul(inp.t()) because inp.shape[1] != W.shape[1] + # Thus, we have to use self.W.shape[0] instead + self.columns = self.weight.shape[0] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.deactivate_add_batch_call = False - # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) - # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) - for name in ["linear4"]: - if name in layers: - del layers[name] + def add_batch(self, inp): + + # After calling fasterquant, we don't want to call add_batch anymore + if self.deactivate_add_batch_call: + return - make_quant(model, layers, wbits, groupsize) - model.load_state_dict(torch.load(checkpoint)) - print('Done.') - return model + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + #TODO: is the case with len = 1 still necessary ? + tmp = 1 if len(inp.shape) == 1 else inp.shape[0] + + # Assume weight come from nn.Linear + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): + W = self.weight.data.clone() + # Need to transpose here, same reason as in __init__ with self.columns + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + + torch.cuda.synchronize() + print('time %.2f' % (time.time() - tick)) + print('error', torch.sum(Losses).item()) + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + #TODO: Do we have to uncomment it ? + # if isinstance(self.layer, transformers.Conv1D): + # Q = Q.t() + self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx + + ### end GPTQ + + ### begin GPTQ_CUSTOM + def __init__(self, checkpoint_path): + super().__init__() + self.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + + def _fill_subset(self, layer_id): + is_last_layer = (layer_id == self.nb_layers - 1) + if is_last_layer: + return {} + # Keep only layer within block layer_id + is_weight = re.compile(f'^linear{layer_id}_w$') + for name in self.w.keys(): + if is_weight.match(name): + self.subset[name] = self.w[name] + return self.subset + + def alloc_gptq(self, layer_id): + self.subset = {} + self.gptq = {} + + self.subset = self._fill_subset(layer_id) + + for name in self.subset: + self.gptq[name] = self.GPTQ(self.subset[name], name) + self.gptq[name].quantizer = Quantizer() + self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) + + def free_gptq(self): + self.subset = {} + self.gptq = {} + + def fasterquant(self, layer_id, quantizers): + + for name in self.subset: + print(layer_id, name) + print('Quantizing ...') + scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + + ## end GPTQ_CUSTOM + + ## Begin SimpleNet_V2 + def my_linear(self, x, weight, bias): + out = x @ weight.weight + bias + weight.add_batch(x) + return out + ## End SimpleNet_V2 + + +@torch.no_grad() +def quantize_gptq_custom(model, train_loader): + + nb_layers = model.nb_layers + is_last_layer = lambda x: x == (nb_layers - 1) + + nsamples = len(train_loader.dataset) + batch_size = train_loader.batch_size + + inps = torch.zeros((nsamples, model.N), dtype=torch.float) + for i, (inp, _) in enumerate(train_loader): + inps[i*batch_size:(i+1)*batch_size] = inp.view(-1, 32*32) + outs = torch.zeros_like(inps) + + quantizers = {} + + for layer_id in range(nb_layers): + + if not is_last_layer(layer_id): + + model.alloc_gptq(layer_id) + + for i in range(nsamples): + outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], model.w[f"linear{layer_id}_b"]) + + model.gptq[f"linear{layer_id}_w"].deactivate_add_batch_call = True + + model.fasterquant(layer_id, quantizers) + + for i in range(nsamples): + outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], model.w[f"linear{layer_id}_b"]) + + model.free_gptq() + + inps, outs = outs, inps + + return quantizers + + +def model_pack_custom(model, quantizers, wbits, groupsize): + pass if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train", action="store_true") parser.add_argument("--gptq", action="store_true") parser.add_argument("--eval_gptq", action="store_true") + parser.add_argument("--train_custom", action="store_true") + parser.add_argument("--gptq_custom", action="store_true") parser.add_argument("--pyquant", action="store_true") args = parser.parse_args() @@ -124,17 +360,22 @@ def load_quant(model, checkpoint, wbits, groupsize): seed_everything(42) lr = 0.02 num_epochs = 5 - model = SimpleNet() - optimizer = optim.Adam(model.parameters(), lr) criterion = nn.CrossEntropyLoss() train_loader, _, _ = MNISTloader(train_val_split=0.95).load() + #TODO: Why is training for ref and custom not the same + #TODO: Custom packing + + ## ================== REFERENCE ================== if args.train: + model = SimpleNet() + optimizer = optim.Adam(model.parameters(), lr) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) train(num_epochs, model, optimizer, criterion, train_loader, device) torch.save(model.state_dict(), "model.pt") elif args.gptq: + model = SimpleNet() device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) model = model.to(device) @@ -142,7 +383,9 @@ def load_quant(model, checkpoint, wbits, groupsize): model_pack(model, quantizers, WBITS, GROUPSIZE) torch.save(model.state_dict(), "model_quantized.pt") print("Done GPTQ") + elif args.eval_gptq: + model = SimpleNet() device = torch.device("cuda:0") model = load_quant(model, "model_quantized.pt", WBITS, GROUPSIZE) model = model.to(device) @@ -154,8 +397,26 @@ def load_quant(model, checkpoint, wbits, groupsize): print(f"wbits = {WBITS} using {device}") print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") print(f"Latency: {end - start}") + ## ================== CUSTOM ================== + elif args.train_custom: + model = SimpleNet_V2() + optimizer = optim.Adam(model.parameters(), lr) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + train(num_epochs, model, optimizer, criterion, train_loader, device) + torch.save(model.state_dict(), "model_custom.pt") + elif args.gptq_custom: + device = torch.device("cpu") + model = GPTQ_CUSTOM("./model_custom.pt") + model = model.to(device) + quantizers = quantize_gptq_custom(model, train_loader) + model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + torch.save(model.state_dict(), "model_quantized_custom.pt") + print("Done Custom GPTQ") + ## ================== MISC ================== elif args.pyquant: # Baseline post-training quantization from Pytorch + model = SimpleNet() device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt")) model.eval() @@ -177,6 +438,8 @@ def load_quant(model, checkpoint, wbits, groupsize): print(f"val_loss_q: {val_loss_q:.3f} \t val_acc_q:{val_acc_q:.3f}") print(f"Latency: {end_q - start_q}") else: + # Evaluate float 32 + model = SimpleNet() device = torch.device("cpu") model.load_state_dict(torch.load("./model.pt", map_location="cpu")) model = model.to(device) diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py index 90e17cbc..5fcf5ec4 100644 --- a/quantize/gptq/sanity_check_utils.py +++ b/quantize/gptq/sanity_check_utils.py @@ -20,7 +20,7 @@ def seed_everything(seed: int): # Model class SimpleNet(nn.Module): - def __init__(self, num_classes=10, init_weights=True): + def __init__(self, num_classes=10): super(SimpleNet, self).__init__() self.N = 32 * 32 self.linear1 = nn.Linear(in_features=self.N, out_features=self.N) @@ -66,6 +66,42 @@ def forward_pyquant(self, x): return x +class SimpleNet_V2(nn.Module): + def __init__(self, num_classes=10): + super(SimpleNet_V2, self).__init__() + self.N = 32 * 32 + self.linear0_w = nn.Parameter(torch.randn(self.N, self.N)) + self.linear0_b = nn.Parameter(torch.randn(self.N)) + self.linear1_w = nn.Parameter(torch.randn(self.N, self.N)) + self.linear1_b = nn.Parameter(torch.randn(self.N)) + self.linear2_w = nn.Parameter(torch.randn(self.N, self.N)) + self.linear2_b = nn.Parameter(torch.randn(self.N)) + self.linear3_w = nn.Parameter(torch.randn(self.N, num_classes)) + self.linear3_b = nn.Parameter(torch.randn(num_classes)) + + self.w = {} + self.nb_layers = 0 + for i in range(0, 4): + self.w[f"linear{i}_w"] = getattr(self, f"linear{i}_w") + self.w[f"linear{i}_b"] = getattr(self, f"linear{i}_b") + self.nb_layers += 1 + + def my_linear(self, x, weight, bias): + return x @ weight + bias + + def forward(self, x): + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + residual = x + x = F.relu(self.my_linear(x, self.linear0_w, self.linear0_b)) + x = self.my_linear(x, self.linear1_w, self.linear1_b) + x = F.relu(x) + residual + x = self.my_linear(x, self.linear2_w, self.linear2_b) + x = F.relu(x) + residual + x = self.my_linear(x, self.linear3_w, self.linear3_b) + return x + # Dataset class MNISTloader: From 8a37fb4b783c5e77e48a20fd895b773dca78df57 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Fri, 28 Apr 2023 16:18:49 +0000 Subject: [PATCH 14/20] fix(sanity-check): training ref and implem now yield same outputs --- quantize/gptq/sanity_check_main.py | 23 +++++++++++++++--- quantize/gptq/sanity_check_utils.py | 37 +++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index ee60595e..eec4114a 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -243,7 +243,7 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) #TODO: Do we have to uncomment it ? # if isinstance(self.layer, transformers.Conv1D): - # Q = Q.t() + # Q = Q.t() self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) if scale == []: @@ -346,6 +346,24 @@ def quantize_gptq_custom(model, train_loader): def model_pack_custom(model, quantizers, wbits, groupsize): pass +def load_quant_custom(model, quantizers, wbits, groupsize): + pass + +def assert_parameters(model, model_custom): + is_weight = re.compile(r'^linear\d+.weight$') + weights, bias = {}, {} + for name, param in model.named_parameters(): + if is_weight.match(name): + weights[name] = param + else: + bias[name] = param + + for i, (name, param) in enumerate(weights.items()): + assert torch.allclose(param, model_custom.state_dict()[f"linear{i}_w"]) + + for i, (name, param) in enumerate(bias.items()): + assert torch.allclose(param, model_custom.state_dict()[f"linear{i}_b"]) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train", action="store_true") @@ -363,8 +381,7 @@ def model_pack_custom(model, quantizers, wbits, groupsize): criterion = nn.CrossEntropyLoss() train_loader, _, _ = MNISTloader(train_val_split=0.95).load() - #TODO: Why is training for ref and custom not the same - #TODO: Custom packing + #TODO: Do Custom packing ## ================== REFERENCE ================== if args.train: diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py index 5fcf5ec4..0c0a6888 100644 --- a/quantize/gptq/sanity_check_utils.py +++ b/quantize/gptq/sanity_check_utils.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms -from collections import OrderedDict +import math def seed_everything(seed: int): random.seed(seed) @@ -22,6 +22,8 @@ def seed_everything(seed: int): class SimpleNet(nn.Module): def __init__(self, num_classes=10): super(SimpleNet, self).__init__() + seed_everything(42) + self.N = 32 * 32 self.linear1 = nn.Linear(in_features=self.N, out_features=self.N) self.linear2 = nn.Linear(in_features=self.N, out_features=self.N) @@ -69,15 +71,28 @@ def forward_pyquant(self, x): class SimpleNet_V2(nn.Module): def __init__(self, num_classes=10): super(SimpleNet_V2, self).__init__() + seed_everything(42) self.N = 32 * 32 - self.linear0_w = nn.Parameter(torch.randn(self.N, self.N)) - self.linear0_b = nn.Parameter(torch.randn(self.N)) - self.linear1_w = nn.Parameter(torch.randn(self.N, self.N)) - self.linear1_b = nn.Parameter(torch.randn(self.N)) - self.linear2_w = nn.Parameter(torch.randn(self.N, self.N)) - self.linear2_b = nn.Parameter(torch.randn(self.N)) - self.linear3_w = nn.Parameter(torch.randn(self.N, num_classes)) - self.linear3_b = nn.Parameter(torch.randn(num_classes)) + + self.linear0_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear0_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear0_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound)) + + self.linear1_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear1_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear1_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound)) + + self.linear2_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear2_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear2_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound)) + + self.linear3_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(num_classes, self.N), a=math.sqrt(5))) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear3_w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + self.linear3_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(num_classes), -bound, bound)) self.w = {} self.nb_layers = 0 @@ -87,7 +102,9 @@ def __init__(self, num_classes=10): self.nb_layers += 1 def my_linear(self, x, weight, bias): - return x @ weight + bias + # return x @ weight.t() + bias. + # Although this is the same, they yield different results as here: https://discuss.pytorch.org/t/differences-between-implementations/129237 + return F.linear(x, weight, bias) def forward(self, x): if len(x.shape) == 4: From 423352276f0925e80c24ff235b949c33a32b71b6 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Fri, 28 Apr 2023 20:25:55 +0000 Subject: [PATCH 15/20] feat(sanity-check): implem version of gptq now added --- quantize/gptq/quant.py | 182 +++++++++++++++++++++++++++++ quantize/gptq/sanity_check_main.py | 102 +++++++++++++--- 2 files changed, 271 insertions(+), 13 deletions(-) diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py index e301c521..00cb2819 100644 --- a/quantize/gptq/quant.py +++ b/quantize/gptq/quant.py @@ -147,6 +147,188 @@ def make_quant(module, names, bits, groupsize, name=''): for name1, child in module.named_children(): make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) +def make_quant_custom(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + + bias_name = attr.replace('w', 'b') + layer_name = attr.replace('w', 'quant') + setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None)) + + +class QuantLinear_custom(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): + super().__init__() + if bits not in [2,3,4,8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2,4,8]: + self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + self.is_cuda = is_cuda + + def pack(self, weight, bias, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if bias is not None: + self.bias = bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((weight[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold): + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2,4,8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4) + zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1))&0x7 + weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4) + weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx])) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + class QuantLinear(nn.Module): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): super().__init__() diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index eec4114a..31803adf 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn import torch.optim as optim +from collections import OrderedDict +import torch.nn.functional as F from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2 from gptq import * @@ -34,9 +36,8 @@ def load_quant(model, checkpoint, wbits, groupsize): # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) - for name in ["linear4"]: - if name in layers: - del layers[name] + if "linear4" in layers: + del layers["linear4"] make_quant(model, layers, wbits, groupsize) model.load_state_dict(torch.load(checkpoint)) @@ -258,8 +259,8 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) ### begin GPTQ_CUSTOM def __init__(self, checkpoint_path): super().__init__() - self.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) - + self.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + def _fill_subset(self, layer_id): is_last_layer = (layer_id == self.nb_layers - 1) if is_last_layer: @@ -292,7 +293,7 @@ def fasterquant(self, layer_id, quantizers): print(layer_id, name) print('Quantizing ...') scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) - quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) ## end GPTQ_CUSTOM @@ -301,6 +302,19 @@ def my_linear(self, x, weight, bias): out = x @ weight.weight + bias weight.add_batch(x) return out + + def forward(self, x): + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + residual = x + x = F.relu(self.linear0_quant(x)) + x = self.linear1_quant(x) + x = F.relu(x) + residual + x = self.linear2_quant(x) + x = F.relu(x) + residual + x = super().my_linear(x, self.linear3_w, self.linear3_b) + return x ## End SimpleNet_V2 @@ -321,9 +335,11 @@ def quantize_gptq_custom(model, train_loader): quantizers = {} for layer_id in range(nb_layers): - + if not is_last_layer(layer_id): - + + print(f"Quantizing layer {layer_id} ...") + model.alloc_gptq(layer_id) for i in range(nsamples): @@ -342,12 +358,56 @@ def quantize_gptq_custom(model, train_loader): return quantizers - def model_pack_custom(model, quantizers, wbits, groupsize): - pass + # Extract weights and bias from model + is_weight = re.compile(r'^linear\d+_w$') + weights, bias = OrderedDict(), OrderedDict() + for name, param in model.w.items(): + if is_weight.match(name): + weights[name] = param + else: + bias[name] = param + + make_quant_custom(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [QuantLinear_custom]) + + print('Packing ...') + for i in range(len(qlayers)): + name_w, name_b, layer_quant_name = f'linear{i}_w', f'linear{i}_b', f'linear{i}_quant' + quantizers[name_w],scale,zero,g_idx = quantizers[name_w] + qlayers[layer_quant_name].pack(weights[name_w], bias[name_b], scale, zero, g_idx) + print('Done.') + return model + +def load_quant_custom(model, checkpoint, wbits, groupsize): + print('Loading model ...') + model = model.eval() + # Extract weights and bias from model + is_weight = re.compile(r'^linear\d+_w$') + weights, bias = OrderedDict(), OrderedDict() + for name, param in model.w.items(): + if is_weight.match(name): + weights[name] = param + else: + bias[name] = param + + # Create linear layer out of weights and bias + layers = {} + for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()): + layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True) + layers[w_name].weight.data = w_param + layers[w_name].bias.data = b_param + + # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) + # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) + if "linear3_w" in layers: + del layers["linear3_w"] + + make_quant_custom(model, layers, wbits, groupsize) + model.load_state_dict(torch.load(checkpoint)) + print('Done.') + return model -def load_quant_custom(model, quantizers, wbits, groupsize): - pass def assert_parameters(model, model_custom): is_weight = re.compile(r'^linear\d+.weight$') @@ -371,6 +431,7 @@ def assert_parameters(model, model_custom): parser.add_argument("--eval_gptq", action="store_true") parser.add_argument("--train_custom", action="store_true") parser.add_argument("--gptq_custom", action="store_true") + parser.add_argument("--eval_gptq_custom", action="store_true") parser.add_argument("--pyquant", action="store_true") args = parser.parse_args() @@ -381,7 +442,9 @@ def assert_parameters(model, model_custom): criterion = nn.CrossEntropyLoss() train_loader, _, _ = MNISTloader(train_val_split=0.95).load() - #TODO: Do Custom packing + #TODO: Do custom eval gptq + #TODO: Is reference GPTQ quantizing bias as well ? + #TODO: Add seed everywhere in GPT for reproducibility ## ================== REFERENCE ================== if args.train: @@ -430,6 +493,19 @@ def assert_parameters(model, model_custom): model_pack_custom(model, quantizers, WBITS, GROUPSIZE) torch.save(model.state_dict(), "model_quantized_custom.pt") print("Done Custom GPTQ") + elif args.eval_gptq_custom: + model = GPTQ_CUSTOM("./model_custom.pt") + device = torch.device("cuda:0") + model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE) + model = model.to(device) + + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print(f"wbits = {WBITS} using {device}") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") ## ================== MISC ================== elif args.pyquant: # Baseline post-training quantization from Pytorch From e74d72a1445f50a23912083f7a73d1a76ffa9e0d Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 2 May 2023 18:17:57 +0000 Subject: [PATCH 16/20] fix(sanity-check): ref and implem now yield the same results at every step Date: Tue May 2 18:17:57 2023 +0000 --- quantize/gptq/quant.py | 11 +- quantize/gptq/sanity_check_main.py | 149 +++++++++++++--------------- quantize/gptq/sanity_check_utils.py | 75 ++++++++++++-- 3 files changed, 144 insertions(+), 91 deletions(-) diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py index 00cb2819..23933bae 100644 --- a/quantize/gptq/quant.py +++ b/quantize/gptq/quant.py @@ -148,16 +148,15 @@ def make_quant(module, names, bits, groupsize, name=''): make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) def make_quant_custom(module, names, bits, groupsize, name=''): - if isinstance(module, QuantLinear): + if isinstance(module, QuantLinear_custom): return for attr in dir(module): tmp = getattr(module, attr) name1 = name + '.' + attr if name != '' else attr - if name1 in names: - - bias_name = attr.replace('w', 'b') + if name1 in names: + bias = getattr(module, attr.replace('w', 'b')) layer_name = attr.replace('w', 'quant') - setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None)) + setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], bias is not None)) class QuantLinear_custom(nn.Module): @@ -203,7 +202,7 @@ def pack(self, weight, bias, scales, zeros, g_idx = None): intweight = [] for idx in range(self.infeatures): - intweight.append(torch.round((weight[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) intweight = torch.cat(intweight,dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index 31803adf..d96ee24f 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -49,7 +49,6 @@ def quantize_gptq(model, train_loader): quantizers = {} layers = list(model.modules())[1:] layers = [l for l in layers if isinstance(l, nn.Linear)] - layers = layers[:-1] is_last_layer = lambda x: x == (len(layers) - 1) nsamples = len(train_loader.dataset) @@ -60,54 +59,50 @@ def quantize_gptq(model, train_loader): inps[i*batch_size:(i+1)*batch_size] = inp.view(-1, 32*32) outs = torch.zeros_like(inps) - for layer_id in range(len(layers)): - layer = layers[layer_id] - - subset = find_layers(layer) - gptq = {} + + if not is_last_layer(layer_id): - for name in subset: - gptq[name] = GPTQ(subset[name], name) - gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=True, mse=False, trits=False) + layer = layers[layer_id] + + subset = find_layers(layer) + gptq = {} - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - return tmp - - handles = [] - - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) + print(f"Quantizing layer {layer_id} ...") + for name in subset: + gptq[name] = GPTQ(subset[name], name) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) - for i in range(nsamples): - if not is_last_layer(layer_id): + for i in range(nsamples): outs[i] = layer(inps[i]) - else: - _ = layer(inps[i]) - for h in handles: h.remove() - - for name in subset: - print(i, name) - print('Quantizing ...') - scale,zero,g_idx = gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) - quantizers[f"linear{layer_id + 1}"] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) - gptq[name].free() + for h in handles: h.remove() + + for name in subset: + print(i, name) + print('Quantizing ...') + scale,zero,g_idx = gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id + 1}"] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + gptq[name].free() - for i in range(nsamples): - if not is_last_layer(layer_id): + for i in range(nsamples): outs[i] = layer(inps[i]) - else: - _ = layer(inps[i]) - - del layer - del gptq - torch.cuda.empty_cache() - if not is_last_layer(layer_id): + del layer + del gptq + torch.cuda.empty_cache() + inps, outs = outs, inps return quantizers @@ -132,7 +127,6 @@ def __init__(self, weight, name): self.deactivate_add_batch_call = False def add_batch(self, inp): - # After calling fasterquant, we don't want to call add_batch anymore if self.deactivate_add_batch_call: return @@ -140,14 +134,10 @@ def add_batch(self, inp): if len(inp.shape) == 2: inp = inp.unsqueeze(0) - #TODO: is the case with len = 1 still necessary ? - tmp = 1 if len(inp.shape) == 1 else inp.shape[0] + tmp = inp.shape[0] # Assume weight come from nn.Linear - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) self.nsamples += tmp inp = math.sqrt(2 / self.nsamples) * inp.float() @@ -155,8 +145,9 @@ def add_batch(self, inp): def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): W = self.weight.data.clone() - # Need to transpose here, same reason as in __init__ with self.columns - W = W.t() + # OLD: Need to transpose here, same reason as in __init__ with self.columns + # UPDATE: no need to tranpose as we already transpose in my_linear() + # W = W.t() W = W.float() tick = time.time() @@ -166,6 +157,7 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) H = self.H del self.H + dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 @@ -242,9 +234,6 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) Q = Q[:, invperm] g_idx = g_idx[invperm] - #TODO: Do we have to uncomment it ? - # if isinstance(self.layer, transformers.Conv1D): - # Q = Q.t() self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) if scale == []: @@ -267,9 +256,10 @@ def _fill_subset(self, layer_id): return {} # Keep only layer within block layer_id is_weight = re.compile(f'^linear{layer_id}_w$') - for name in self.w.keys(): - if is_weight.match(name): - self.subset[name] = self.w[name] + + for name in dir(self): + if is_weight.match(name): + self.subset[name] = getattr(self, name) return self.subset def alloc_gptq(self, layer_id): @@ -277,7 +267,7 @@ def alloc_gptq(self, layer_id): self.gptq = {} self.subset = self._fill_subset(layer_id) - + for name in self.subset: self.gptq[name] = self.GPTQ(self.subset[name], name) self.gptq[name].quantizer = Quantizer() @@ -299,7 +289,8 @@ def fasterquant(self, layer_id, quantizers): ## Begin SimpleNet_V2 def my_linear(self, x, weight, bias): - out = x @ weight.weight + bias + # out = x @ weight.weight.T + bias # Use version below as it is more stable + out = F.linear(x, weight.weight, bias) weight.add_batch(x) return out @@ -308,6 +299,7 @@ def forward(self, x): x = x.view(x.size(0), -1) residual = x + #TODO: maybe we would need to transpose weight when building linear0_quant ? x = F.relu(self.linear0_quant(x)) x = self.linear1_quant(x) x = F.relu(x) + residual @@ -320,7 +312,6 @@ def forward(self, x): @torch.no_grad() def quantize_gptq_custom(model, train_loader): - nb_layers = model.nb_layers is_last_layer = lambda x: x == (nb_layers - 1) @@ -333,40 +324,44 @@ def quantize_gptq_custom(model, train_loader): outs = torch.zeros_like(inps) quantizers = {} - + for layer_id in range(nb_layers): if not is_last_layer(layer_id): print(f"Quantizing layer {layer_id} ...") + bias = getattr(model, f"linear{layer_id}_b") + model.alloc_gptq(layer_id) for i in range(nsamples): - outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], model.w[f"linear{layer_id}_b"]) - + outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], bias) + model.gptq[f"linear{layer_id}_w"].deactivate_add_batch_call = True model.fasterquant(layer_id, quantizers) for i in range(nsamples): - outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], model.w[f"linear{layer_id}_b"]) - + outs[i] = model.my_linear(inps[i], model.gptq[f"linear{layer_id}_w"], bias) + + setattr(model, f"linear{layer_id}_w", nn.Parameter(model.gptq[f"linear{layer_id}_w"].weight)) model.free_gptq() inps, outs = outs, inps - + return quantizers def model_pack_custom(model, quantizers, wbits, groupsize): # Extract weights and bias from model - is_weight = re.compile(r'^linear\d+_w$') + is_weight, is_bias = re.compile(r'^linear\d+_w$'), re.compile(r'^linear\d+_b$') weights, bias = OrderedDict(), OrderedDict() - for name, param in model.w.items(): - if is_weight.match(name): - weights[name] = param - else: - bias[name] = param + + for attr in dir(model): + if is_weight.match(attr): + weights[attr] = getattr(model, attr) + elif is_bias.match(attr): + bias[attr] = getattr(model, attr) make_quant_custom(model, quantizers, wbits, groupsize) qlayers = find_layers(model, [QuantLinear_custom]) @@ -383,13 +378,13 @@ def load_quant_custom(model, checkpoint, wbits, groupsize): print('Loading model ...') model = model.eval() # Extract weights and bias from model - is_weight = re.compile(r'^linear\d+_w$') + is_weight, is_bias = re.compile(r'^linear\d+_w$'), re.compile(r'^linear\d+_b$') weights, bias = OrderedDict(), OrderedDict() - for name, param in model.w.items(): - if is_weight.match(name): - weights[name] = param - else: - bias[name] = param + for attr in dir(model): + if is_weight.match(attr): + weights[attr] = getattr(model, attr) + elif is_bias.match(attr): + bias[attr] = getattr(model, attr) # Create linear layer out of weights and bias layers = {} @@ -442,10 +437,6 @@ def assert_parameters(model, model_custom): criterion = nn.CrossEntropyLoss() train_loader, _, _ = MNISTloader(train_val_split=0.95).load() - #TODO: Do custom eval gptq - #TODO: Is reference GPTQ quantizing bias as well ? - #TODO: Add seed everywhere in GPT for reproducibility - ## ================== REFERENCE ================== if args.train: model = SimpleNet() diff --git a/quantize/gptq/sanity_check_utils.py b/quantize/gptq/sanity_check_utils.py index 0c0a6888..22e4a1f3 100644 --- a/quantize/gptq/sanity_check_utils.py +++ b/quantize/gptq/sanity_check_utils.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms import math +import struct def seed_everything(seed: int): random.seed(seed) @@ -39,7 +40,6 @@ def forward(self, x): x = x.view(x.size(0), -1) residual = x - x = F.relu(self.linear1(x)) x = self.linear2(x) x = F.relu(x) + residual @@ -95,11 +95,7 @@ def __init__(self, num_classes=10): self.linear3_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(num_classes), -bound, bound)) self.w = {} - self.nb_layers = 0 - for i in range(0, 4): - self.w[f"linear{i}_w"] = getattr(self, f"linear{i}_w") - self.w[f"linear{i}_b"] = getattr(self, f"linear{i}_b") - self.nb_layers += 1 + self.nb_layers = 4 def my_linear(self, x, weight, bias): # return x @ weight.t() + bias. @@ -252,3 +248,70 @@ def train(num_epochs, model, optimizer, criterion, train_loader, device): info = "Epoch: {:3}/{} \t train_loss: {:.3f} \t train_acc: {:.3f}" print(info.format(epoch + 1, num_epochs, train_loss, train_acc)) + +def write_bin(filename, array): + from functools import reduce + # Force endianess: https://stackoverflow.com/questions/23831422/what-endianness-does-python-use-to-write-into-files + dtype_to_format = { + np.int8: 'i', + np.int16: 'i', + np.int32: 'i', + np.int64: 'i', + np.unsignedinteger: 'I', + np.float16: 'f', + np.float32: 'f', + np.float64: 'f', + np.double: 'd' + } + fmt = dtype_to_format[array.dtype.type] + shapes = [shape for shape in array.shape] + # n, c, h, w = array.shape + with open(filename, "wb") as f: + # number of dim + f.write(struct.pack('I', len(shapes))) + for shape in shapes: + f.write(struct.pack('I', shape)) + f.write(struct.pack('c', bytes(fmt, 'utf-8'))) + f.write(struct.pack(f"{fmt}"*(reduce(lambda x, y: x * y, shapes)), *array.flatten(order="C").tolist())) + +def read_bin(filename): + # https://qiita.com/madaikiteruyo/items/dadc99aa29f7eae0cdd0 + format_to_byte = { + 'c': 1, + 'i': 4, + 'I': 4, + 'f': 4, + 'd': 8 + } + + data = [] + dims, fmt = None, None + with open(filename, "rb") as f: + # read row and col (np.int = 4 bytes) + byte = f.read(format_to_byte['i']) + + if byte == b'': + raise Exception("read_bin: Empty binary") + else: + nb_dim = struct.unpack('I', byte) + + # Read dims + byte = f.read(nb_dim[0] * format_to_byte['I']) + dims = struct.unpack('I'*nb_dim[0], byte) + # Read character format + byte = f.read(1) + if byte == b'': + raise Exception("read_bin: Empty binary") + else: + fmt = chr(struct.unpack('c', byte)[0][0]) + + if len(fmt) != 1: raise Exception("read_bin: No format dumped in binary") + + while True: + byte = f.read(format_to_byte[fmt]) + if byte == b'': + break + else: + data.append(struct.unpack(fmt, byte)[0]) + + return np.array(data).reshape(*dims) \ No newline at end of file From cf14124dbdb911ac2d99decbc1f54a46419d5f31 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Wed, 3 May 2023 12:30:10 +0000 Subject: [PATCH 17/20] feat(quantize): readapt GPTQ for rwkv --- quantize/tmp_rwkv.py | 166 +++++++++++++++++++++++++------------------ 1 file changed, 97 insertions(+), 69 deletions(-) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index 1ed2ad1f..83d0dfd7 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -5,11 +5,14 @@ import os import torch.nn.functional as F -import torch.nn as nn +from collections import OrderedDict import time import math import re +WBITS = 8 +GROUPSIZE = -1 + class GPTQ_RWKV(RWKV): ### begin GPTQ @@ -29,7 +32,6 @@ def __init__(self, weight, name): self.deactivate_add_batch_call = False def add_batch(self, inp): - # After calling fasterquant, we don't want to call add_batch anymore if self.deactivate_add_batch_call: return @@ -37,9 +39,8 @@ def add_batch(self, inp): if len(inp.shape) == 2: inp = inp.unsqueeze(0) - #TODO: is the case with len = 1 still necessary ? - tmp = 1 if len(inp.shape) == 1 else inp.shape[0] - + tmp = inp.shape[0] + # Assume weight come from nn.Linear if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) @@ -52,7 +53,9 @@ def add_batch(self, inp): def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): W = self.weight.data.clone() - # Need to transpose here, same reason as in __init__ with self.columns + # OLD: Need to transpose here, same reason as in __init__ with self.columns + # UPDATE: no need to tranpose as we already transpose in my_linear() + # UPDATE2: for rwkv, this is necessary W = W.t() W = W.float() @@ -63,10 +66,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) H = self.H del self.H + dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 - + if actorder: perm = torch.argsort(torch.diag(H), descending=True) W = W[:, perm] @@ -82,6 +86,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -101,6 +110,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) if (i1 + i) % groupsize == 0: self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + q = quantize( w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq ).flatten() @@ -116,15 +130,27 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + torch.cuda.synchronize() print('time %.2f' % (time.time() - tick)) print('error', torch.sum(Losses).item()) - + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if actorder: invperm = torch.argsort(perm) Q = Q[:, invperm] + g_idx = g_idx[invperm] self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale,dim=1) + zero = torch.cat(zero,dim=1) + return scale,zero,g_idx ### end GPTQ @@ -134,6 +160,7 @@ def __init__(self, model, strategy): for i in range(self.args.n_layer): assert self.strategy[i].device == "cpu" + #TODO: Change to match my implem def _fill_subset(self, layer_id): # Keep only layer within block layer_id is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') @@ -146,18 +173,18 @@ def _fill_subset(self, layer_id): if is_last_layer: self.subset["head.weight"] = self.w["head.weight"] - + return self.subset + def alloc_gptq(self, layer_id): self.subset = {} self.gptq = {} - self._fill_subset(layer_id) - + self.subset = self._fill_subset(layer_id) + for name in self.subset: self.gptq[name] = self.GPTQ(self.subset[name], name) self.gptq[name].quantizer = Quantizer() - #TODO: add argparse to configure - self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False) + self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False) def free_gptq(self): self.subset = {} @@ -166,11 +193,10 @@ def free_gptq(self): def fasterquant(self, layer_id, quantizers): for name in self.subset: - print(f"Quantizing {name} of layer {layer_id}") - #TODO: add argparse to fastquant - self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False) - # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers[name] = self.gptq[name].quantizer + print(layer_id, name) + print('Quantizing ...') + scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) + quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) ### end GPTQ_RWKV @@ -326,7 +352,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): orx = self.w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x omy = self.w[f'{att}output.weight_my'] if wtype == torch.uint8 else x ory = self.w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x - + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3], ln_w=self.w[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'], @@ -338,12 +364,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, omx=omx, orx=orx, omy=omy, ory=ory, ) - - # Deactivate add_batch() after quantization is applied - kw.deactivate_add_batch_call = True - vw.deactivate_add_batch_call = True - rw.deactivate_add_batch_call = True - ow.deactivate_add_batch_call = True if dd.stream: del kw, vw, rw, ow @@ -378,11 +398,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): vmx=vmx, vrx=vrx, vmy=vmy, vry=vry, rmx=rmx, rrx=rrx, rmy=rmy, rry=rry, ) - - # Deactivate add_batch() after quantization is applied - kw.deactivate_add_batch_call = True - vw.deactivate_add_batch_call = True - rw.deactivate_add_batch_call = True if dd.stream: del kw, vw, rw @@ -392,7 +407,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x = x / 2 is_last_layer = i == (args.n_layer - 1) - if is_last_layer: dd = self.strategy[args.n_layer] x = x[-1,:] if (seq_mode and (not full_output)) else x @@ -410,63 +424,77 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): ### end RWKV -model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') - -NSAMPLES=2 -HIDDEN_SIZE=model.args.n_embd -SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m - -# train_tokens, test_tokens = get_loaders( -# dataset_name="wikitext2", -# nsamples=NSAMPLES, -# seed=42, -# seqlen=SEQLEN, -# model=model -# ) - -# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) -tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) -print("tokens.shape", tokens.shape) - -is_last_layer = lambda x: x == (model.args.n_layer - 1) - -start_time = time.time() - -#TODO: Do the same in GPU side -with torch.no_grad(): +@torch.no_grad() +def quantize_gptq_custom(model, tokens): + nsamples = tokens.shape[0] seq_mode = len(tokens) > 1 + is_last_layer = lambda x: x == (model.args.n_layer - 1) + inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]] outs = torch.zeros_like(inps) - quantizers = {} - + for layer_id in range(model.args.n_layer): + + print(f"Quantizing layer {layer_id} ...") model.alloc_gptq(layer_id) - for j in range(NSAMPLES): + for i in range(nsamples): + #TODO: Are outs value normal ? (they look almost all the same) if not is_last_layer(layer_id): - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) + outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) - + _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + + for gptq_layer in model.gptq.values(): + gptq_layer.deactivate_add_batch_call = True + + tmp = model.w["blocks.0.att.key.weight"] + model.fasterquant(layer_id, quantizers) - for j in range(NSAMPLES): + for i in range(nsamples): if not is_last_layer(layer_id): - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) + outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) - + _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) + + # Assign the quantized weights to the model + for key in model.gptq.keys(): + model.w[key].copy_(model.gptq[key].weight) + model.free_gptq() # We need to pass the outputs of block i as input of block i+1 (except for last block) if not is_last_layer(layer_id): inps, outs = outs, inps -end_time = time.time() + return quantizers + +if __name__ == "__main__": + + model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') + + NSAMPLES=2 + HIDDEN_SIZE=model.args.n_embd + SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m + + train_tokens, test_tokens = get_loaders( + dataset_name="wikitext2", + nsamples=NSAMPLES, + seed=42, + seqlen=SEQLEN, + model=model + ) + + tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) + tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) + print("tokens.shape", tokens.shape) -print(f"Done in {end_time - start_time:.2f} seconds") + import pdb; pdb.set_trace() + # quantizers = quantize_gptq_custom(model, tokens) -# TODO: Do something with quantizers dictionary -# TODO: pack3 save model \ No newline at end of file + # model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + # torch.save(model.state_dict(), "model_quantized_custom.pt") + # print("Done Custom GPTQ") \ No newline at end of file From c2bbe6430007fc81d4e6a0b9f132dc9c2eda04db Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Sun, 7 May 2023 19:47:17 +0000 Subject: [PATCH 18/20] breaking(gptq): quantizing only 1 layer yield high perplexity --- quantize/compress_rwkv.py | 239 ----------- quantize/gptq/datautils.py | 2 +- quantize/gptq/quant.py | 5 +- quantize/measure_perplexity.py | 15 +- quantize/myRWKV.py | 747 +++++++++++++++++++++++++++++++++ quantize/tmp_rwkv.py | 186 +++++--- quantize/tmp_rwkv.py.lprof | Bin 0 -> 426 bytes 7 files changed, 892 insertions(+), 302 deletions(-) delete mode 100644 quantize/compress_rwkv.py create mode 100644 quantize/myRWKV.py create mode 100644 quantize/tmp_rwkv.py.lprof diff --git a/quantize/compress_rwkv.py b/quantize/compress_rwkv.py deleted file mode 100644 index dbe17a87..00000000 --- a/quantize/compress_rwkv.py +++ /dev/null @@ -1,239 +0,0 @@ -import time -import torch -import torch.nn as nn - -from rwkv.model import RWKV -from gptq.gptq import * -from gptq.modelutils import * -from gptq.quant import * -from gptq.datautils import * - -# TODO: perform packing on GPU -def opt_pack3(model, quantizers): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - make_quant3(model, quantizers, faster=args.faster_kernel) - qlayers = find_layers(model, [Quant3Linear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name] = quantizers[name].cpu() - qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) - print('Done.') - return model - -@torch.no_grad() -def quantize_model(model, train_tokens, device): - print('Starting ...') - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.decoder.layers - - # Load layer to device - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device) - model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device) - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.to(device) - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.to(device) - layers[0] = layers[0].to(device) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device - ) - cache = {'i': 0, 'attention_mask': None} - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - raise ValueError - layers[0] = Catcher(layers[0]) - - for batch in train_tokens: - try: - # model(batch[0].to(device)) - # IndexError: invalid index of a 0-dim tensor. - # Use `tensor.item()` in Python or `tensor.item()` - # in C++ to convert a 0-dim tensor to a number - model(batch[0].to(device)) - except ValueError: - pass - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() - model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.cpu() - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - - print('Ready.') - - quantizers = {} - for i in range(len(layers)): - layer = layers[i].to(device) - - subset = find_layers(layer) - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure( - args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits - ) - - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - return tmp - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - for h in handles: - h.remove() - - for name in subset: - print(i, name) - print('Quantizing ...') - gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer - gptq[name].free() - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - - model.config.use_cache = use_cache - - return quantizers - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - - parser.add_argument( - '--model_path', type=str, - help='Path to model checkpoint file.' - ) - - parser.add_argument( - '--dataset_name', type=str, choices=['wikitext2', 'ptb', 'c4'], - help='Where to extract calibration data from.' - ) - - parser.add_argument( - '--wbits', type=int, default=16, choices=[2, 3, 4, 16], - help='#bits to use for quantization; use 16 for evaluating base model.' - ) - - parser.add_argument( - '--groupsize', type=int, default=-1, - help='Groupsize to use for quantization; default uses full row.' - ) - - parser.add_argument( - '--save', type=str, default='', - help='Save quantized checkpoint under this name.' - ) - - # ==== DEFAULT ==== - parser.add_argument( - '--seed', - type=int, default=0, help='Seed for sampling the calibration data.' - ) - parser.add_argument( - '--nsamples', type=int, default=128, - help='Number of calibration data samples.' - ) - parser.add_argument( - '--percdamp', type=float, default=.01, - help='Percent of the average Hessian diagonal to use for dampening.' - ) - - parser.add_argument( - '--nearest', action='store_true', - help='Whether to run the RTN baseline.' - ) - - parser.add_argument( - '--trits', action='store_true', - help='Whether to use trits for quantization.' - ) - - parser.add_argument( - '--sym', action='store_true', - help='Whether to perform symmetric quantization.' - ) - - parser.add_argument( - '--faster-kernel', action='store_true', - help='Whether to use the new faster kernel for benchmarking.' - ) - - parser.add_argument( - '--act-order', action='store_true', - help='Whether to apply the activation order GPTQ heuristic' - ) - - args = parser.parse_args() - - # FIXME: Seems like quantization with OPT is not working in CPU mode - # device = torch.device('cpu') - device = torch.device('cuda:0') - - # Model - # model = model = RWKV(args.model_path, strategy='cpu fp32') - - def skip(*args, **kwargs): pass - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - from transformers import OPTForCausalLM - model = OPTForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype='auto') - model.seqlen = model.config.max_position_embeddings - - model.eval() - - # Dataset - train_tokens, test_tokens = get_loaders( - dataset_name=args.dataset_name, - nsamples=args.nsamples, - seed=args.seed, - seqlen=model.seqlen, - model="facebook/opt-125m", - # model=None - ) - - print(f'{len(train_tokens)} train tokens in the text') - print(f'{len(test_tokens)} test tokens in the text') - - if args.wbits < 16 and not args.nearest: - start_time = time.time() - quantizers = quantize_model(model, train_tokens, device) - end_time = time.time() - print('Quantization time: ', end_time - start_time) - - if args.save: - print('Saving quantized model to ', args.save) - opt_pack3(model, quantizers) - torch.save(model.state_dict(), args.save) \ No newline at end of file diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py index cd296d3c..166e6547 100644 --- a/quantize/gptq/datautils.py +++ b/quantize/gptq/datautils.py @@ -4,7 +4,7 @@ import pathlib import tokenizers import random -from rwkv.model import RWKV +from myRWKV import RWKV from datasets import load_dataset diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py index 23933bae..f584325b 100644 --- a/quantize/gptq/quant.py +++ b/quantize/gptq/quant.py @@ -202,7 +202,8 @@ def pack(self, weight, bias, scales, zeros, g_idx = None): intweight = [] for idx in range(self.infeatures): - intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + #OLD: intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight.append(torch.round((weight.data[idx, :] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) intweight = torch.cat(intweight,dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) @@ -411,7 +412,7 @@ def pack(self, linear, scales, zeros, g_idx = None): qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) - zeros -= 1; + zeros -= 1 zeros = zeros.numpy().astype(np.uint32) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 diff --git a/quantize/measure_perplexity.py b/quantize/measure_perplexity.py index e0b5be29..8cdbe4a3 100644 --- a/quantize/measure_perplexity.py +++ b/quantize/measure_perplexity.py @@ -10,6 +10,8 @@ import torch from typing import List from rwkv.model import RWKV +os.environ['RWKV_JIT_ON'] = '1' +os.environ["RWKV_CUDA_ON"] = '0' def parse_args(): parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') @@ -56,9 +58,10 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str: # --- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') -# device=torch.device('cpu') +# device = torch.device('cpu') -model = RWKV(model=args.model_path, strategy='cuda fp16i8') +#TODO: Why is PERPLEXITY SO DAMN HIGH ? +model = RWKV(model=args.model_path, strategy='cuda fp16') logits, state = None, None loss_sum: torch.Tensor = torch.tensor([0.0], device=device) @@ -72,7 +75,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str: for i in range(run_count): token: int = test_tokens[i] target: int = test_tokens[i + 1] - + logits, state = model.forward([token], None if i == 0 else state) if ignore_first_n_tokens == 0 or i + 1 >= ignore_first_n_tokens: @@ -105,7 +108,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str: print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token') print() -print(f'Model: {os.path.basename(args.model_path)}, ' - f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens, ' - f'Ignored first {ignore_first_n_tokens} tokens, ' +print(f'Model: {os.path.basename(args.model_path)}\n' + f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens\n' + f'Ignored first {ignore_first_n_tokens} tokens\n' f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}') diff --git a/quantize/myRWKV.py b/quantize/myRWKV.py new file mode 100644 index 00000000..524b9481 --- /dev/null +++ b/quantize/myRWKV.py @@ -0,0 +1,747 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types, gc, os, time, re +import torch +from torch.nn import functional as F +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +current_path = os.path.dirname(os.path.abspath(__file__)) + +######################################################################################################## + +if os.environ.get('RWKV_JIT_ON') != '0': + os.environ["RWKV_JIT_ON"] = '1' + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script +else: + MyModule = torch.nn.Module + def __nop(ob): + return ob + MyFunction = __nop + MyStatic = __nop + +if os.environ.get('RWKV_CUDA_ON') == '1': + from torch.utils.cpp_extension import load + load( + name=f"wkv_cuda", + sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"], + verbose=True, + extra_cuda_cflags=["-t 4", "-std=c++17", "--use_fast_math", "-O3", "--extra-device-vectorization"], + is_python_module=False) + + @MyStatic + def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): + assert 1 * C % min(C, 32) == 0 + assert k.dtype == torch.float16 + w = w.contiguous() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.float16) + torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) + return y, aa, bb, pp + @MyStatic + def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == [B, N] + assert w.shape == [N, M] + assert rx.shape == mx.shape == [M] + assert ry.shape == my.shape == [N, 1] + y = torch.empty((B, M), device=w.device, dtype=torch.float16) + torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) + return y + @MyStatic + def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == [N] + assert w.shape == [N, M] + assert rx.shape == mx.shape == [M] + assert ry.shape == my.shape == [N, 1] + y = torch.zeros((M,), device=w.device, dtype=torch.float32) + torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) + return y.to(dtype=torch.float16) +else: + os.environ["RWKV_CUDA_ON"] = '0' + +######################################################################################################## + +class RWKV(MyModule): + def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): + super().__init__() + if verbose: + prxxx = lambda *args, **kwargs: print(*args, **kwargs) + else: + prxxx = lambda *args, **kwargs: None + + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + if not re.match(STRATEGY_REGEX, strategy): + raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") + + strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') + self.args = types.SimpleNamespace() + args = self.args + args.MODEL_NAME = model + args.strategy_string = strategy + + # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) + self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 + prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') + + args.MODEL_NAME = args.MODEL_NAME.strip() + if not args.MODEL_NAME.endswith('.pth'): + args.MODEL_NAME += '.pth' + prxxx(f'Loading {args.MODEL_NAME} ...') + with torch.no_grad(): + obj = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first + if isinstance(obj, list): # GPTQ + self.w_quant = obj[0] + self.w = obj[1] + gc.collect() + w = self.w + ALREADY_CONVERTED = True + else: + self.w_quant = {} + self.w = obj + gc.collect() + w = self.w + + ALREADY_CONVERTED = False + if '_strategy' in w: + ALREADY_CONVERTED = True + assert convert_and_save_and_exit == None # you should only convert a raw model + prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") + assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model + assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py + assert w['_rescale_layer'] == self.RESCALE_LAYER + del w['_strategy'] + del w['_version'] + del w['_rescale_layer'] + + args.n_embd = w['emb.weight'].shape[1] + args.n_layer = 0 + keys = list(w.keys()) + for x in keys: + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + args.n_layer = max(args.n_layer, layer_id+1) + + ####################### Compute strategy + + s = [x.strip().split(' ') for x in strategy.split('->')] + plan = [0] * len(s) + stream_i = -1 + stream_count = 0 + to_allocate = args.n_layer + 1 + allocated = 0 + free_slots = 0 + for i in range(len(s)): + si = s[i] + si1 = si[1] + if si1.startswith('fp32'): si[1] = [torch.float] + elif si1.startswith('fp16'): si[1] = [torch.float16] + elif si1.startswith('bf16'): si[1] = [torch.bfloat16] + if si1.endswith('i8'): si[1] += [torch.uint8] + else: si[1] += [si[1][0]] + if len(si) > 2: + ss = si[2] + assert ss.startswith('*') + if ss.endswith('+'): + plan[i] = int(ss[1:-1]) + stream_i = i + else: + plan[i] = int(ss[1:]) + allocated += plan[i] + if allocated >= to_allocate: + plan[i] += to_allocate - allocated + break + else: + free_slots += 1 + if stream_i < 0: + if free_slots > 0 and to_allocate > allocated: + for i in range(len(s)): + if plan[i] == 0: + plan[i] = (to_allocate - allocated) // free_slots + allocated += plan[i] + free_slots -= 1 + if to_allocate > allocated: + plan[len(s)-1] += to_allocate - allocated + else: + if to_allocate > allocated: + stream_count = to_allocate - allocated + plan[stream_i] += stream_count + prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') + for i in range(len(s)): + ss = s[i] + if i != stream_i: + prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') + else: + prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') + plan[i] += (0 if i == 0 else plan[i-1]) + self.strategy = [None] * (args.n_layer + 1) + strategy = self.strategy + + for n in range(args.n_layer + 1): + for i in range(len(s)): + if n < plan[i]: + strategy[n] = types.SimpleNamespace() + strategy[n].device = s[i][0] + strategy[n].atype = s[i][1][0] + strategy[n].wtype = s[i][1][1] + strategy[n].stream = False + if i == stream_i and n >= (plan[i] - stream_count): + strategy[n].stream = True + break + prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') + prxxx() + + ####################### Load weights to self.w + if not ALREADY_CONVERTED: + try: # precompute embedding + w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) + except: + w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) + del w['blocks.0.ln0.weight'] + del w['blocks.0.ln0.bias'] + + print_need_newline = False + keys = list(w.keys()) + for x in keys: + w[x].requires_grad = False + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + if ('ln_out.' in x) or ('head.' in x): + layer_id = args.n_layer + dd = strategy[layer_id] + DEVICE = dd.device + ATYPE = dd.atype + WTYPE = dd.wtype + + if not ALREADY_CONVERTED: + if self.RESCALE_LAYER > 0: + if 'att.output.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + if 'ffn.value.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + + if '.time_' in x: + w[x] = w[x].squeeze() + if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: + w[x] = w[x].t() + + if '.time_decay' in x: # need fp32 for this + w[x] = -torch.exp(w[x].float()) + elif '.time_first' in x: # need fp32 for this + w[x] = w[x].float() + else: + if (len(w[x].shape) == 2) and ('emb' not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x] = w[x].float() + + if w[x].shape[0] > w[x].shape[1]: + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + else: + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + + w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) + w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() + w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() + w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() + w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() + else: + w[x] = w[x].to(dtype=ATYPE) + + # GPTQ + if self.w_quant != {}: + if (len(w[x].shape) == 2) and ('emb' not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() + w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() + w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() + w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() + else: + w[x] = w[x].to(dtype=ATYPE) + + if convert_and_save_and_exit == None: + if 'emb.' in x: + w[x] = w[x].contiguous() + elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): + try: + w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) + except: + print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') + elif DEVICE != 'cpu': + w[x] = w[x].to(device=DEVICE).contiguous() + + if (dd.stream) or (DEVICE != 'cpu'): + try: + w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() + w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() + w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() + w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() + except: + pass + + if 'ffn.value.weight' in x: + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + shape = [i for i in w[x].shape if i != 1] + if len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + else: + shape = f" {str(shape[0]).rjust(5)} " + if layer_id == 0 or layer_id >= args.n_layer-1: + if print_need_newline: + prxxx('\n', end = '') + print_need_newline = False + dt = str(w[x].dtype).replace('torch.', '') + dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') + prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') + else: + print_need_newline = True + prxxx('.', end = '', flush = True) + + if convert_and_save_and_exit: + w['_strategy'] = args.strategy_string + w['_rescale_layer'] = self.RESCALE_LAYER + w['_version'] = '0.7' + if not convert_and_save_and_exit.endswith('.pth'): + convert_and_save_and_exit += '.pth' + prxxx(f'Saving to {convert_and_save_and_exit}...') + torch.save(w, convert_and_save_and_exit) + prxxx(f'Converted and saved. Now this will exit.') + exit(0) + + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + if os.environ.get('RWKV_CUDA_ON') == '1': + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + B, N, M = x.shape[0], w.shape[0], w.shape[1] + return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + N, M = w.shape[0], w.shape[1] + return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) + else: + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + ######################################################################################################## + + # @MyFunction + def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + # vx = torch.square(torch.relu(kx @ kw)) + vx = torch.square(torch.relu(self._trigger_gptq(kx, weight=kw[0], name=kw[1]))) + # out = r * (vx @ vw) + out = r * (self._trigger_gptq(vx, weight=vw[0], name=vw[1])) + return x + out, xx + + # @MyFunction + def ffn_one_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx + + ######################################################################################################## + + # @MyFunction + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + # vx = torch.square(torch.relu(kx @ kw)) + vx = torch.square(torch.relu(self._trigger_gptq(kx, weight=kw[0], name=kw[1]))) + # out = r * (vx @ vw) + out = r * self._trigger_gptq(vx, weight=vw[0], name=vw[1]) + return x + out, xx[-1,:] + + # @MyFunction + def ffn_seq_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx[-1,:] + + ######################################################################################################## + + # @MyFunction + def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # k = (kx @ kw).float() + # v = (vx @ vw).float() + + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + k = self._trigger_gptq(kx, weight=kw[0], name=kw[1]).float() + v = self._trigger_gptq(vx, weight=vw[0], name=vw[1]).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + # out = (r * wkv) @ ow + out = self._trigger_gptq(r * wkv, weight=ow[0], name=ow[1]) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + # @MyFunction + def att_one_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() + v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + ######################################################################################################## + + # @MyFunction + def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + # r = torch.sigmoid(rx @ rw) + # k = (kx @ kw).float() + # v = (vx @ vw).float() + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + k = (self._trigger_gptq(kx, weight=kw[0], name=kw[1])).float() + v = (self._trigger_gptq(vx, weight=vw[0], name=vw[1])).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + # out = (r * sx) @ ow + out = self._trigger_gptq(r * sx, weight=ow[0], name=ow[1]) + return x + out, xx[-1,:], aa, bb, pp + + # @MyFunction + def att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) + return x + out, xx[-1,:], aa, bb, pp + + ######################################################################################################## + + if os.environ["RWKV_CUDA_ON"] == '1': + @MyFunction + def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = kx @ kw + v = vx @ vw + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = (r * y) @ ow + return x + out, xx[-1,:], aa, bb, pp + + @MyFunction + def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) + return x + out, xx[-1,:], aa, bb, pp + + ######################################################################################################## + + def _trigger_gptq(self, x, weight, name): + + # GPTQ + if name in self.w_quant.keys(): + w_quant = self.w_quant[name] + # Only work on CUDA because of the use of fp16 (not available on CPU) + if w_quant.bits in [2, 4, 8]: + w_quant.scales = w_quant.scales.to(x.device) + out_shape = x.shape[:-1] + (w_quant.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + + zeros = torch.bitwise_right_shift(torch.unsqueeze(w_quant.qzeros, 2).expand(-1, -1, 32 // w_quant.bits), w_quant.wf.unsqueeze(0)).to(x.device, torch.int16 if w_quant.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** w_quant.bits) - 1, out=zeros) + + zeros = zeros + 1 + # if name == "head.weight": + # import pdb; pdb.set_trace() + zeros = zeros.reshape(w_quant.scales.shape) + + new_weight = torch.bitwise_right_shift(torch.unsqueeze(w_quant.qweight, 1).expand(-1, 32 // w_quant.bits, -1), w_quant.wf.unsqueeze(-1)).to(x.device, torch.int16 if w_quant.bits == 8 else torch.int8) + torch.bitwise_and(new_weight,(2 ** w_quant.bits) - 1, out=new_weight) + new_weight = new_weight.reshape(new_weight.shape[0] * new_weight.shape[1], new_weight.shape[2]) + new_weight = (w_quant.scales[w_quant.g_idx.long()] * (new_weight - zeros[w_quant.g_idx.long()])) + + out = torch.matmul(x.half(), new_weight) + out = out.reshape(out_shape) + return out + else: + return x @ weight + + def forward(self, tokens, state, full_output=False): + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 + state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + + seq_mode = len(tokens) > 1 + + x = w['emb.weight'][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': + ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = w[f'{att}key.weight'], f'{att}key.weight' + vw = w[f'{att}value.weight'], f'{att}value.weight' + rw = w[f'{att}receptance.weight'], f'{att}receptance.weight' + ow = w[f'{att}output.weight'], f'{att}output.weight' + + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + + kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x + krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x + kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x + kry = w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x + vmx = w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x + vrx = w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x + vmy = w[f'{att}value.weight_my'] if wtype == torch.uint8 else x + vry = w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x + rmx = w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x + rry = w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x + omx = w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x + orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x + omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x + ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3], + w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], + w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], + w[f'{att}time_decay'], w[f'{att}time_first'], + kw, vw, rw, ow, + kmx, krx, kmy, kry, + vmx, vrx, vmy, vry, + rmx, rrx, rmy, rry, + omx, orx, omy, ory, + ) + if dd.stream: + del kw, vw, rw, ow + + kw = w[f'{ffn}key.weight'], f'{ffn}key.weight' + vw = w[f'{ffn}value.weight'], f'{ffn}value.weight' + rw = w[f'{ffn}receptance.weight'], f'{ffn}receptance.weight' + + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x + krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x + kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x + kry = w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x + vmx = w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x + vrx = w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x + vmy = w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x + vry = w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x + rmx = w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x + rry = w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+4] = FFN( + x, state[i*5+4], + w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'], + w[f'{ffn}time_mix_k'], w[f'{ffn}time_mix_r'], + kw, vw, rw, + kmx, krx, kmy, kry, + vmx, vrx, vmy, vry, + rmx, rrx, rmy, rry, + ) + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i+1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) + if w['head.weight'].dtype != torch.uint8: + x = x @ w['head.weight'] + # FIXME: WTF check load model and see why + # (Pdb++) w_quant.scales.shape + # torch.Size([1, 50277]) + # (Pdb++) zeros.shape + # torch.Size([1, 12568, 4]) + # (Pdb++) 12568 * 4 + # 50272 + # x = self._trigger_gptq(x, w["head.weight"], "head.weight") + else: + if seq_mode and full_output: + x = self.mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) + else: + x = self.mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) + + return x.float(), state \ No newline at end of file diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index 83d0dfd7..38b0f070 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -1,5 +1,4 @@ - -from rwkv.model import RWKV +from myRWKV import RWKV from gptq.datautils import * from gptq.quant import Quantizer, quantize @@ -9,6 +8,7 @@ import time import math import re +from gptq.gptq import QuantLinear_custom WBITS = 8 GROUPSIZE = -1 @@ -155,23 +155,26 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) ### end GPTQ ### begin GPTQ_RWKV - def __init__(self, model, strategy): - super().__init__(model, strategy) + def __init__(self, checkpoint_path, strategy): + super().__init__(checkpoint_path, strategy) for i in range(self.args.n_layer): assert self.strategy[i].device == "cpu" - #TODO: Change to match my implem def _fill_subset(self, layer_id): # Keep only layer within block layer_id - is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') + + #TODO: Uncomment me when quantizing 1 layer works + # is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') + is_weight = re.compile("blocks.0.att.key.weight") for name in self.w.keys(): if is_weight.match(name): if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now self.subset[name] = self.w[name] - is_last_layer = (layer_id == self.args.n_layer - 1) - if is_last_layer: - self.subset["head.weight"] = self.w["head.weight"] + # TODO: Uncomment me when quantizing 1 layer works + # is_last_layer = (layer_id == self.args.n_layer - 1) + # if is_last_layer: + # self.subset["head.weight"] = self.w["head.weight"] return self.subset @@ -196,7 +199,7 @@ def fasterquant(self, layer_id, quantizers): print(layer_id, name) print('Quantizing ...') scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) - quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + quantizers[name] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) ### end GPTQ_RWKV @@ -208,12 +211,16 @@ def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(rx @ rw.weight) - rw.add_batch(rx) + # r = torch.sigmoid(rx @ rw.weight) + r = torch.sigmoid(rx @ rw) + # rw.add_batch(rx) + k = (kx @ kw.weight).float() kw.add_batch(kx) - v = (vx @ vw.weight).float() - vw.add_batch(vx) + + # v = (vx @ vw.weight).float() + v = (vx @ vw).float() + # vw.add_batch(vx) ww = t_first + k p = torch.maximum(pp, ww) @@ -225,8 +232,9 @@ def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t e1 = torch.exp(ww - p) e2 = torch.exp(k - p) - out = (r * wkv) @ ow.weight - ow.add_batch((r * wkv)) + # out = (r * wkv) @ ow.weight + out = (r * wkv) @ ow + # ow.add_batch(r * wkv) return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): @@ -236,12 +244,14 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(rx @ rw.weight) - rw.add_batch(rx) + # r = torch.sigmoid(rx @ rw.weight) + r = torch.sigmoid(rx @ rw) + # rw.add_batch(rx) k = (kx @ kw.weight).float() kw.add_batch(kx) - v = (vx @ vw.weight).float() - vw.add_batch(vx) + # v = (vx @ vw.weight).float() + v = (vx @ vw).float() + # vw.add_batch(vx) T = x.shape[0] for t in range(T): @@ -259,8 +269,9 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p - out = (r * sx) @ ow.weight - ow.add_batch((r * sx)) + # out = (r * sx) @ ow.weight + out = (r * sx) @ ow + # ow.add_batch(r * sx) return x + out, xx[-1,:], aa, bb, pp def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): @@ -268,12 +279,15 @@ def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(rx @ rw.weight) - rw.add_batch(rx) - vx = torch.square(torch.relu(kx @ kw.weight)) - kw.add_batch(kx) - out = r * (vx @ vw.weight) - vw.add_batch(vx) + # r = torch.sigmoid(rx @ rw.weight) + r = torch.sigmoid(rx @ rw) + # rw.add_batch(rx) + # vx = torch.square(torch.relu(kx @ kw.weight)) + vx = torch.square(torch.relu(kx @ kw)) + # kw.add_batch(kx) + # out = r * (vx @ vw.weight) + out = r * (vx @ vw) + # vw.add_batch(vx) return x + out, xx def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): @@ -282,12 +296,15 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(rx @ rw.weight) - rw.add_batch(rx) - vx = torch.square(torch.relu(kx @ kw.weight)) - kw.add_batch(kx) - out = r * (vx @ vw.weight) - vw.add_batch(vx) + # r = torch.sigmoid(rx @ rw.weight) + r = torch.sigmoid(rx @ rw) + # rw.add_batch(rx) + # vx = torch.square(torch.relu(kx @ kw.weight)) + vx = torch.square(torch.relu(kx @ kw)) + # kw.add_batch(kx) + # out = r * (vx @ vw.weight) + out = r * (vx @ vw) + # vw.add_batch(vx) return x + out, xx[-1,:] def forward_block(self, x, state, i, seq_mode, full_output=False): @@ -326,9 +343,12 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x = x.to(dtype=atype, device=dev) kw = self.gptq[f'{att}key.weight'] - vw = self.gptq[f'{att}value.weight'] - rw = self.gptq[f'{att}receptance.weight'] - ow = self.gptq[f'{att}output.weight'] + # vw = self.gptq[f'{att}value.weight'] + vw = self.w[f'{att}value.weight'] + # rw = self.gptq[f'{att}receptance.weight'] + rw = self.w[f'{att}receptance.weight'] + # ow = self.gptq[f'{att}output.weight'] + ow = self.w[f'{att}output.weight'] if dd.stream: kw = kw.to(device=dev, non_blocking=True) @@ -368,9 +388,12 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): if dd.stream: del kw, vw, rw, ow - kw = self.gptq[f'{ffn}key.weight'] - vw = self.gptq[f'{ffn}value.weight'] - rw = self.gptq[f'{ffn}receptance.weight'] + # kw = self.gptq[f'{ffn}key.weight'] + kw = self.w[f'{ffn}key.weight'] + # vw = self.gptq[f'{ffn}value.weight'] + vw = self.w[f'{ffn}value.weight'] + # rw = self.gptq[f'{ffn}receptance.weight'] + rw = self.w[f'{ffn}receptance.weight'] if dd.stream: kw = kw.to(device=dev, non_blocking=True) @@ -416,9 +439,10 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) if self.w['head.weight'].dtype != torch.uint8: - x = x @ self.gptq['head.weight'].weight - self.gptq['head.weight'].add_batch(x) - self.gptq['head.weight'].deactivate_add_batch_call = True + x = x @ self.w['head.weight'] + #TODO: uncommenbt me when quantizing 1 layer work + # x = x @ self.gptq['head.weight'].weight + # self.gptq['head.weight'].add_batch(x) return x.float() @@ -434,7 +458,8 @@ def quantize_gptq_custom(model, tokens): outs = torch.zeros_like(inps) quantizers = {} - for layer_id in range(model.args.n_layer): + # for layer_id in range(model.args.n_layer): + for layer_id in range(1): print(f"Quantizing layer {layer_id} ...") @@ -450,8 +475,6 @@ def quantize_gptq_custom(model, tokens): for gptq_layer in model.gptq.values(): gptq_layer.deactivate_add_batch_call = True - tmp = model.w["blocks.0.att.key.weight"] - model.fasterquant(layer_id, quantizers) for i in range(nsamples): @@ -472,11 +495,50 @@ def quantize_gptq_custom(model, tokens): return quantizers +def model_pack_custom(model, quantizers, wbits, groupsize): + + weights = OrderedDict() + + # is_weight = re.compile('^blocks\.\d+(\.[a-z]+[0-9]?)*\.weight$') + # for name in model.w.keys(): + # if is_weight.match(name): + # if len(model.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now + # weights[name] = model.w[name] + + for name in quantizers.keys(): + if len(model.w[name].shape) == 1: continue + weights[name] = model.w[name] + + #TODO: uncommenbt me when done + # weights["head.weight"] = model.w["head.weight"] + + assert set(quantizers) - set(weights) == set(), "Quantizers and weights don't match" + assert set(weights) - set(quantizers) == set(), "Quantizers and weights don't match" + + # Replace layer by QuantLinear + model.w_quant = {} + for key, value in model.w.items(): + if key in quantizers.keys(): + #FIXME: So far, we don't quantize ln0 et ln1 (which have bias) because 1d tensors + bias = None + model.w_quant[key] = QuantLinear_custom(wbits, groupsize, value.shape[0], value.shape[1], bias) + + # Fill QuantLinear + print('Packing ...') + for key in model.w_quant.keys(): + _, scale,zero,g_idx = quantizers[key] + bias = None + model.w_quant[key].pack(weights[key], bias, scale, zero, g_idx) + print('Done.') + return model + + if __name__ == "__main__": + model_ref = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') - NSAMPLES=2 + NSAMPLES=1 HIDDEN_SIZE=model.args.n_embd SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m @@ -491,10 +553,26 @@ def quantize_gptq_custom(model, tokens): tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) print("tokens.shape", tokens.shape) - - import pdb; pdb.set_trace() - # quantizers = quantize_gptq_custom(model, tokens) - - # model_pack_custom(model, quantizers, WBITS, GROUPSIZE) - # torch.save(model.state_dict(), "model_quantized_custom.pt") - # print("Done Custom GPTQ") \ No newline at end of file + + quantizers = quantize_gptq_custom(model, tokens) + model = model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + torch.save([model.w_quant, model.w], "1sample_quantized.pth") + + # Make sure only 1 layer was quantized + assert len(model.w_quant.keys()) == 1 and "blocks.0.att.key.weight" in model.w_quant.keys() + + for (ref_key, ref_value), (key, value) in zip(model_ref.w.items(), model.w.items()): + if key != "blocks.0.att.key.weight": + assert torch.allclose(ref_value, value, atol=1e-5) + else: + assert not torch.allclose(ref_value, value, atol=1e-5) + + print("Done Custom GPTQ") + + # I have noticed QuantLinear.forward() can be divded in 2 parts: + # 1. Quantize the weights (using info from model.w_quant thanks to QuantLinear.pack()) + # 2. Perform x @ weights + # We can load checkpoint RWKV of base class with model_w (which are quantized but doesnt have the scale, zero info) + # Then, if isinstancce(model, w_quant) exist, we load this dict as well + # Each time the weights are called, we do a trigger() by checking if isinstancce(moded.w_quant is QuantLinear) + # This way, we can reuse RWKV base class with minimal change \ No newline at end of file diff --git a/quantize/tmp_rwkv.py.lprof b/quantize/tmp_rwkv.py.lprof new file mode 100644 index 0000000000000000000000000000000000000000..e245fb51861a6d11d033e6007b556883f28e8491 GIT binary patch literal 426 zcmY+<-7AAp90zdEY?jv2Li4ig*~D{fLR`3!)Q(-C=0aMJoNb3a&t}hg@)AX0+sn;dxe4 z)Bi$gH%_8FBQ#;kLmAzXTQXvM1=qQeHGhL4q8ox38I$A+IwKiXFX&_}r-|uVnj$Nf zW(dB)NVX_e$f71ROYpp@7S|$)Knqn8 zA9qUwU@N#Ayy^#oD8I&lHlX|nw4+!(h`7NHs_c70H`vufj9 Date: Tue, 23 May 2023 09:44:13 +0000 Subject: [PATCH 19/20] fix(ppl): measure ppl using sliding window --- quantize/measure_perplexity.py | 178 +++++++++++++-------------------- 1 file changed, 67 insertions(+), 111 deletions(-) diff --git a/quantize/measure_perplexity.py b/quantize/measure_perplexity.py index 8cdbe4a3..248d5069 100644 --- a/quantize/measure_perplexity.py +++ b/quantize/measure_perplexity.py @@ -1,114 +1,70 @@ -# Measures perplexity and per-token latency of an RWKV model on a given text file. -# Perplexity is defined here as exp() of average cross-entropy loss. -# Usage: python measure_perplexity.py RWKV-4-Pile-169M-20220807-8023.pth wikitext2 2048 - -import os -import time -import pathlib -import argparse -import tokenizers import torch -from typing import List +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +# from myRWKV import RWKV from rwkv.model import RWKV -os.environ['RWKV_JIT_ON'] = '1' -os.environ["RWKV_CUDA_ON"] = '0' - -def parse_args(): - parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') - parser.add_argument('model_path', help='Path to model checkpoint file') - parser.add_argument('dataset_path', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') - parser.add_argument('nsamples', help='How many samples', type=int, default=4096) - return parser.parse_args() - -args = parse_args() - -def get_wikitext2(nsamples): - from datasets import load_dataset - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - - print('Loading 20B tokenizer (RWKV)') - tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' - tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) - - print('Loading text') - test_text: str = "\n\n".join(testdata['text']) - test_tokens = torch.tensor(tokenizer.encode(test_text).ids, dtype=torch.long) - print(f'{len(test_tokens)} test tokens in the text') - - import random - random.seed(42) - # Randomly select a sample of nsamples tokens - i = random.randint(0, len(test_tokens) - nsamples) - return tokenizer, test_tokens[i:i+nsamples] - -def get_loaders(dataset_path, nsamples): - if 'wikitext2' in dataset_path: - return get_wikitext2(nsamples) - else: - # https://github.com/IST-DASLab/gptq/blob/main/datautils.py - raise NotImplementedError("Only wikitext2 is supported for now") - -tokenizer, test_tokens = get_loaders(args.dataset_path, args.nsamples) - -def format_loss(loss: torch.Tensor) -> str: - return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1] - -def format_loss_with_perplexity(loss: torch.Tensor) -> str: - return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}' - -# --- -device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') -# device = torch.device('cpu') - -#TODO: Why is PERPLEXITY SO DAMN HIGH ? -model = RWKV(model=args.model_path, strategy='cuda fp16') - +import torch.nn as nn + +device = "cpu" + +# Model +model = RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') +# Dataset +tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile") +# sentence = "My name is Bob" +# encodings = tokenizer("\n\n".join(sentence), return_tensors='pt') +# ctx_len = 5 +# stride = ctx_len // 2 +# seq_len = encodings.input_ids.size(1) + +test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") +tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile") +encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") +ctx_len = 1024 +stride = ctx_len // 2 +seq_len = encodings.input_ids.size(1) + +nlls = [] +prev_end_loc = 0 logits, state = None, None -loss_sum: torch.Tensor = torch.tensor([0.0], device=device) -loss_count: int = 0 -token_count = len(test_tokens) -run_count = token_count - 1 -# Ignore 20% of the tokens to let the model warmup -ignore_first_n_tokens = int(token_count * 0.2) -start: float = time.time() - -for i in range(run_count): - token: int = test_tokens[i] - target: int = test_tokens[i + 1] - - logits, state = model.forward([token], None if i == 0 else state) - - if ignore_first_n_tokens == 0 or i + 1 >= ignore_first_n_tokens: - losses = torch.tensor([ - torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long, device=device), reduction='none').item() - ] - , device=device) - - loss_sum += losses - loss_count += 1 - - if i % 100 == 0: - avg_loss_so_far = loss_sum / loss_count - - duration: float = time.time() - start - duration_per_token: float = duration / (i + 1) - runs_remaining: int = run_count - i - 1 - duration_remaining: int = int(runs_remaining * duration_per_token) - - print(f'Token #{i}/{token_count}, ' - f'{int(100.0 * i / token_count)}%, ' - f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='') - - if loss_count > 0: - print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}') - else: - print() - -print() -print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token') - -print() -print(f'Model: {os.path.basename(args.model_path)}\n' - f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens\n' - f'Ignored first {ignore_first_n_tokens} tokens\n' - f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}') +loss_fct = nn.CrossEntropyLoss() + +# for begin_loc in tqdm(range(0, seq_len, stride)): +for begin_loc in tqdm(range(0, stride * 3, stride)): + end_loc = min(begin_loc + ctx_len, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + full_logits = torch.zeros((input_ids.size(1), model.w["emb.weight"].shape[0])) + + with torch.no_grad(): + for i in range(input_ids.size(1)): + logits, state = model.forward([input_ids[0, i]], state) + full_logits[i, :] = logits + + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + labels = target_ids + labels = labels.to(full_logits.device) + # Shift so that tokens < n predict n + shift_logits = full_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + neg_log_likelihood = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + +print(f"nlls: {torch.stack(nlls)}") +mean_nll = torch.stack(nlls).mean() +if mean_nll.is_cuda: + mean_nll = mean_nll.cpu().float() +ppl = torch.exp(mean_nll) +print(f"Perplexity: {ppl}") From 3399ef008258b72dedcf496124a6271dc59cdb54 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Fri, 4 Aug 2023 14:09:29 +0000 Subject: [PATCH 20/20] update --- quantize/gptq/datautils.py | 146 ++-------------------- quantize/gptq/gptq.py | 2 +- quantize/gptq/quant.py | 33 ++++- quantize/measure_perplexity.py | 5 +- quantize/myRWKV.py | 52 +++++--- quantize/opt.py | 216 --------------------------------- quantize/tmp_rwkv.py | 148 +++++++++++----------- quantize/tmp_rwkv.py.lprof | Bin 426 -> 0 bytes 8 files changed, 148 insertions(+), 454 deletions(-) delete mode 100644 quantize/opt.py delete mode 100644 quantize/tmp_rwkv.py.lprof diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py index 166e6547..b8f12ce7 100644 --- a/quantize/gptq/datautils.py +++ b/quantize/gptq/datautils.py @@ -25,14 +25,17 @@ def get_wikitext2(nsamples, seed, seqlen, model): trainenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(traindata['text'])).ids, dtype=torch.long), 0) testenc = torch.unsqueeze(torch.tensor(tokenizer.encode("\n\n".join(testdata['text'])).ids, dtype=torch.long), 0) else: - traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + # testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') print('Loading tokenizer') from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') - testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + # trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + # testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + sentence = "My name is Ferdinand and I live in France" + trainenc = tokenizer("\n\n".join(sentence), return_tensors='pt') + testenc = tokenizer("\n\n".join(sentence), return_tensors='pt') random.seed(seed) trainloader = [] @@ -48,141 +51,6 @@ def get_wikitext2(nsamples, seed, seqlen, model): trainloader.append((inp, tar)) return trainloader, testenc -def get_ptb(nsamples, seed, seqlen, model): - raise NotImplementedError('PTB not implemented yet') - # from datasets import load_dataset - # traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - # valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') - - # from transformers import AutoTokenizer - # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - # trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') - # testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') - - # import random - # random.seed(seed) - # trainloader = [] - # for _ in range(nsamples): - # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - # j = i + seqlen - # inp = trainenc.input_ids[:, i:j] - # tar = inp.clone() - # tar[:, :-1] = -100 - # trainloader.append((inp, tar)) - # return trainloader, testenc - -def get_c4(nsamples, seed, seqlen, model): - raise NotImplementedError('C4 not implemented yet') - # from datasets import load_dataset - # traindata = load_dataset( - # 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' - # ) - # valdata = load_dataset( - # 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' - # ) - - # from transformers import AutoTokenizer - # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - - # import random - # random.seed(seed) - # trainloader = [] - # for _ in range(nsamples): - # while True: - # i = random.randint(0, len(traindata) - 1) - # trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') - # if trainenc.input_ids.shape[1] >= seqlen: - # break - # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - # j = i + seqlen - # inp = trainenc.input_ids[:, i:j] - # tar = inp.clone() - # tar[:, :-1] = -100 - # trainloader.append((inp, tar)) - - # import random - # random.seed(0) - # valenc = [] - # for _ in range(256): - # while True: - # i = random.randint(0, len(valdata) - 1) - # tmp = tokenizer(valdata[i]['text'], return_tensors='pt') - # if tmp.input_ids.shape[1] >= seqlen: - # break - # i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) - # j = i + seqlen - # valenc.append(tmp.input_ids[:, i:j]) - # valenc = torch.hstack(valenc) - # class TokenizerWrapper: - # def __init__(self, input_ids): - # self.input_ids = input_ids - # valenc = TokenizerWrapper(valenc) - - # return trainloader, valenc - -def get_ptb_new(nsamples, seed, seqlen, model): - raise NotImplementedError('PTB not implemented yet') - # from datasets import load_dataset - # traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - # testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') - - # from transformers import AutoTokenizer - # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - # trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') - # testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') - - # import random - # random.seed(seed) - # trainloader = [] - # for _ in range(nsamples): - # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - # j = i + seqlen - # inp = trainenc.input_ids[:, i:j] - # tar = inp.clone() - # tar[:, :-1] = -100 - # trainloader.append((inp, tar)) - # return trainloader, testenc - -def get_c4_new(nsamples, seed, seqlen, model): - raise NotImplementedError('C4 not implemented yet') - # from datasets import load_dataset - # traindata = load_dataset( - # 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' - # ) - # valdata = load_dataset( - # 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' - # ) - - # from transformers import AutoTokenizer - # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - - # import random - # random.seed(seed) - # trainloader = [] - # for _ in range(nsamples): - # while True: - # i = random.randint(0, len(traindata) - 1) - # trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') - # if trainenc.input_ids.shape[1] >= seqlen: - # break - # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - # j = i + seqlen - # inp = trainenc.input_ids[:, i:j] - # tar = inp.clone() - # tar[:, :-1] = -100 - # trainloader.append((inp, tar)) - - # valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') - # valenc = valenc.input_ids[:, :(256 * seqlen)] - - # class TokenizerWrapper: - # def __init__(self, input_ids): - # self.input_ids = input_ids - # valenc = TokenizerWrapper(valenc) - - # return trainloader, valenc - - def get_loaders( dataset_name, nsamples, seed, seqlen, model ): diff --git a/quantize/gptq/gptq.py b/quantize/gptq/gptq.py index 4cb03c85..62ed697f 100644 --- a/quantize/gptq/gptq.py +++ b/quantize/gptq/gptq.py @@ -5,7 +5,7 @@ import torch.nn as nn import transformers -from quant import * +from .quant import * DEBUG = False diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py index f584325b..c3a9a7ee 100644 --- a/quantize/gptq/quant.py +++ b/quantize/gptq/quant.py @@ -128,6 +128,7 @@ def enabled(self): def ready(self): return torch.all(self.scale != 0) + try: import quant_cuda is_cuda = True @@ -322,13 +323,25 @@ def forward(self, x): weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - - weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx])) + num_itr = self.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])) + else: + num_dim = self.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) out = torch.matmul(x.half(), weights) out = out.reshape(out_shape) out = out + self.bias if self.bias is not None else out return out + class QuantLinear(nn.Module): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): super().__init__() @@ -491,9 +504,21 @@ def forward(self, x): weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - - weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx])) + num_itr = self.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])) + else: + num_dim = self.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) out = torch.matmul(x.half(), weights) out = out.reshape(out_shape) out = out + self.bias if self.bias is not None else out return out + diff --git a/quantize/measure_perplexity.py b/quantize/measure_perplexity.py index 248d5069..c1759f15 100644 --- a/quantize/measure_perplexity.py +++ b/quantize/measure_perplexity.py @@ -2,14 +2,15 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset -# from myRWKV import RWKV -from rwkv.model import RWKV +from myRWKV import RWKV +# from rwkv.model import RWKV import torch.nn as nn device = "cpu" # Model model = RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') +# model = RWKV("./1sample_quantized.pth", strategy='cpu fp32') # Dataset tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile") # sentence = "My name is Bob" diff --git a/quantize/myRWKV.py b/quantize/myRWKV.py index 524b9481..7e7e64b7 100644 --- a/quantize/myRWKV.py +++ b/quantize/myRWKV.py @@ -99,6 +99,7 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = prxxx(f'Loading {args.MODEL_NAME} ...') with torch.no_grad(): obj = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first + import pdb; pdb.set_trace() if isinstance(obj, list): # GPTQ self.w_quant = obj[0] self.w = obj[1] @@ -349,11 +350,11 @@ def mm8_one(self, x, w, mx, rx, my, ry): N, M = w.shape[0], w.shape[1] return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) else: - @MyFunction + # @MyFunction def mm8_seq(self, x, w, mx, rx, my, ry): return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) - @MyFunction + # @MyFunction def mm8_one(self, x, w, mx, rx, my, ry): return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) @@ -366,10 +367,10 @@ def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr rx = xx * r_mix + sx * (1 - r_mix) # r = torch.sigmoid(rx @ rw) - r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) # vx = torch.square(torch.relu(kx @ kw)) - vx = torch.square(torch.relu(self._trigger_gptq(kx, weight=kw[0], name=kw[1]))) # out = r * (vx @ vw) + r = torch.sigmoid(self._trigger_gptq(rx, weight=rw[0], name=rw[1])) + vx = torch.square(torch.relu(self._trigger_gptq(kx, weight=kw[0], name=kw[1]))) out = r * (self._trigger_gptq(vx, weight=vw[0], name=vw[1])) return x + out, xx @@ -501,7 +502,7 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p - # out = (r * sx) @ ow + out = (r * sx) @ ow out = self._trigger_gptq(r * sx, weight=ow[0], name=ow[1]) return x + out, xx[-1,:], aa, bb, pp @@ -590,17 +591,29 @@ def _trigger_gptq(self, x, weight, name): torch.bitwise_and(zeros, (2 ** w_quant.bits) - 1, out=zeros) zeros = zeros + 1 - # if name == "head.weight": - # import pdb; pdb.set_trace() zeros = zeros.reshape(w_quant.scales.shape) new_weight = torch.bitwise_right_shift(torch.unsqueeze(w_quant.qweight, 1).expand(-1, 32 // w_quant.bits, -1), w_quant.wf.unsqueeze(-1)).to(x.device, torch.int16 if w_quant.bits == 8 else torch.int8) torch.bitwise_and(new_weight,(2 ** w_quant.bits) - 1, out=new_weight) - new_weight = new_weight.reshape(new_weight.shape[0] * new_weight.shape[1], new_weight.shape[2]) - new_weight = (w_quant.scales[w_quant.g_idx.long()] * (new_weight - zeros[w_quant.g_idx.long()])) - out = torch.matmul(x.half(), new_weight) - out = out.reshape(out_shape) + new_weight = new_weight.reshape(new_weight.shape[0] * new_weight.shape[1], new_weight.shape[2]) + num_itr = w_quant.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (w_quant.scales[w_quant.g_idx.long()] * (new_weight - zeros[w_quant.g_idx.long()])) + else: + num_dim = w_quant.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = w_quant.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = new_weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = w_quant.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) + + # out = torch.matmul(x.half(), weights) + out = torch.matmul(x, weights.to(x.dtype)) + out = out.reshape(out_shape) return out else: return x @ weight @@ -621,15 +634,16 @@ def forward(self, tokens, state, full_output=False): state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() - + seq_mode = len(tokens) > 1 x = w['emb.weight'][tokens if seq_mode else tokens[0]] - + for i in range(args.n_layer): bbb = f'blocks.{i}.' att = f'blocks.{i}.att.' ffn = f'blocks.{i}.ffn.' + dd = self.strategy[i] dev = dd.device atype = dd.atype @@ -650,7 +664,12 @@ def forward(self, tokens, state, full_output=False): vw = w[f'{att}value.weight'], f'{att}value.weight' rw = w[f'{att}receptance.weight'], f'{att}receptance.weight' ow = w[f'{att}output.weight'], f'{att}output.weight' - + + # kw = w[f'{att}key.weight'] + # vw = w[f'{att}value.weight'] + # rw = w[f'{att}receptance.weight'] + # ow = w[f'{att}output.weight'] + if dd.stream: kw = kw.to(device=dev, non_blocking=True) vw = vw.to(device=dev, non_blocking=True) @@ -673,6 +692,7 @@ def forward(self, tokens, state, full_output=False): orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3], w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], @@ -690,6 +710,9 @@ def forward(self, tokens, state, full_output=False): kw = w[f'{ffn}key.weight'], f'{ffn}key.weight' vw = w[f'{ffn}value.weight'], f'{ffn}value.weight' rw = w[f'{ffn}receptance.weight'], f'{ffn}receptance.weight' + # kw = w[f'{ffn}key.weight'] + # vw = w[f'{ffn}value.weight'] + # rw = w[f'{ffn}receptance.weight'] if dd.stream: kw = kw.to(device=dev, non_blocking=True) @@ -743,5 +766,4 @@ def forward(self, tokens, state, full_output=False): x = self.mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) else: x = self.mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) - return x.float(), state \ No newline at end of file diff --git a/quantize/opt.py b/quantize/opt.py deleted file mode 100644 index 16a776f5..00000000 --- a/quantize/opt.py +++ /dev/null @@ -1,216 +0,0 @@ -import time - -import torch -import torch.nn as nn - -from gptq.gptq import * -from gptq.modelutils import * -from gptq.quant import * - - -def get_opt(model): - import torch - def skip(*args, **kwargs): - pass - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - from transformers import OPTForCausalLM - model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') - model.seqlen = model.config.max_position_embeddings - return model - -@torch.no_grad() -def opt_sequential(model, dataloader, dev): - print('Starting ...') - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.decoder.layers - - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) - model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.to(dev) - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.to(dev) - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev - ) - cache = {'i': 0, 'attention_mask': None} - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - raise ValueError - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].to(dev)) - except ValueError: - pass - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() - model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() - if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: - model.model.decoder.project_out = model.model.decoder.project_out.cpu() - if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: - model.model.decoder.project_in = model.model.decoder.project_in.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - - print('Ready.') - - quantizers = {} - for i in range(len(layers)): - layer = layers[i].to(dev) - - subset = find_layers(layer) - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure( - args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits - ) - - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - return tmp - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - for h in handles: - h.remove() - - for name in subset: - print(i, name) - print('Quantizing ...') - gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer - gptq[name].free() - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - - model.config.use_cache = use_cache - - return quantizers - -if __name__ == '__main__': - import argparse - from gptq.datautils import * - - parser = argparse.ArgumentParser() - - parser.add_argument( - 'model', type=str, - help='OPT model to load; pass `facebook/opt-X`.' - ) - parser.add_argument( - 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], - help='Where to extract calibration data from.' - ) - parser.add_argument( - '--seed', - type=int, default=0, help='Seed for sampling the calibration data.' - ) - parser.add_argument( - '--nsamples', type=int, default=128, - help='Number of calibration data samples.' - ) - parser.add_argument( - '--percdamp', type=float, default=.01, - help='Percent of the average Hessian diagonal to use for dampening.' - ) - parser.add_argument( - '--nearest', action='store_true', - help='Whether to run the RTN baseline.' - ) - parser.add_argument( - '--wbits', type=int, default=16, choices=[2, 3, 4, 16], - help='#bits to use for quantization; use 16 for evaluating base model.' - ) - parser.add_argument( - '--trits', action='store_true', - help='Whether to use trits for quantization.' - ) - parser.add_argument( - '--groupsize', type=int, default=-1, - help='Groupsize to use for quantization; default uses full row.' - ) - parser.add_argument( - '--sym', action='store_true', - help='Whether to perform symmetric quantization.' - ) - parser.add_argument( - '--save', type=str, default='', - help='Save quantized checkpoint under this name.' - ) - parser.add_argument( - '--load', type=str, default='', - help='Load quantized model.' - ) - parser.add_argument( - '--benchmark', type=int, default=0, - help='Number of tokens to use for benchmarking.' - ) - parser.add_argument( - '--check', action='store_true', - help='Whether to compute perplexity during benchmarking for verification.' - ) - parser.add_argument( - '--new-eval', action='store_true', - help='Whether to use the new PTB and C4 eval.' - ) - parser.add_argument( - '--faster-kernel', action='store_true', - help='Whether to use the new faster kernel for benchmarking.' - ) - parser.add_argument( - '--act-order', action='store_true', - help='Whether to apply the activation order GPTQ heuristic' - ) - - args = parser.parse_args() - - model = get_opt(args.model) - model.eval() - - dataloader, testloader = get_loaders( - args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen - ) - - print(f'{len(dataloader)} train tokens in the text') - print(f'{len(testloader)} test tokens in the text') - - - if args.wbits < 16 and not args.nearest: - tick = time.time() - quantizers = opt_sequential(model, dataloader, DEV) - print(time.time() - tick) - - # if args.save: - # opt_pack3(model, quantizers) - # torch.save(model.state_dict(), args.save) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index 38b0f070..d550e196 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -52,6 +52,9 @@ def add_batch(self, inp): self.H += inp.matmul(inp.t()) def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): + # if self.name == "blocks.0.ffn.value.weight": + # if self.name == "blocks.0.ffn.key.weight": + # import pdb; pdb.set_trace() W = self.weight.data.clone() # OLD: Need to transpose here, same reason as in __init__ with self.columns # UPDATE: no need to tranpose as we already transpose in my_linear() @@ -120,7 +123,6 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) ).flatten() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d ** 2 - err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) Err1[:, i] = err1 @@ -164,9 +166,9 @@ def _fill_subset(self, layer_id): # Keep only layer within block layer_id #TODO: Uncomment me when quantizing 1 layer works - # is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') - is_weight = re.compile("blocks.0.att.key.weight") - for name in self.w.keys(): + is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') + # is_weight = re.compile("blocks.0.att.key.weight") + for name in self.w.keys(): if is_weight.match(name): if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now self.subset[name] = self.w[name] @@ -211,16 +213,15 @@ def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - # r = torch.sigmoid(rx @ rw.weight) - r = torch.sigmoid(rx @ rw) - # rw.add_batch(rx) + # r = torch.sigmoid(rx @ rw) + # v = (vx @ vw).float() + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) k = (kx @ kw.weight).float() kw.add_batch(kx) - - # v = (vx @ vw.weight).float() - v = (vx @ vw).float() - # vw.add_batch(vx) + v = (vx @ vw.weight).float() + vw.add_batch(vx) ww = t_first + k p = torch.maximum(pp, ww) @@ -232,9 +233,9 @@ def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t e1 = torch.exp(ww - p) e2 = torch.exp(k - p) - # out = (r * wkv) @ ow.weight - out = (r * wkv) @ ow - # ow.add_batch(r * wkv) + # out = (r * wkv) @ ow + out = (r * wkv) @ ow.weight + ow.add_batch(r * wkv) return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): @@ -244,14 +245,16 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - # r = torch.sigmoid(rx @ rw.weight) - r = torch.sigmoid(rx @ rw) - # rw.add_batch(rx) + # r = torch.sigmoid(rx @ rw) + # k = (kx @ kw).float() + # v = (vx @ vw).float() + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) k = (kx @ kw.weight).float() kw.add_batch(kx) - # v = (vx @ vw.weight).float() - v = (vx @ vw).float() - # vw.add_batch(vx) + v = (vx @ vw.weight).float() + vw.add_batch(vx) T = x.shape[0] for t in range(T): @@ -269,9 +272,9 @@ def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p - # out = (r * sx) @ ow.weight - out = (r * sx) @ ow - # ow.add_batch(r * sx) + # out = (r * sx) @ ow + out = (r * sx) @ ow.weight + ow.add_batch(r * sx) return x + out, xx[-1,:], aa, bb, pp def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): @@ -279,15 +282,17 @@ def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - # r = torch.sigmoid(rx @ rw.weight) - r = torch.sigmoid(rx @ rw) - # rw.add_batch(rx) - # vx = torch.square(torch.relu(kx @ kw.weight)) - vx = torch.square(torch.relu(kx @ kw)) - # kw.add_batch(kx) - # out = r * (vx @ vw.weight) - out = r * (vx @ vw) - # vw.add_batch(vx) + # r = torch.sigmoid(rx @ rw) + # vx = torch.square(torch.relu(kx @ kw)) + # out = r * (vx @ vw) + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + vx = torch.square(torch.relu(kx @ kw.weight)) + kw.add_batch(kx) + out = r * (vx @ vw.weight) + vw.add_batch(vx) + return x + out, xx def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): @@ -296,15 +301,16 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - # r = torch.sigmoid(rx @ rw.weight) - r = torch.sigmoid(rx @ rw) - # rw.add_batch(rx) - # vx = torch.square(torch.relu(kx @ kw.weight)) - vx = torch.square(torch.relu(kx @ kw)) - # kw.add_batch(kx) - # out = r * (vx @ vw.weight) - out = r * (vx @ vw) - # vw.add_batch(vx) + # r = torch.sigmoid(rx @ rw) + # vx = torch.square(torch.relu(kx @ kw)) + # out = r * (vx @ vw) + + r = torch.sigmoid(rx @ rw.weight) + rw.add_batch(rx) + vx = torch.square(torch.relu(kx @ kw.weight)) + kw.add_batch(kx) + out = r * (vx @ vw.weight) + vw.add_batch(vx) return x + out, xx[-1,:] def forward_block(self, x, state, i, seq_mode, full_output=False): @@ -343,12 +349,14 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): x = x.to(dtype=atype, device=dev) kw = self.gptq[f'{att}key.weight'] - # vw = self.gptq[f'{att}value.weight'] - vw = self.w[f'{att}value.weight'] - # rw = self.gptq[f'{att}receptance.weight'] - rw = self.w[f'{att}receptance.weight'] - # ow = self.gptq[f'{att}output.weight'] - ow = self.w[f'{att}output.weight'] + vw = self.gptq[f'{att}value.weight'] + rw = self.gptq[f'{att}receptance.weight'] + ow = self.gptq[f'{att}output.weight'] + + # kw = self.w[f'{att}key.weight'] + # vw = self.w[f'{att}value.weight'] + # rw = self.w[f'{att}receptance.weight'] + # ow = self.w[f'{att}output.weight'] if dd.stream: kw = kw.to(device=dev, non_blocking=True) @@ -388,12 +396,13 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): if dd.stream: del kw, vw, rw, ow - # kw = self.gptq[f'{ffn}key.weight'] - kw = self.w[f'{ffn}key.weight'] - # vw = self.gptq[f'{ffn}value.weight'] - vw = self.w[f'{ffn}value.weight'] - # rw = self.gptq[f'{ffn}receptance.weight'] - rw = self.w[f'{ffn}receptance.weight'] + # kw = self.w[f'{ffn}key.weight'] + # vw = self.w[f'{ffn}value.weight'] + # rw = self.w[f'{ffn}receptance.weight'] + + kw = self.gptq[f'{ffn}key.weight'] + vw = self.gptq[f'{ffn}value.weight'] + rw = self.gptq[f'{ffn}receptance.weight'] if dd.stream: kw = kw.to(device=dev, non_blocking=True) @@ -412,6 +421,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False): rrx = self.w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x rmy = self.w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x rry = self.w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+4] = FFN( x=x, sx=state[i*5+4], ln_w=self.w[f'{bbb}ln2.weight'], ln_b=self.w[f'{bbb}ln2.bias'], @@ -466,8 +476,7 @@ def quantize_gptq_custom(model, tokens): model.alloc_gptq(layer_id) for i in range(nsamples): - #TODO: Are outs value normal ? (they look almost all the same) - if not is_last_layer(layer_id): + if not is_last_layer(layer_id): outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) else: _ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode) @@ -476,6 +485,7 @@ def quantize_gptq_custom(model, tokens): gptq_layer.deactivate_add_batch_call = True model.fasterquant(layer_id, quantizers) + # model.gptq["blocks.0.ffn.value.weight"].weight for i in range(nsamples): if not is_last_layer(layer_id): @@ -532,16 +542,15 @@ def model_pack_custom(model, quantizers, wbits, groupsize): print('Done.') return model - if __name__ == "__main__": - model_ref = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') - NSAMPLES=1 + NSAMPLES=128 HIDDEN_SIZE=model.args.n_embd SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m + print("Loading data ...") train_tokens, test_tokens = get_loaders( dataset_name="wikitext2", nsamples=NSAMPLES, @@ -551,28 +560,13 @@ def model_pack_custom(model, quantizers, wbits, groupsize): ) tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) - tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) print("tokens.shape", tokens.shape) + print("Quantizing ...") quantizers = quantize_gptq_custom(model, tokens) model = model_pack_custom(model, quantizers, WBITS, GROUPSIZE) + print("Saving ...") torch.save([model.w_quant, model.w], "1sample_quantized.pth") - # Make sure only 1 layer was quantized - assert len(model.w_quant.keys()) == 1 and "blocks.0.att.key.weight" in model.w_quant.keys() - - for (ref_key, ref_value), (key, value) in zip(model_ref.w.items(), model.w.items()): - if key != "blocks.0.att.key.weight": - assert torch.allclose(ref_value, value, atol=1e-5) - else: - assert not torch.allclose(ref_value, value, atol=1e-5) - - print("Done Custom GPTQ") - - # I have noticed QuantLinear.forward() can be divded in 2 parts: - # 1. Quantize the weights (using info from model.w_quant thanks to QuantLinear.pack()) - # 2. Perform x @ weights - # We can load checkpoint RWKV of base class with model_w (which are quantized but doesnt have the scale, zero info) - # Then, if isinstancce(model, w_quant) exist, we load this dict as well - # Each time the weights are called, we do a trigger() by checking if isinstancce(moded.w_quant is QuantLinear) - # This way, we can reuse RWKV base class with minimal change \ No newline at end of file + + print("Done Custom GPTQ") \ No newline at end of file diff --git a/quantize/tmp_rwkv.py.lprof b/quantize/tmp_rwkv.py.lprof deleted file mode 100644 index e245fb51861a6d11d033e6007b556883f28e8491..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 426 zcmY+<-7AAp90zdEY?jv2Li4ig*~D{fLR`3!)Q(-C=0aMJoNb3a&t}hg@)AX0+sn;dxe4 z)Bi$gH%_8FBQ#;kLmAzXTQXvM1=qQeHGhL4q8ox38I$A+IwKiXFX&_}r-|uVnj$Nf zW(dB)NVX_e$f71ROYpp@7S|$)Knqn8 zA9qUwU@N#Ayy^#oD8I&lHlX|nw4+!(h`7NHs_c70H`vufj9