From 8e78f2d405ee9d66270bbda276825680de3ef8c2 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Tue, 25 Apr 2023 13:14:17 +0000 Subject: [PATCH] fix(quantize): add missing part in forward block + support head.weight quantization --- quantize/gptq/datautils.py | 11 +++-- quantize/tmp_rwkv.py | 96 +++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/quantize/gptq/datautils.py b/quantize/gptq/datautils.py index 4ed1e39b..cd296d3c 100644 --- a/quantize/gptq/datautils.py +++ b/quantize/gptq/datautils.py @@ -4,6 +4,7 @@ import pathlib import tokenizers import random +from rwkv.model import RWKV from datasets import load_dataset @@ -12,7 +13,7 @@ def set_seed(seed): torch.random.manual_seed(seed) def get_wikitext2(nsamples, seed, seqlen, model): - is_rwkv = True if model is None else False + is_rwkv = isinstance(model, RWKV) if is_rwkv: traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') @@ -37,11 +38,11 @@ def get_wikitext2(nsamples, seed, seqlen, model): trainloader = [] shape = trainenc.shape if is_rwkv else trainenc.input_ids.shape trainenc = trainenc if is_rwkv else trainenc.input_ids + random_idx = [random.randint(0, shape[1] - seqlen - 1) for _ in range(nsamples)] - for _ in range(nsamples): - i = random.randint(0, shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc[:, i:j] + for i in range(nsamples): + j = random_idx[i] + seqlen + inp = trainenc[:, random_idx[i]:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) diff --git a/quantize/tmp_rwkv.py b/quantize/tmp_rwkv.py index b4697b64..1ed2ad1f 100644 --- a/quantize/tmp_rwkv.py +++ b/quantize/tmp_rwkv.py @@ -7,7 +7,6 @@ import torch.nn.functional as F import torch.nn as nn import time -import gc import math import re @@ -132,27 +131,21 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) ### begin GPTQ_RWKV def __init__(self, model, strategy): super().__init__(model, strategy) - #TODO: add assert to only quantize in CPU FP32 mode + for i in range(self.args.n_layer): + assert self.strategy[i].device == "cpu" def _fill_subset(self, layer_id): # Keep only layer within block layer_id - dd = self.strategy[layer_id] - dev = dd.device - - for name in self.w.keys(): - if re.match(f'^blocks\.{layer_id}\..*\.weight$', name): - tensor = self.w[name] - - #TODO: Skip 1D tensors for now - if len(tensor.shape) == 1: - continue - - print(f"{name} = {self.w[name].shape}") - - if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name): - tensor = tensor.to(device=dev, non_blocking=True) - - self.subset[name] = tensor + is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$') + for name in self.w.keys(): + if is_weight.match(name): + if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now + self.subset[name] = self.w[name] + + is_last_layer = (layer_id == self.args.n_layer - 1) + if is_last_layer: + self.subset["head.weight"] = self.w["head.weight"] + def alloc_gptq(self, layer_id): self.subset = {} @@ -178,7 +171,6 @@ def fasterquant(self, layer_id, quantizers): self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False) # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) quantizers[name] = self.gptq[name].quantizer - # TODO: may be free gptq here to save memory ### end GPTQ_RWKV @@ -272,7 +264,7 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr vw.add_batch(vx) return x + out, xx[-1,:] - def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False): + def forward_block(self, x, state, i, seq_mode, full_output=False): with torch.no_grad(): args = self.args @@ -312,6 +304,12 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) rw = self.gptq[f'{att}receptance.weight'] ow = self.gptq[f'{att}output.weight'] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + kmx = self.w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x krx = self.w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x kmy = self.w[f'{att}key.weight_my'] if wtype == torch.uint8 else x @@ -341,6 +339,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) omx=omx, orx=orx, omy=omy, ory=ory, ) + # Deactivate add_batch() after quantization is applied kw.deactivate_add_batch_call = True vw.deactivate_add_batch_call = True rw.deactivate_add_batch_call = True @@ -352,6 +351,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) kw = self.gptq[f'{ffn}key.weight'] vw = self.gptq[f'{ffn}value.weight'] rw = self.gptq[f'{ffn}receptance.weight'] + if dd.stream: kw = kw.to(device=dev, non_blocking=True) vw = vw.to(device=dev, non_blocking=True) @@ -391,43 +391,46 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) if (i+1) % self.RESCALE_LAYER == 0: x = x / 2 - if is_last_layer: + is_last_layer = i == (args.n_layer - 1) + + if is_last_layer: dd = self.strategy[args.n_layer] x = x[-1,:] if (seq_mode and (not full_output)) else x x = x.to(dtype=dd.atype, device=dd.device) - #TODO: Add GPTQ support for head & ln_out + #TODO: ln_out.weight is 1D tensor x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) + if self.w['head.weight'].dtype != torch.uint8: - x = x @ self.w['head.weight'] - else: - if seq_mode and full_output: - x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) - else: - x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry']) + x = x @ self.gptq['head.weight'].weight + self.gptq['head.weight'].add_batch(x) + self.gptq['head.weight'].deactivate_add_batch_call = True return x.float() ### end RWKV +model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') + NSAMPLES=2 -HIDDEN_SIZE=768 -SEQLEN=2048 # TODO: this is chosen by the model +HIDDEN_SIZE=model.args.n_embd +SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m # train_tokens, test_tokens = get_loaders( # dataset_name="wikitext2", # nsamples=NSAMPLES, # seed=42, # seqlen=SEQLEN, -# model=None +# model=model # ) # tokens = torch.cat([inp for inp, _ in train_tokens], dim=0) tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64) print("tokens.shape", tokens.shape) -model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32') -is_last_layer = [False] * (model.args.n_layer - 1) + [True] +is_last_layer = lambda x: x == (model.args.n_layer - 1) + +start_time = time.time() #TODO: Do the same in GPU side with torch.no_grad(): @@ -442,23 +445,28 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False) model.alloc_gptq(layer_id) for j in range(NSAMPLES): - if not is_last_layer[layer_id]: - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + if not is_last_layer(layer_id): + outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) model.fasterquant(layer_id, quantizers) for j in range(NSAMPLES): - if not is_last_layer[layer_id]: - outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + if not is_last_layer(layer_id): + outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) else: - _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id]) + _ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode) + model.free_gptq() - if not is_last_layer[layer_id]: - # We need to pass the outputs of block i as input of block i+1 (except for last block) + # We need to pass the outputs of block i as input of block i+1 (except for last block) + if not is_last_layer(layer_id): inps, outs = outs, inps -# TODO: create a function that check if all weights were properly quantized -print("Done") \ No newline at end of file +end_time = time.time() + +print(f"Done in {end_time - start_time:.2f} seconds") + +# TODO: Do something with quantizers dictionary +# TODO: pack3 save model \ No newline at end of file