Skip to content

Commit 3486e2c

Browse files
anijain2305zhiics
authored andcommitted
[QNN][Legalize] Specialize for Platforms without any fast Int8 arithmetic units. (#4307)
1 parent 8cd5cce commit 3486e2c

File tree

2 files changed

+308
-39
lines changed

2 files changed

+308
-39
lines changed

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 161 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,43 @@
2222
from tvm import relay
2323
from .. import op as reg
2424

25+
#################################################
26+
# Register the functions for different operators.
27+
#################################################
28+
2529
# Registering QNN Conv2D legalization function.
2630
@reg.register_qnn_legalize("qnn.conv2d")
2731
def legalize_qnn_conv2d(attrs, inputs, types):
28-
"""Legalizes QNN conv2d op.
32+
return qnn_conv2d_legalize(attrs, inputs, types)
33+
34+
# Registering QNN dense legalization function.
35+
@reg.register_qnn_legalize("qnn.dense")
36+
def legalize_qnn_dense(attrs, inputs, types):
37+
return qnn_dense_legalize(attrs, inputs, types)
38+
39+
# Default to None. If overridden by target, this will not be run.
40+
# Generic QNN Conv2D legalization function.
41+
@tvm.target.generic_func
42+
def qnn_conv2d_legalize(attrs, inputs, types):
43+
"""Default legalization is None."""
44+
return None
45+
46+
# Generic QNN Conv2D legalization function.
47+
@tvm.target.generic_func
48+
def qnn_dense_legalize(attrs, inputs, types):
49+
"""Default legalization is None."""
50+
return None
51+
52+
###################
53+
# Helper functions.
54+
###################
55+
56+
# Helper function for lowering in the abscence of fast Int8 arithmetic units.
57+
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
58+
""" Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
59+
not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
60+
much more efficiently if the convolution or dense operator input datatypes are int16 instead of
61+
int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
2962
3063
Parameters
3164
----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
4174
result : tvm.relay.Expr
4275
The legalized expr
4376
"""
44-
return qnn_conv2d_legalize(attrs, inputs, types)
4577

46-
# Generic QNN Conv2D legalization function.
47-
@tvm.target.generic_func
48-
def qnn_conv2d_legalize(attrs, inputs, types):
49-
"""Default legalization is None."""
50-
return None
78+
# Collect the input exprs.
79+
data, kernel = inputs
5180

52-
# Intel x86 QNN Conv2D legalization function.
53-
@qnn_conv2d_legalize.register('cpu')
54-
def _qnn_conv2d_legalize(attrs, inputs, types):
55-
"""Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
56-
we dont transform. Else, we shift the tensor values and zero points to change the dtype.
81+
input_zp = attrs['input_zero_point']
82+
kernel_zp = attrs['kernel_zero_point']
83+
84+
shift_data = relay.subtract(relay.cast(data, dtype='int16'),
85+
relay.const(input_zp, 'int16'))
86+
shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
87+
relay.const(kernel_zp, 'int16'))
88+
new_attrs = {k : attrs[k] for k in attrs.keys()}
89+
del new_attrs['kernel_zero_point']
90+
del new_attrs['input_zero_point']
91+
return relay_op(shift_data, shift_kernel, **new_attrs)
92+
93+
# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
94+
def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
95+
"""Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
96+
are already good, we dont transform. Else, we shift the tensor values and zero points to change
97+
the dtype.
5798
5899
Converting from int8 to uint8 can be done in following manner.
59100
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
82123
The legalized expr
83124
"""
84125

85-
def _shift(data, out_dtype):
126+
def _shift(data, zero_point, out_dtype):
86127
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
87128
if out_dtype == 'uint8':
88129
shift = 128
89130
elif out_dtype == 'int8':
90131
shift = -128
91132
else:
92-
raise ValueError("Unsupport out dtype.")
133+
raise ValueError("Unsupported out dtype.")
93134
data_modified = relay.cast(data, 'int32')
94135
data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
95136
data_modified = relay.cast(data_modified, out_dtype)
96-
return data_modified
97-
98-
def _is_int8_hw_support(target):
99-
"""
100-
Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
101-
and above.
102-
"""
103-
supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
104-
return supported_arches.intersection(set(target.options))
137+
return (data_modified, zero_point + shift)
105138

106139
# Collect the dtypes.
107140
data_dtype = types[0].dtype
@@ -110,11 +143,6 @@ def _is_int8_hw_support(target):
110143
# Collect the input exprs.
111144
data, kernel = inputs
112145

