forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_nnapi.py
425 lines (368 loc) · 15.9 KB
/
test_nnapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
#!/usr/bin/env python3
import os
import ctypes
import torch
from typing import Tuple
from torch.backends._nnapi.prepare import convert_model_to_nnapi
from torch.testing._internal.common_utils import TestCase, run_tests
def qpt(t, scale, zero_point, dtype=torch.quint8):
t = torch.tensor(t)
return torch.quantize_per_tensor(t, scale, zero_point, dtype)
def nhwc(t):
t = t.clone().contiguous(memory_format=torch.channels_last)
t.nnapi_nhwc = True
return t
class TestNNAPI(TestCase):
def setUp(self):
# Avoid saturation in fbgemm
torch.backends.quantized.engine = 'qnnpack'
libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
if libneuralnetworks_path:
ctypes.cdll.LoadLibrary(libneuralnetworks_path)
print("Will attempt to run NNAPI models.")
self.can_run_nnapi = True
else:
self.can_run_nnapi = False
def check(
self,
module,
arg_or_args,
*,
trace_args=None,
convert_args=None,
atol_rtol=None,
limit=None,
):
with torch.no_grad():
if isinstance(arg_or_args, torch.Tensor):
args = [arg_or_args]
else:
args = arg_or_args
module.eval()
traced = torch.jit.trace(module, trace_args or args)
nnapi_module = convert_model_to_nnapi(traced, convert_args or args)
if not self.can_run_nnapi:
# Only test that the model was converted successfully.
return
eager_output = module(*args)
nnapi_output = nnapi_module(*args)
kwargs = {}
if atol_rtol is not None:
kwargs["atol"] = atol_rtol[0]
kwargs["rtol"] = atol_rtol[1]
self.assertEqual(eager_output, nnapi_output, **kwargs)
if limit is not None:
mismatches = \
eager_output.int_repr().to(torch.int32) - \
nnapi_output.int_repr().to(torch.int32)
if mismatches.count_nonzero() > limit:
# Too many mismatches. Re-run the check with no tolerance
# to get a nice message.
self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
torch.manual_seed(29)
inp_quant = qpt(inp_float, 0.03, 128)
return [
("float", inp_float),
("float-nhwc", nhwc(inp_float)),
("quant", inp_quant),
("quant-nhwc", nhwc(inp_quant)),
]
def test_prelu(self):
arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
single_a = torch.nn.PReLU()
self.check(single_a, arg)
multi_a = torch.nn.PReLU(4)
with torch.no_grad():
multi_a.weight.copy_(torch.tensor([.1, .2, .3, .4]))
self.check(multi_a, nhwc(arg))
# Test flexible size
self.check(
multi_a,
arg,
trace_args=[torch.zeros(1, 4, 3, 3)],
convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
)
def test_quantize(self):
self.check(
torch.nn.quantized.Quantize(0.25, 2, torch.quint8),
nhwc(torch.tensor([[[[1.0]], [[2.0]]]])))
def test_dequantize(self):
self.check(
torch.nn.quantized.DeQuantize(),
nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2)))
def test_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, arg):
return arg.unsqueeze(self.dim)
self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
def test_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, arg):
return arg.reshape(self.shape)
self.check(
ReshapeModule((2, 4)),
torch.randn(4, 2, 1, 1))
self.check(
ReshapeModule((8, -1)),
nhwc(torch.randn(4, 2, 1, 1)))
with self.assertRaisesRegex(Exception, "target size"):
self.check(
ReshapeModule((2, 4)),
nhwc(torch.randn(4, 2, 1, 1)))
def test_cat(self):
class CatModule(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t1, t2):
return torch.cat([t1, t2], self.dim)
self.check(
CatModule(0),
[
torch.randn(1, 2, 3, 3),
torch.randn(2, 2, 3, 3),
])
self.check(
CatModule(1),
[
torch.randn(1, 2, 3, 3),
torch.randn(1, 4, 3, 3),
])
self.check(
CatModule(1),
[
nhwc(torch.randn(1, 2, 3, 3)),
nhwc(torch.randn(1, 4, 3, 3)),
])
def test_pointwise_unary(self):
for op in ["relu", "sigmoid"]:
with self.subTest(op):
class UnaryModule(torch.nn.Module):
def forward(self, arg):
if op == "relu":
return torch.nn.functional.relu(arg)
if op == "sigmoid":
return torch.sigmoid(arg)
raise Exception("Bad op")
self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
def test_pointwise_binary(self):
for op in ["add", "sub", "mul"]:
with self.subTest(op):
class BinaryModule(torch.nn.Module):
def forward(self, lhs, rhs):
if op == "add":
return lhs + rhs
if op == "sub":
return lhs - rhs
if op == "mul":
return lhs * rhs
raise Exception("Bad op")
self.check(
BinaryModule(),
[
torch.tensor([1.0, 2.0]),
torch.tensor([3.0, 4.0]),
])
self.check(
BinaryModule(),
[
torch.tensor([[1.0, 2.0]]),
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
])
with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
self.check(
BinaryModule(),
[
torch.tensor([1.0, 2.0]),
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
])
def test_hardtanh(self):
inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
self.check(torch.nn.Hardtanh(), inp)
self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
with self.assertRaisesRegex(Exception, "hardtanh with args"):
self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
def test_mean(self):
class MeanModule(torch.nn.Module):
def __init__(self, dim, keep=False):
super().__init__()
self.dim = dim
self.keep = keep
def forward(self, t):
return torch.mean(t, dim=self.dim, keepdim=self.keep)
self.check(MeanModule(0), torch.randn(2, 3))
self.check(MeanModule(1), torch.randn(2, 3))
self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
def test_max_pool2d(self):
for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
with self.subTest(name):
self.check(torch.nn.MaxPool2d(2), inp)
self.check(torch.nn.MaxPool2d((3, 4)), inp)
self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
def test_adaptive_avg_pool2d(self):
for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
with self.subTest(name):
self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
with self.assertRaisesRegex(Exception, "with output size"):
self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
def test_upsample_nearest2d(self):
convert_args = dict(self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128))
for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
with self.subTest(name):
self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
self.check(
torch.nn.UpsamplingNearest2d(size=(24, 32)), inp,
convert_args=[convert_args[name]]
)
self.check(
torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp,
convert_args=[convert_args[name]]
)
def test_linear(self):
torch.manual_seed(29)
self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
def test_conv2d(self):
cases = [
# in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name
( 4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241
( 4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241
( 4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241
( 8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241
( 4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241
( 4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241
( 8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241
]
for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
for case in cases:
in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name = case
with self.subTest("{}-{}".format(kind, name)):
inp = torch.randn(input_dim)
model = torch.nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=bool(bias))
output_size = model(inp).numel()
atol_rtol = None
limit = None
convert_dims = input_dim[:2] + (0, 0)
convert_arg = torch.zeros(*convert_dims)
if "quant" in kind:
model = torch.nn.Sequential(model)
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
model = torch.quantization.prepare(model)
model(inp)
model = torch.quantization.convert(model)
inp = qpt(inp, 1.0 / 16, 128)
# I've seen numerical differences between QNNPACK and NNAPI,
# but never more than 1 quantum, and never more than ~1% of
# the output in this test.
atol_rtol = (1, 0)
limit = output_size * 0.03
convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
if "nhwc" in kind:
inp = nhwc(inp)
convert_arg = nhwc(convert_arg)
self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
self.check(
model,
inp,
convert_args=[convert_arg],
atol_rtol=atol_rtol,
limit=limit
)
def test_qadd(self):
func = torch.nn.quantized.QFunctional()
func.scale = 0.5
func.zero_point = 120
class AddMod(torch.nn.Module):
def forward(self, lhs, rhs):
return func.add(lhs, rhs)
class AddReluMod(torch.nn.Module):
def forward(self, lhs, rhs):
return func.add_relu(lhs, rhs)
for (name, mod) in [("add", AddMod), ("add_relu", AddReluMod)]:
with self.subTest(name):
self.check(
mod(),
[
qpt([1.0, 2.0], 0.25, 128),
qpt([3.0, 4.0], 0.25, 128),
])
self.check(
mod(),
[
qpt([[1.0, 2.0]], 0.25, 128),
qpt([[3.0, 4.0]], 0.25, 128),
],
convert_args=[
qpt([[1.0, 2.0]], 0.25, 128),
qpt(torch.zeros((1, 2)), 0.25, 128),
]
)
self.check(
mod(),
[
qpt([[1.0, 2.0]], 0.25, 128),
qpt([[3.0, 4.0]], 0.25, 128),
],
convert_args=[
qpt(torch.zeros((1, 2)), 0.25, 128),
qpt([[3.0, 4.0]], 0.25, 128),
]
)
self.check(
mod(),
[
qpt([[1.0, 2.0]], 0.25, 128),
qpt([[3.0, 4.0]], 0.25, 128),
],
convert_args=[
qpt(torch.zeros((1, 2)), 0.25, 128),
qpt(torch.zeros((1, 2)), 0.25, 128),
]
)
# NOTE: NNAPI qadd supports broadcast, but PT does not.
def test_qlinear(self):
torch.manual_seed(29)
weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
bias = torch.randn(16)
mod = torch.nn.quantized.Linear(32, 16)
mod.set_weight_bias(weight, bias)
inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
self.check(mod, inp)
def test_seblock_mul(self):
class MulModel(torch.nn.Module):
def forward(self, lhs, rhs):
return lhs * rhs
self.check(
MulModel(),
[
nhwc(torch.randn(2, 3, 4, 4)),
torch.randn(1, 3, 1, 1),
])
def test_multi_output(self):
class MultiModel(torch.nn.Module):
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
the_sum = lhs + rhs
the_diff = lhs - rhs
return the_sum, the_diff
self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
if __name__ == '__main__':
run_tests()