40
40
ONNXRT1161_VERSION = Version ("1.16.1" )
41
41
42
42
43
- def get_blob_size (group_size , has_zp ): # pragma: no cover
43
+ def get_blob_size (group_size , num_bits , has_zp ): # pragma: no cover
44
44
"""Get blob_size.
45
45
46
46
Args:
47
47
group_size (int): how many elements share one scale/zp
48
48
has_zp (bool): whether zero_point is None
49
49
"""
50
50
if Version (ort .__version__ ) > ONNXRT1161_VERSION :
51
- blob_size = group_size // 2
51
+ blob_size = group_size * num_bits // 8
52
52
elif has_zp :
53
- blob_size = group_size // 2 + 4 + 1
53
+ blob_size = group_size * num_bits // 8 + 4 + 1
54
54
else :
55
- blob_size = group_size // 2 + 4
55
+ blob_size = group_size * num_bits // 8 + 4
56
56
return blob_size
57
57
58
58
@@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
86
86
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
87
87
new_inits: initializers of the new node
88
88
"""
89
- blob_size = get_blob_size (group_size , zero_point is not None )
89
+ blob_size = get_blob_size (group_size , num_bits , zero_point is not None )
90
90
packed = np .zeros ((q_weight .shape [0 ], blob_size ), dtype = "uint8" )
91
91
q_weight_name = node .input [1 ] + "_Q{}G{}" .format (str (num_bits ), str (group_size ))
92
92
input_names = [node .input [0 ], q_weight_name ]
@@ -97,8 +97,16 @@ def make_matmul_weight_only_node(
97
97
op_type = "MatMulNBits"
98
98
99
99
# pack quantized weight
100
- q_weight_pairs = q_weight [:, ::2 ] | q_weight [:, 1 ::2 ] << 4
101
- packed [:, :] = q_weight_pairs [:, :blob_size ]
100
+ if num_bits == 4 :
101
+ q_weight_pairs = q_weight [:, ::2 ] | q_weight [:, 1 ::2 ] << 4
102
+ packed [:, :] = q_weight_pairs [:, :blob_size ]
103
+ elif num_bits == 8 :
104
+ packed = q_weight
105
+ else :
106
+ logger .error (
107
+ "MatMulNBits does not have kernel support for num_bits = {}." .format (num_bits )
108
+ )
109
+
102
110
packed = np .reshape (packed , (- 1 , k_blocks , blob_size ))
103
111
104
112
# build scale tensor
@@ -115,7 +123,9 @@ def make_matmul_weight_only_node(
115
123
116
124
# build zero_point tensor
117
125
if zero_point is not None :
118
- if num_bits > 4 :
126
+ if num_bits == 8 :
127
+ packed_zp = zero_point .astype ("uint8" )
128
+ elif num_bits > 4 :
119
129
packed_zp = np .reshape (zero_point , (1 , - 1 )).astype ("uint8" )
120
130
else :
121
131
packed_zp = np .full ((zero_point .shape [0 ] + 1 ) // 2 , 136 , dtype = "uint8" )
@@ -128,6 +138,7 @@ def make_matmul_weight_only_node(
128
138
packed_zp [even_idx // 2 ] = (packed_zp [even_idx // 2 ] & 0xF0 ) | zero_point [even_idx ].ravel ()
129
139
packed_zp [odd_idx // 2 ] = (packed_zp [odd_idx // 2 ] & 0x0F ) | (zero_point [odd_idx ].ravel () << 4 )
130
140
141
+ packed_zp = np .reshape (packed_zp , (weight_shape [1 ], - 1 ))
131
142
zp_tensor = onnx .helper .make_tensor (
132
143
name = node .input [1 ] + "_zp" , data_type = 2 , dims = packed_zp .shape , vals = packed_zp .tobytes (), raw = True
133
144
)
@@ -463,7 +474,7 @@ def rtn_quantize(
463
474
ratios = {},
464
475
accuracy_level = 0 ,
465
476
providers = ["CPUExecutionProvider" ],
466
- algorithm = "rtn " ,
477
+ algorithm = "k_quant " ,
467
478
):
468
479
"""Quant the model with round to nearst method.
469
480
@@ -527,7 +538,8 @@ def rtn_quantize(
527
538
528
539
weight = pad_tensor (weight , group_size , k_blocks )
529
540
530
- satisfy_MatMulNBits_condition = Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4
541
+ enable_MatMulNBits_8bits = True
542
+ satisfy_MatMulNBits_condition = (Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4 ) or (enable_MatMulNBits_8bits and num_bits == 8 )
531
543
satisfy_MatMulFpQ4_condition = (
532
544
Version (ort .__version__ ) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
533
545
)
0 commit comments