Skip to content

Commit 903604f

Browse files
committed
Merge branch 'int8_new' into k_quant
2 parents 99f10df + 0a1a0d4 commit 903604f

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@
4040
ONNXRT1161_VERSION = Version("1.16.1")
4141

4242

43-
def get_blob_size(group_size, has_zp): # pragma: no cover
43+
def get_blob_size(group_size, num_bits, has_zp): # pragma: no cover
4444
"""Get blob_size.
4545
4646
Args:
4747
group_size (int): how many elements share one scale/zp
4848
has_zp (bool): whether zero_point is None
4949
"""
5050
if Version(ort.__version__) > ONNXRT1161_VERSION:
51-
blob_size = group_size // 2
51+
blob_size = group_size * num_bits // 8
5252
elif has_zp:
53-
blob_size = group_size // 2 + 4 + 1
53+
blob_size = group_size * num_bits // 8 + 4 + 1
5454
else:
55-
blob_size = group_size // 2 + 4
55+
blob_size = group_size * num_bits // 8 + 4
5656
return blob_size
5757

5858

@@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
8686
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
8787
new_inits: initializers of the new node
8888
"""
89-
blob_size = get_blob_size(group_size, zero_point is not None)
89+
blob_size = get_blob_size(group_size, num_bits, zero_point is not None)
9090
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
9191
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
9292
input_names = [node.input[0], q_weight_name]
@@ -97,8 +97,16 @@ def make_matmul_weight_only_node(
9797
op_type = "MatMulNBits"
9898

9999
# pack quantized weight
100-
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
101-
packed[:, :] = q_weight_pairs[:, :blob_size]
100+
if num_bits == 4:
101+
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
102+
packed[:, :] = q_weight_pairs[:, :blob_size]
103+
elif num_bits == 8:
104+
packed = q_weight
105+
else:
106+
logger.error(
107+
"MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits)
108+
)
109+
102110
packed = np.reshape(packed, (-1, k_blocks, blob_size))
103111

104112
# build scale tensor
@@ -115,7 +123,9 @@ def make_matmul_weight_only_node(
115123

116124
# build zero_point tensor
117125
if zero_point is not None:
118-
if num_bits > 4:
126+
if num_bits == 8:
127+
packed_zp = zero_point.astype("uint8")
128+
elif num_bits > 4:
119129
packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8")
120130
else:
121131
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
@@ -128,6 +138,7 @@ def make_matmul_weight_only_node(
128138
packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
129139
packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)
130140

141+
packed_zp = np.reshape(packed_zp, (weight_shape[1], -1))
131142
zp_tensor = onnx.helper.make_tensor(
132143
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
133144
)
@@ -463,7 +474,7 @@ def rtn_quantize(
463474
ratios={},
464475
accuracy_level=0,
465476
providers=["CPUExecutionProvider"],
466-
algorithm="rtn",
477+
algorithm="k_quant",
467478
):
468479
"""Quant the model with round to nearst method.
469480
@@ -527,7 +538,8 @@ def rtn_quantize(
527538

528539
weight = pad_tensor(weight, group_size, k_blocks)
529540

530-
satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4
541+
enable_MatMulNBits_8bits = True
542+
satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (enable_MatMulNBits_8bits and num_bits == 8)
531543
satisfy_MatMulFpQ4_condition = (
532544
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
533545
)

0 commit comments

Comments
 (0)