-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathengine.py
150 lines (134 loc) · 5.15 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from tester import TensorConfig, APIConfig, analyse_configs
from tester import APITestAccuracy, APITestPaddleOnly, APITestCINNVSDygraph
import re
import collections
import paddle
import numpy
import math
import json
import paddle
import inspect
import argparse
import subprocess
import os
from datetime import datetime
DIR_PATH = os.path.dirname(os.path.realpath(__file__))[0:os.path.dirname(os.path.realpath(__file__)).index("PaddleAPITest")+13]
def get_notsupport_config():
not_support_files = [
"tester/api_config/api_config_merged_not_support_amp.txt",
"tester/api_config/api_config_merged_not_support_arange.txt",
"tester/api_config/api_config_merged_not_support_empty.txt",
"tester/api_config/api_config_merged_not_support_eye.txt",
"tester/api_config/api_config_merged_not_support_flatten.txt",
"tester/api_config/api_config_merged_not_support_full.txt",
"tester/api_config/api_config_merged_not_support_getset_item.txt",
"tester/api_config/api_config_merged_not_support_reshape.txt",
"tester/api_config/api_config_merged_not_support_slice.txt",
"tester/api_config/api_config_merged_not_support_sparse.txt",
"tester/api_config/api_config_merged_not_support_tensor_init.txt",
"tester/api_config/api_config_merged_not_support_topk.txt",
"tester/api_config/api_config_merged_not_support_zeros.txt",
"tester/api_config/api_config_merged_not_support.txt",
]
configs = set()
for flie in not_support_files:
with open(DIR_PATH+"/"+flie, "r") as f:
origin_configs = f.readlines()
f.close()
for config in origin_configs:
configs.add(config)
return configs
def main():
parser = argparse.ArgumentParser(
description='API Test'
)
parser.add_argument(
'--api_config_file',
default="",
)
parser.add_argument(
'--api_config',
default="",
)
parser.add_argument(
'--paddle_only',
default=False,
)
parser.add_argument(
'--paddle_cinn',
default=False,
)
parser.add_argument(
'--accuracy',
default=False,
)
parser.add_argument(
'--test_amp',
default=False,
)
options = parser.parse_args()
test_class = APITestAccuracy
if options.paddle_only:
test_class = APITestPaddleOnly
elif options.paddle_cinn:
test_class = APITestCINNVSDygraph
elif options.accuracy:
test_class = APITestAccuracy
if options.api_config != "":
print("test begin:", options.api_config, flush=True)
try:
api_config = APIConfig(options.api_config)
except Exception as err:
print("[config parse error]", options.api_config, str(err))
return
case = test_class(api_config, options.test_amp)
case.test()
case.clear_tensor()
del case
del api_config
elif options.api_config_file != "":
try:
checkpoint_r = open(DIR_PATH+"/tester/api_config/test_log/checkpoint.txt", "r")
finish_configs = set(checkpoint_r.readlines())
checkpoint_r.close()
except Exception as err:
finish_configs = set()
not_support_api_config = open(DIR_PATH+"/tester/api_config/api_config_merged_not_support.txt", "r")
not_support_api_config = set(not_support_api_config.readlines())
checkpoint = open(DIR_PATH+"/tester/api_config/test_log/checkpoint.txt", "a")
api_config_file = open(options.api_config_file, "r")
api_configs = set(api_config_file.readlines())
api_configs = api_configs - finish_configs - not_support_api_config
api_configs = sorted(api_configs)
for api_config_str in api_configs:
checkpoint.write(api_config_str)
checkpoint.flush()
print(datetime.now(), "test begin:", api_config_str, flush=True)
try:
api_config = APIConfig(api_config_str)
except Exception as err:
print("[config parse error]", api_config_str, str(err))
continue
case = test_class(api_config, options.test_amp)
try:
case.test()
except Exception as err:
if "CUDA error" in str(err) or "memory corruption" in str(err) or "CUDA out of memory" in str(err):
exit(0)
case.clear_tensor()
del case
del api_config
# elif options.api_config_file != "":
# with open(options.api_config_file, "r") as f:
# configs = f.readlines()
# f.close()
# for config in configs:
# config = config.replace("\n", "")
# cmd = ["python", "engine.py", "--api_config=" + config]
# res = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
# # res.wait()
# print(str(res.stdout.read(), encoding="utf-8"), flush=True)
# print(str(res.stderr.read(), encoding="utf-8"), flush=True)
# # res.terminate()
if __name__ == '__main__':
main()