@@ -246,6 +246,167 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
246
246
247
247
return q_weight , scale , zero_point
248
248
249
+ def quant_tensor_k_quant_cpu (data , num_bits = 4 , group_size = 32 ):
250
+ """Quantize tensor per group based on k quant.
251
+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
252
+
253
+ Args:
254
+ data : input weight
255
+ num_bits (int, optional): num_bits. Defaults to 4.
256
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
257
+
258
+ Returns:
259
+ output: quantized weight
260
+ scale: scale
261
+ zero_point: zero point
262
+ """
263
+ data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # (nb, group_size)
264
+ maxq = 2 ** num_bits - 1
265
+ minq = 0
266
+ sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
267
+ av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
268
+ weights = np .add (av_x , np .abs (data )) # (nb, group_size)
269
+ rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
270
+ rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
271
+ sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
272
+ sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
273
+ iscale = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
274
+ mask = rmin != rmax
275
+ iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
276
+ scale = 1 / iscale
277
+ quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
278
+ diff = scale * quant_data + rmin - data # (nb, group_size)
279
+ best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
280
+ nstep = 20
281
+ rdelta = 0.1
282
+ # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
283
+ rrmin = - 1
284
+ for is_ in range (nstep ):
285
+ iscale_new = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
286
+ factor = np .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
287
+ mask = rmin != rmax
288
+ iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
289
+ quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
290
+ mul_weights_quant_data_new = weights * quant_data_new
291
+ sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
292
+ sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
293
+ sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
294
+ D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
295
+
296
+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
297
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
298
+
299
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
300
+ mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
301
+
302
+ mad_1 = np .array (mad )
303
+ best_mad_1 = np .array (best_mad )
304
+ idx_to_replace = np .where (mad_1 < best_mad_1 )[0 ]
305
+ quant_data [idx_to_replace , :] = quant_data_new [idx_to_replace , :]
306
+ best_mad [idx_to_replace ] = mad [idx_to_replace ]
307
+ scale [idx_to_replace ] = this_scale [idx_to_replace ]
308
+ rmin [idx_to_replace ] = this_min [idx_to_replace ]
309
+
310
+ zero_point = np .clip ((( - rmin ) / scale ).round (), 0 , maxq ).astype ("uint8" )
311
+ scale = scale .astype (np .float64 )
312
+ q_weight = np .empty_like (data , dtype = scale .dtype )
313
+ np .divide (data , scale , out = q_weight )
314
+ np .add (q_weight , zero_point , out = q_weight )
315
+ np .round (q_weight , out = q_weight )
316
+ np .clip (q_weight , minq , maxq , out = q_weight )
317
+
318
+ return q_weight , scale , zero_point
319
+
320
+ def quant_tensor_k_quant_cuda (data , num_bits = 4 , group_size = 32 ):
321
+ """Quantize tensor per group based on k quant.
322
+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
323
+
324
+ Args:
325
+ data : input weight
326
+ num_bits (int, optional): num_bits. Defaults to 4.
327
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
328
+
329
+ Returns:
330
+ output: quantized weight
331
+ scale: scale
332
+ zero_point: zero point
333
+ """
334
+ try :
335
+ import cupy as cp
336
+ import torch
337
+ if torch .cuda .is_available ():
338
+ data = cp .asarray (data )
339
+ data = data .reshape ((- 1 , group_size )).astype (cp .float32 ) # nb = data.shape[0], (nb, group_size)
340
+ maxq = 2 ** num_bits - 1
341
+ minq = 0
342
+ sum_x2 = cp .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
343
+ av_x = cp .sqrt (sum_x2 / group_size ) # (nb, 1)
344
+ weights = cp .add (av_x , cp .abs (data )) # (nb, group_size)
345
+ rmin = cp .min (data , axis = 1 , keepdims = True ) # (nb, 1)
346
+ rmax = cp .max (data , axis = 1 , keepdims = True ) # (nb, 1)
347
+ sum_w = cp .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
348
+ sum_x = cp .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
349
+ iscale = cp .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
350
+ mask = rmin != rmax
351
+ iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
352
+ scale = 1 / iscale
353
+ quant_data = cp .clip (cp .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
354
+ diff = scale * quant_data + rmin - data # (nb, group_size)
355
+ best_mad = cp .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
356
+ nstep = 20
357
+ rdelta = 0.1
358
+ rrmin = - 1
359
+ for is_ in range (nstep ):
360
+ iscale_new = cp .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
361
+ factor = cp .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
362
+ mask = rmin != rmax
363
+ iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
364
+ quant_data_new = cp .clip (cp .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
365
+ mul_weights_quant_data_new = weights * quant_data_new
366
+ sum_l = cp .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
367
+ sum_l2 = cp .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
368
+ sum_xl = cp .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
369
+ D = cp .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
370
+
371
+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
372
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
373
+
374
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
375
+ mad = cp .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
376
+
377
+ mad_1 = cp .array (mad )
378
+ best_mad_1 = cp .array (best_mad )
379
+ idx_to_replace = cp .where (mad_1 < best_mad_1 )[0 ]
380
+ quant_data [idx_to_replace , :] = quant_data_new [idx_to_replace , :]
381
+ best_mad [idx_to_replace ] = mad [idx_to_replace ]
382
+ scale [idx_to_replace ] = this_scale [idx_to_replace ]
383
+ rmin [idx_to_replace ] = this_min [idx_to_replace ]
384
+
385
+ zero_point = cp .clip ((( - rmin ) / scale ).round (), 0 , maxq ).astype ("uint8" )
386
+ scale = scale .astype (cp .float64 )
387
+ q_weight = cp .empty_like (data , dtype = scale .dtype )
388
+ cp .divide (data , scale , out = q_weight )
389
+ cp .add (q_weight , zero_point , out = q_weight )
390
+ cp .round (q_weight , out = q_weight )
391
+ cp .clip (q_weight , minq , maxq , out = q_weight )
392
+
393
+ return q_weight .get (), scale .get (), zero_point .get ()
394
+ else :
395
+ logger .warning ("Try to use k-quant quantization on CUDA. However, CUDA is not available." \
396
+ "Fall back to k-quant quantization on CPU." )
397
+ return quant_tensor_k_quant_cpu (
398
+ data , num_bits , group_size
399
+ )
400
+ except ImportError :
401
+ logger .info (
402
+ "Now we are using k-quant quantization on cpu, which is time consuming." \
403
+ "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \
404
+ "Please also install torch to check CUDA availablity."
405
+ )
406
+ return quant_tensor_k_quant_cpu (
407
+ data , num_bits , group_size
408
+ )
409
+
249
410
250
411
def qdq_tensor (data , num_bits = 4 , group_size = 32 , scheme = "asym" , dtype = "int" , ratio = 1.0 ):
251
412
"""Quant dequant tensor per group.
@@ -299,6 +460,7 @@ def rtn_quantize(
299
460
ratios = {},
300
461
accuracy_level = 0 ,
301
462
providers = ["CPUExecutionProvider" ],
463
+ algorithm = "rtn" ,
302
464
):
303
465
"""Quant the model with round to nearst method.
304
466
@@ -372,9 +534,15 @@ def rtn_quantize(
372
534
): # pragma: no cover
373
535
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
374
536
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
375
- q_weight , scale , zp = quant_tensor (
376
- weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
377
- )
537
+ if algorithm == "k_quant" :
538
+ q_weight , scale , zp = quant_tensor_k_quant_cuda (
539
+ weight .T , num_bits , group_size
540
+ )
541
+ else :
542
+ q_weight , scale , zp = quant_tensor (
543
+ weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
544
+ )
545
+
378
546
q_matmul_node , new_inits = make_matmul_weight_only_node (
379
547
node = node ,
380
548
weight_shape = org_w_shape ,
0 commit comments