113-
# The VNNI transformations are applicable only Skylake and above.g
114-
target = tvm.target.current_target(allow_none=False)
115-
if not _is_int8_hw_support(target):
116-
return None
117-
118146
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
119147
if data_dtype == 'uint8' and kernel_dtype == 'int8':
120148
return None
@@ -123,18 +151,118 @@ def _is_int8_hw_support(target):
123151
input_zp = attrs['input_zero_point']
124152
if data_dtype == 'int8':
125153
# Compute (QA + 128) and (zp_a + 128)
126-
data = _shift(data, 'uint8')
127-
input_zp = input_zp + 128
154+
data, input_zp = _shift(data, input_zp, 'uint8')
128155

129156
# Shift kernel if necessary.
130157
kernel_zp = attrs['kernel_zero_point']
131158
if kernel_dtype == 'uint8':
132159
# Compute (QA - 128) and (zp_a - 128)
133-
kernel = _shift(kernel, 'int8')
134-
kernel_zp = kernel_zp - 128
160+
kernel, kernel_zp = _shift(kernel, kernel_zp, 'int8')
135161

136162
# Call qnn.conv2d with modified inputs and zero points.
137163
new_attrs = {k : attrs[k] for k in attrs.keys()}
138164
new_attrs['input_zero_point'] = input_zp
139165
new_attrs['kernel_zero_point'] = kernel_zp
140-
return relay.qnn.op.conv2d(data, kernel, **new_attrs)
166+
return relay_op(data, kernel, **new_attrs)
167+
168+
# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
169+
def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
170+
""" Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
171+
many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
172+
conv2d/dense such that both the dtypes are same.
173+
174+
Parameters
175+
----------
176+
attrs : tvm.attrs.Attrs
177+
Attributes of current convolution
178+
inputs : list of tvm.relay.Expr
179+
The args of the Relay expr to be legalized
180+
types : list of types
181+
List of input and output types
182+
183+
Returns
184+
-------
185+
result : tvm.relay.Expr
186+
The legalized expr
187+
"""
188+
189+
def _shift(data, zero_point, out_dtype):
190+
"""Shifts (adds/subtracts) the qnn tensor by 128)"""
191+
if out_dtype == 'uint8':
192+
shift = 128
193+
elif out_dtype == 'int8':
194+
shift = -128
195+
else:
196+
raise ValueError("Unsupported out dtype.")
197+
data_modified = relay.cast(data, 'int32')
198+
data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
199+
data_modified = relay.cast(data_modified, out_dtype)
200+
return (data_modified, zero_point + shift)
201+
202+
# Collect the dtypes.
203+
data_dtype = types[0].dtype
204+
kernel_dtype = types[1].dtype
205+
206+
if data_dtype == kernel_dtype:
207+
return None
208+
209+
# Collect the input exprs.
210+
data, kernel = inputs
211+
212+
assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
213+
"Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
214+
215+
# Shift input if necessary.
216+
input_zp = attrs['input_zero_point']
217+
data, input_zp = _shift(data, input_zp, kernel_dtype)
218+
219+
new_attrs = {k : attrs[k] for k in attrs.keys()}
220+
new_attrs['input_zero_point'] = input_zp
221+
return relay_op(data, kernel, **new_attrs)
222+
223+
def is_fast_int8_on_intel():
224+
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
225+
target = tvm.target.current_target(allow_none=False)
226+
intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
227+
return intel_supported_arches.intersection(set(target.options))
228+
229+
def is_fast_int8_on_arm():
230+
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
231+
target = tvm.target.current_target(allow_none=False)
232+
return '+v8.2a,+dotprod' in ' '.join(target.options)
233+
234+
########################
235+
# ARM CPU legalizations.
236+
########################
237+
238+
@qnn_conv2d_legalize.register('arm_cpu')
239+
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
240+
# ARM prefers the dtypes to be same.
241+
if is_fast_int8_on_arm():
242+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
243+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
244+
245+
@qnn_dense_legalize.register('arm_cpu')
246+
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
247+
# ARM prefers the dtypes to be same.
248+
if is_fast_int8_on_arm():
249+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
250+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
251+
252+
##########################
253+
# Intel CPU legalizations.
254+
##########################
255+
256+
@qnn_conv2d_legalize.register('cpu')
257+
def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
258+
# The VNNI transformations prefer uint8 x int8 datatypes.
259+
if is_fast_int8_on_intel():
260+
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
261+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
262+
263+
@qnn_dense_legalize.register('cpu')
264+
def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
265+
# The VNNI transformations prefer uint8 x int8 datatypes.
266+
if is_fast_int8_on_intel():
267+
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
268+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)

0 commit comments

Comments
 (0)