22
22
from tvm import relay
23
23
from .. import op as reg
24
24
25
+ #################################################
26
+ # Register the functions for different operators.
27
+ #################################################
28
+
25
29
# Registering QNN Conv2D legalization function.
26
30
@reg .register_qnn_legalize ("qnn.conv2d" )
27
31
def legalize_qnn_conv2d (attrs , inputs , types ):
28
- """Legalizes QNN conv2d op.
32
+ return qnn_conv2d_legalize (attrs , inputs , types )
33
+
34
+ # Registering QNN dense legalization function.
35
+ @reg .register_qnn_legalize ("qnn.dense" )
36
+ def legalize_qnn_dense (attrs , inputs , types ):
37
+ return qnn_dense_legalize (attrs , inputs , types )
38
+
39
+ # Default to None. If overridden by target, this will not be run.
40
+ # Generic QNN Conv2D legalization function.
41
+ @tvm .target .generic_func
42
+ def qnn_conv2d_legalize (attrs , inputs , types ):
43
+ """Default legalization is None."""
44
+ return None
45
+
46
+ # Generic QNN Conv2D legalization function.
47
+ @tvm .target .generic_func
48
+ def qnn_dense_legalize (attrs , inputs , types ):
49
+ """Default legalization is None."""
50
+ return None
51
+
52
+ ###################
53
+ # Helper functions.
54
+ ###################
55
+
56
+ # Helper function for lowering in the abscence of fast Int8 arithmetic units.
57
+ def helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay_op ):
58
+ """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
59
+ not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
60
+ much more efficiently if the convolution or dense operator input datatypes are int16 instead of
61
+ int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
29
62
30
63
Parameters
31
64
----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
41
74
result : tvm.relay.Expr
42
75
The legalized expr
43
76
"""
44
- return qnn_conv2d_legalize (attrs , inputs , types )
45
77
46
- # Generic QNN Conv2D legalization function.
47
- @tvm .target .generic_func
48
- def qnn_conv2d_legalize (attrs , inputs , types ):
49
- """Default legalization is None."""
50
- return None
78
+ # Collect the input exprs.
79
+ data , kernel = inputs
51
80
52
- # Intel x86 QNN Conv2D legalization function.
53
- @qnn_conv2d_legalize .register ('cpu' )
54
- def _qnn_conv2d_legalize (attrs , inputs , types ):
55
- """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
56
- we dont transform. Else, we shift the tensor values and zero points to change the dtype.
81
+ input_zp = attrs ['input_zero_point' ]
82
+ kernel_zp = attrs ['kernel_zero_point' ]
83
+
84
+ shift_data = relay .subtract (relay .cast (data , dtype = 'int16' ),
85
+ relay .const (input_zp , 'int16' ))
86
+ shift_kernel = relay .subtract (relay .cast (kernel , dtype = 'int16' ),
87
+ relay .const (kernel_zp , 'int16' ))
88
+ new_attrs = {k : attrs [k ] for k in attrs .keys ()}
89
+ del new_attrs ['kernel_zero_point' ]
90
+ del new_attrs ['input_zero_point' ]
91
+ return relay_op (shift_data , shift_kernel , ** new_attrs )
92
+
93
+ # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
94
+ def helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay_op ):
95
+ """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
96
+ are already good, we dont transform. Else, we shift the tensor values and zero points to change
97
+ the dtype.
57
98
58
99
Converting from int8 to uint8 can be done in following manner.
59
100
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
82
123
The legalized expr
83
124
"""
84
125
85
- def _shift (data , out_dtype ):
126
+ def _shift (data , zero_point , out_dtype ):
86
127
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
87
128
if out_dtype == 'uint8' :
88
129
shift = 128
89
130
elif out_dtype == 'int8' :
90
131
shift = - 128
91
132
else :
92
- raise ValueError ("Unsupport out dtype." )
133
+ raise ValueError ("Unsupported out dtype." )
93
134
data_modified = relay .cast (data , 'int32' )
94
135
data_modified = relay .add (data_modified , relay .const (shift , 'int32' ))
95
136
data_modified = relay .cast (data_modified , out_dtype )
96
- return data_modified
97
-
98
- def _is_int8_hw_support (target ):
99
- """
100
- Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
101
- and above.
102
- """
103
- supported_arches = {'-mcpu=skylake-avx512' , '-mcpu=cascadelake' }
104
- return supported_arches .intersection (set (target .options ))
137
+ return (data_modified , zero_point + shift )
105
138
106
139
# Collect the dtypes.
107
140
data_dtype = types [0 ].dtype
@@ -110,11 +143,6 @@ def _is_int8_hw_support(target):
110
143
# Collect the input exprs.
111
144
data , kernel = inputs
112
145
113
- # The VNNI transformations are applicable only Skylake and above.g
114
- target = tvm .target .current_target (allow_none = False )
115
- if not _is_int8_hw_support (target ):
116
- return None
117
-
118
146
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
119
147
if data_dtype == 'uint8' and kernel_dtype == 'int8' :
120
148
return None
@@ -123,18 +151,118 @@ def _is_int8_hw_support(target):
123
151
input_zp = attrs ['input_zero_point' ]
124
152
if data_dtype == 'int8' :
125
153
# Compute (QA + 128) and (zp_a + 128)
126
- data = _shift (data , 'uint8' )
127
- input_zp = input_zp + 128
154
+ data , input_zp = _shift (data , input_zp , 'uint8' )
128
155
129
156
# Shift kernel if necessary.
130
157
kernel_zp = attrs ['kernel_zero_point' ]
131
158
if kernel_dtype == 'uint8' :
132
159
# Compute (QA - 128) and (zp_a - 128)
133
- kernel = _shift (kernel , 'int8' )
134
- kernel_zp = kernel_zp - 128
160
+ kernel , kernel_zp = _shift (kernel , kernel_zp , 'int8' )
135
161
136
162
# Call qnn.conv2d with modified inputs and zero points.
137
163
new_attrs = {k : attrs [k ] for k in attrs .keys ()}
138
164
new_attrs ['input_zero_point' ] = input_zp
139
165
new_attrs ['kernel_zero_point' ] = kernel_zp
140
- return relay .qnn .op .conv2d (data , kernel , ** new_attrs )
166
+ return relay_op (data , kernel , ** new_attrs )
167
+
168
+ # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
169
+ def helper_change_dtypes_to_be_same (attrs , inputs , types , relay_op ):
170
+ """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
171
+ many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
172
+ conv2d/dense such that both the dtypes are same.
173
+
174
+ Parameters
175
+ ----------
176
+ attrs : tvm.attrs.Attrs
177
+ Attributes of current convolution
178
+ inputs : list of tvm.relay.Expr
179
+ The args of the Relay expr to be legalized
180
+ types : list of types
181
+ List of input and output types
182
+
183
+ Returns
184
+ -------
185
+ result : tvm.relay.Expr
186
+ The legalized expr
187
+ """
188
+
189
+ def _shift (data , zero_point , out_dtype ):
190
+ """Shifts (adds/subtracts) the qnn tensor by 128)"""
191
+ if out_dtype == 'uint8' :
192
+ shift = 128
193
+ elif out_dtype == 'int8' :
194
+ shift = - 128
195
+ else :
196
+ raise ValueError ("Unsupported out dtype." )
197
+ data_modified = relay .cast (data , 'int32' )
198
+ data_modified = relay .add (data_modified , relay .const (shift , 'int32' ))
199
+ data_modified = relay .cast (data_modified , out_dtype )
200
+ return (data_modified , zero_point + shift )
201
+
202
+ # Collect the dtypes.
203
+ data_dtype = types [0 ].dtype
204
+ kernel_dtype = types [1 ].dtype
205
+
206
+ if data_dtype == kernel_dtype :
207
+ return None
208
+
209
+ # Collect the input exprs.
210
+ data , kernel = inputs
211
+
212
+ assert 'int8' in data_dtype and 'int8' in kernel_dtype , \
213
+ "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
214
+
215
+ # Shift input if necessary.
216
+ input_zp = attrs ['input_zero_point' ]
217
+ data , input_zp = _shift (data , input_zp , kernel_dtype )
218
+
219
+ new_attrs = {k : attrs [k ] for k in attrs .keys ()}
220
+ new_attrs ['input_zero_point' ] = input_zp
221
+ return relay_op (data , kernel , ** new_attrs )
222
+
223
+ def is_fast_int8_on_intel ():
224
+ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
225
+ target = tvm .target .current_target (allow_none = False )
226
+ intel_supported_arches = {'-mcpu=skylake-avx512' , '-mcpu=cascadelake' }
227
+ return intel_supported_arches .intersection (set (target .options ))
228
+
229
+ def is_fast_int8_on_arm ():
230
+ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
231
+ target = tvm .target .current_target (allow_none = False )
232
+ return '+v8.2a,+dotprod' in ' ' .join (target .options )
233
+
234
+ ########################
235
+ # ARM CPU legalizations.
236
+ ########################
237
+
238
+ @qnn_conv2d_legalize .register ('arm_cpu' )
239
+ def _qnn_conv2d_legalize_arm_cpu (attrs , inputs , types ):
240
+ # ARM prefers the dtypes to be same.
241
+ if is_fast_int8_on_arm ():
242
+ return helper_change_dtypes_to_be_same (attrs , inputs , types , relay .qnn .op .conv2d )
243
+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .conv2d )
244
+
245
+ @qnn_dense_legalize .register ('arm_cpu' )
246
+ def _qnn_dense_legalize_arm_cpu (attrs , inputs , types ):
247
+ # ARM prefers the dtypes to be same.
248
+ if is_fast_int8_on_arm ():
249
+ return helper_change_dtypes_to_be_same (attrs , inputs , types , relay .qnn .op .dense )
250
+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .dense )
251
+
252
+ ##########################
253
+ # Intel CPU legalizations.
254
+ ##########################
255
+
256
+ @qnn_conv2d_legalize .register ('cpu' )
257
+ def _qnn_conv2d_legalize_intel_cpu (attrs , inputs , types ):
258
+ # The VNNI transformations prefer uint8 x int8 datatypes.
259
+ if is_fast_int8_on_intel ():
260
+ return helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay .qnn .op .conv2d )
261
+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .conv2d )
262
+
263
+ @qnn_dense_legalize .register ('cpu' )
264
+ def _qnn_dense_legalize_intel_cpu (attrs , inputs , types ):
265
+ # The VNNI transformations prefer uint8 x int8 datatypes.
266
+ if is_fast_int8_on_intel ():
267
+ return helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay .qnn .op .dense )
268
+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .dense )
0 commit comments