-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathautograd_4bit.py
295 lines (239 loc) · 9.28 KB
/
autograd_4bit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import quant
import torch
import numpy as np
import torch.nn as nn
import time
import transformers
import accelerate
from transformers import GPTJConfig, GPTJForCausalLM, LlamaConfig, LlamaForCausalLM, AutoConfig, AutoModelForCausalLM
from modelutils import find_layers
# Global Buffer
buffer_mat_dic = {}
use_new = True
auto_switch = True
auto_switch_thd = 16
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
return buffer_mat_dic[shape_of_qweight]
def matmul4bit(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.float()
quant.quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_transpose(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == k
perform qweight @ x.T
return y:
"""
assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.float()
quant.quant_cuda.vecquant4transposematmul(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_half(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device)
dtype = x.dtype
quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_transpose_half(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == k
perform qweight @ x.T
return y:
"""
assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device)
dtype = x.dtype
quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def fast_4bit_forward(x, qweight, scales, zeros, bias):
use_new_flag = use_new
if auto_switch:
if x.shape[1] > auto_switch_thd:
use_new_flag = True
else:
use_new_flag = False
if use_new_flag:
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
output = torch.matmul(x, buffer)
else:
output = matmul4bit(x, qweight, scales.float(), zeros.float())
output += bias
return output
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros):
ctx.save_for_backward(qweight, scales, zeros)
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
output = torch.matmul(x, buffer).clone()
return output
@staticmethod
def backward(ctx, grad_output):
qweight, scales, zeros = ctx.saved_tensors
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
grad = torch.matmul(grad_output, buffer.T)
return grad, None, None, None
# Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, infeatures, outfeatures):
super().__init__()
bits = 4
self.in_features = infeatures
self.out_features = outfeatures
self.bits = bits
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
self.register_buffer('scales', torch.empty((outfeatures, 1)))
self.register_buffer('bias', torch.empty(outfeatures))
self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
)
def forward(self, x):
if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
out += self.bias
else:
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
return out
def make_quant_for_4bit_autograd(module, names, name=''):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Converted as Half.')
def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print('Converted as Float.')
def load_auto_model_4bit_low_ram(config_path, model_path, half=False):
print("Loading Auto Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = AutoConfig.from_pretrained(config_path)
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['embed_out', 'lm_head']:
if name in layers:
del layers[name]
# for name in ['lm_head']:
# if name in layers:
# del layers[name]
make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
model.cuda()
model.seqlen = 2048
if half:
model_to_half(model)
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
import transformers
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from modelutils import find_layers
print("Loading Llama Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
model.cuda()
model.seqlen = 2048
if half:
model_to_half(model)
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model
def load_gptj_model_4bit_low_ram(config_path, model_path, half=False):
print("Loading Llama Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = GPTJConfig.from_pretrained(config_path)
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = GPTJForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
model.cuda()
model.seqlen = 2048
if half:
model_to_half(model)
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model