Skip to content

Commit c3318cf

Browse files
committed
k quant
Signed-off-by: David Fan <[email protected]>
1 parent e9bd2e7 commit c3318cf

File tree

1 file changed

+171
-3
lines changed

1 file changed

+171
-3
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+171-3
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,167 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
246246

247247
return q_weight, scale, zero_point
248248

249+
def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
250+
"""Quantize tensor per group based on k quant.
251+
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
252+
253+
Args:
254+
data : input weight
255+
num_bits (int, optional): num_bits. Defaults to 4.
256+
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
257+
258+
Returns:
259+
output: quantized weight
260+
scale: scale
261+
zero_point: zero point
262+
"""
263+
data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size)
264+
maxq = 2**num_bits - 1
265+
minq = 0
266+
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
267+
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
268+
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
269+
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
270+
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
271+
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
272+
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
273+
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
274+
mask = rmin != rmax
275+
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
276+
scale = 1 / iscale
277+
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
278+
diff = scale * quant_data + rmin - data # (nb, group_size)
279+
best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
280+
nstep = 20
281+
rdelta = 0.1
282+
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
283+
rrmin = -1
284+
for is_ in range(nstep):
285+
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
286+
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
287+
mask = rmin != rmax
288+
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
289+
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
290+
mul_weights_quant_data_new = weights * quant_data_new
291+
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
292+
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
293+
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
294+
D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
295+
296+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
297+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
298+
299+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
300+
mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
301+
302+
mad_1 = np.array(mad)
303+
best_mad_1 = np.array(best_mad)
304+
idx_to_replace = np.where(mad_1 < best_mad_1)[0]
305+
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
306+
best_mad[idx_to_replace] = mad[idx_to_replace]
307+
scale[idx_to_replace] = this_scale[idx_to_replace]
308+
rmin[idx_to_replace] = this_min[idx_to_replace]
309+
310+
zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8")
311+
scale = scale.astype(np.float64)
312+
q_weight = np.empty_like(data, dtype=scale.dtype)
313+
np.divide(data, scale, out=q_weight)
314+
np.add(q_weight, zero_point, out=q_weight)
315+
np.round(q_weight, out=q_weight)
316+
np.clip(q_weight, minq, maxq, out=q_weight)
317+
318+
return q_weight, scale, zero_point
319+
320+
def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
321+
"""Quantize tensor per group based on k quant.
322+
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
323+
324+
Args:
325+
data : input weight
326+
num_bits (int, optional): num_bits. Defaults to 4.
327+
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
328+
329+
Returns:
330+
output: quantized weight
331+
scale: scale
332+
zero_point: zero point
333+
"""
334+
try:
335+
import cupy as cp
336+
import torch
337+
if torch.cuda.is_available():
338+
data = cp.asarray(data)
339+
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
340+
maxq = 2**num_bits - 1
341+
minq = 0
342+
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
343+
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
344+
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
345+
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
346+
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
347+
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
348+
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
349+
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
350+
mask = rmin != rmax
351+
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
352+
scale = 1 / iscale
353+
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
354+
diff = scale * quant_data + rmin - data # (nb, group_size)
355+
best_mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
356+
nstep = 20
357+
rdelta = 0.1
358+
rrmin = -1
359+
for is_ in range(nstep):
360+
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
361+
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
362+
mask = rmin != rmax
363+
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
364+
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
365+
mul_weights_quant_data_new = weights * quant_data_new
366+
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
367+
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
368+
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
369+
D = cp.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
370+
371+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
372+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
373+
374+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
375+
mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
376+
377+
mad_1 = cp.array(mad)
378+
best_mad_1 = cp.array(best_mad)
379+
idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
380+
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
381+
best_mad[idx_to_replace] = mad[idx_to_replace]
382+
scale[idx_to_replace] = this_scale[idx_to_replace]
383+
rmin[idx_to_replace] = this_min[idx_to_replace]
384+
385+
zero_point = cp.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8")
386+
scale = scale.astype(cp.float64)
387+
q_weight = cp.empty_like(data, dtype=scale.dtype)
388+
cp.divide(data, scale, out=q_weight)
389+
cp.add(q_weight, zero_point, out=q_weight)
390+
cp.round(q_weight, out=q_weight)
391+
cp.clip(q_weight, minq, maxq, out=q_weight)
392+
393+
return q_weight.get(), scale.get(), zero_point.get()
394+
else:
395+
logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \
396+
"Fall back to k-quant quantization on CPU.")
397+
return quant_tensor_k_quant_cpu(
398+
data, num_bits, group_size
399+
)
400+
except ImportError:
401+
logger.info(
402+
"Now we are using k-quant quantization on cpu, which is time consuming." \
403+
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \
404+
"Please also install torch to check CUDA availablity."
405+
)
406+
return quant_tensor_k_quant_cpu(
407+
data, num_bits, group_size
408+
)
409+
249410

250411
def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
251412
"""Quant dequant tensor per group.
@@ -299,6 +460,7 @@ def rtn_quantize(
299460
ratios={},
300461
accuracy_level=0,
301462
providers=["CPUExecutionProvider"],
463+
algorithm="rtn",
302464
):
303465
"""Quant the model with round to nearst method.
304466
@@ -372,9 +534,15 @@ def rtn_quantize(
372534
): # pragma: no cover
373535
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
374536
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
375-
q_weight, scale, zp = quant_tensor(
376-
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
377-
)
537+
if algorithm == "k_quant":
538+
q_weight, scale, zp = quant_tensor_k_quant_cuda(
539+
weight.T, num_bits, group_size
540+
)
541+
else:
542+
q_weight, scale, zp = quant_tensor(
543+
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
544+
)
545+
378546
q_matmul_node, new_inits = make_matmul_weight_only_node(
379547
node=node,
380548
weight_shape=org_w_shape,

0 commit comments

Comments
 (0)