Skip to content

Commit d91b4e5

Browse files
committed
Merge branch 'k_quant' of https://github.com/jiafatom/neural-compressor into k_quant
2 parents 4542a33 + 1b3518a commit d91b4e5

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+34-34
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
246246

247247
return q_weight, scale, zero_point
248248

249+
249250
def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
250251
"""Quantize tensor per group based on k quant.
251252
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
@@ -307,7 +308,7 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
307308
scale[idx_to_replace] = this_scale[idx_to_replace]
308309
rmin[idx_to_replace] = this_min[idx_to_replace]
309310

310-
zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8")
311+
zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
311312
scale = scale.astype(np.float64)
312313
q_weight = np.empty_like(data, dtype=scale.dtype)
313314
np.divide(data, scale, out=q_weight)
@@ -334,45 +335,46 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
334335
try:
335336
import cupy as cp
336337
import torch
338+
337339
if torch.cuda.is_available():
338340
data = cp.asarray(data)
339-
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
341+
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
340342
maxq = 2**num_bits - 1
341343
minq = 0
342-
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
343-
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
344-
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
345-
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
346-
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
347-
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
348-
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
349-
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
344+
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
345+
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
346+
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
347+
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
348+
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
349+
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
350+
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
351+
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
350352
mask = rmin != rmax
351353
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
352354
scale = 1 / iscale
353-
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
354-
diff = scale * quant_data + rmin - data # (nb, group_size)
355-
best_mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
355+
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
356+
diff = scale * quant_data + rmin - data # (nb, group_size)
357+
best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
356358
nstep = 20
357359
rdelta = 0.1
358360
rrmin = -1
359361
for is_ in range(nstep):
360-
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
362+
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
361363
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
362364
mask = rmin != rmax
363365
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
364-
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
366+
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
365367
mul_weights_quant_data_new = weights * quant_data_new
366-
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
367-
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
368-
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
369-
D = cp.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
368+
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
369+
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
370+
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
371+
D = cp.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1)
370372

371-
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
372-
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
373+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
374+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
373375

374-
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
375-
mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
376+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
377+
mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
376378

377379
mad_1 = cp.array(mad)
378380
best_mad_1 = cp.array(best_mad)
@@ -382,7 +384,7 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
382384
scale[idx_to_replace] = this_scale[idx_to_replace]
383385
rmin[idx_to_replace] = this_min[idx_to_replace]
384386

385-
zero_point = cp.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8")
387+
zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
386388
scale = scale.astype(cp.float64)
387389
q_weight = cp.empty_like(data, dtype=scale.dtype)
388390
cp.divide(data, scale, out=q_weight)
@@ -392,20 +394,18 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
392394

393395
return q_weight.get(), scale.get(), zero_point.get()
394396
else:
395-
logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \
396-
"Fall back to k-quant quantization on CPU.")
397-
return quant_tensor_k_quant_cpu(
398-
data, num_bits, group_size
397+
logger.warning(
398+
"Try to use k-quant quantization on CUDA. However, CUDA is not available."
399+
"Fall back to k-quant quantization on CPU."
399400
)
401+
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
400402
except ImportError:
401403
logger.info(
402-
"Now we are using k-quant quantization on cpu, which is time consuming." \
403-
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \
404-
"Please also install torch to check CUDA availablity."
405-
)
406-
return quant_tensor_k_quant_cpu(
407-
data, num_bits, group_size
404+
"Now we are using k-quant quantization on cpu, which is time consuming."
405+
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
406+
"Please also install torch to check CUDA availability."
408407
)
408+
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
409409

410410

411411
def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):

0 commit comments

Comments
 (0)