Skip to content

Commit

Permalink
fix(quantize): add missing part in forward block + support head.weigh…
Browse files Browse the repository at this point in the history
…t quantization
  • Loading branch information
3outeille committed Apr 26, 2023
1 parent 57079e7 commit 8e78f2d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 49 deletions.
11 changes: 6 additions & 5 deletions quantize/gptq/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import tokenizers
import random
from rwkv.model import RWKV

from datasets import load_dataset

Expand All @@ -12,7 +13,7 @@ def set_seed(seed):
torch.random.manual_seed(seed)

def get_wikitext2(nsamples, seed, seqlen, model):
is_rwkv = True if model is None else False
is_rwkv = isinstance(model, RWKV)

if is_rwkv:
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
Expand All @@ -37,11 +38,11 @@ def get_wikitext2(nsamples, seed, seqlen, model):
trainloader = []
shape = trainenc.shape if is_rwkv else trainenc.input_ids.shape
trainenc = trainenc if is_rwkv else trainenc.input_ids
random_idx = [random.randint(0, shape[1] - seqlen - 1) for _ in range(nsamples)]

for _ in range(nsamples):
i = random.randint(0, shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc[:, i:j]
for i in range(nsamples):
j = random_idx[i] + seqlen
inp = trainenc[:, random_idx[i]:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
Expand Down
96 changes: 52 additions & 44 deletions quantize/tmp_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn.functional as F
import torch.nn as nn
import time
import gc
import math
import re

Expand Down Expand Up @@ -132,27 +131,21 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
### begin GPTQ_RWKV
def __init__(self, model, strategy):
super().__init__(model, strategy)
#TODO: add assert to only quantize in CPU FP32 mode
for i in range(self.args.n_layer):
assert self.strategy[i].device == "cpu"

def _fill_subset(self, layer_id):
# Keep only layer within block layer_id
dd = self.strategy[layer_id]
dev = dd.device

for name in self.w.keys():
if re.match(f'^blocks\.{layer_id}\..*\.weight$', name):
tensor = self.w[name]

#TODO: Skip 1D tensors for now
if len(tensor.shape) == 1:
continue

print(f"{name} = {self.w[name].shape}")

if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name):
tensor = tensor.to(device=dev, non_blocking=True)

self.subset[name] = tensor
is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$')
for name in self.w.keys():
if is_weight.match(name):
if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now
self.subset[name] = self.w[name]

is_last_layer = (layer_id == self.args.n_layer - 1)
if is_last_layer:
self.subset["head.weight"] = self.w["head.weight"]


def alloc_gptq(self, layer_id):
self.subset = {}
Expand All @@ -178,7 +171,6 @@ def fasterquant(self, layer_id, quantizers):
self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False)
# self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
quantizers[name] = self.gptq[name].quantizer
# TODO: may be free gptq here to save memory

### end GPTQ_RWKV

Expand Down Expand Up @@ -272,7 +264,7 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr
vw.add_batch(vx)
return x + out, xx[-1,:]

def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False):
def forward_block(self, x, state, i, seq_mode, full_output=False):
with torch.no_grad():
args = self.args

Expand Down Expand Up @@ -312,6 +304,12 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
rw = self.gptq[f'{att}receptance.weight']
ow = self.gptq[f'{att}output.weight']

if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)

kmx = self.w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x
krx = self.w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x
kmy = self.w[f'{att}key.weight_my'] if wtype == torch.uint8 else x
Expand Down Expand Up @@ -341,6 +339,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
omx=omx, orx=orx, omy=omy, ory=ory,
)

# Deactivate add_batch() after quantization is applied
kw.deactivate_add_batch_call = True
vw.deactivate_add_batch_call = True
rw.deactivate_add_batch_call = True
Expand All @@ -352,6 +351,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
kw = self.gptq[f'{ffn}key.weight']
vw = self.gptq[f'{ffn}value.weight']
rw = self.gptq[f'{ffn}receptance.weight']

if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
Expand Down Expand Up @@ -391,43 +391,46 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
if (i+1) % self.RESCALE_LAYER == 0:
x = x / 2

if is_last_layer:
is_last_layer = i == (args.n_layer - 1)

if is_last_layer:
dd = self.strategy[args.n_layer]
x = x[-1,:] if (seq_mode and (not full_output)) else x
x = x.to(dtype=dd.atype, device=dd.device)

#TODO: Add GPTQ support for head & ln_out
#TODO: ln_out.weight is 1D tensor
x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias'])

if self.w['head.weight'].dtype != torch.uint8:
x = x @ self.w['head.weight']
else:
if seq_mode and full_output:
x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
else:
x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
x = x @ self.gptq['head.weight'].weight
self.gptq['head.weight'].add_batch(x)
self.gptq['head.weight'].deactivate_add_batch_call = True

return x.float()

### end RWKV

model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')

NSAMPLES=2
HIDDEN_SIZE=768
SEQLEN=2048 # TODO: this is chosen by the model
HIDDEN_SIZE=model.args.n_embd
SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m

# train_tokens, test_tokens = get_loaders(
# dataset_name="wikitext2",
# nsamples=NSAMPLES,
# seed=42,
# seqlen=SEQLEN,
# model=None
# model=model
# )

# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64)
print("tokens.shape", tokens.shape)

model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')
is_last_layer = [False] * (model.args.n_layer - 1) + [True]
is_last_layer = lambda x: x == (model.args.n_layer - 1)

start_time = time.time()

#TODO: Do the same in GPU side
with torch.no_grad():
Expand All @@ -442,23 +445,28 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
model.alloc_gptq(layer_id)

for j in range(NSAMPLES):
if not is_last_layer[layer_id]:
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
if not is_last_layer(layer_id):
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
else:
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)

model.fasterquant(layer_id, quantizers)

for j in range(NSAMPLES):
if not is_last_layer[layer_id]:
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
if not is_last_layer(layer_id):
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
else:
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)

model.free_gptq()

if not is_last_layer[layer_id]:
# We need to pass the outputs of block i as input of block i+1 (except for last block)
# We need to pass the outputs of block i as input of block i+1 (except for last block)
if not is_last_layer(layer_id):
inps, outs = outs, inps

# TODO: create a function that check if all weights were properly quantized
print("Done")
end_time = time.time()

print(f"Done in {end_time - start_time:.2f} seconds")

# TODO: Do something with quantizers dictionary
# TODO: pack3 save model

0 comments on commit 8e78f2d

Please sign in to comment.