forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_slow_tests.py
105 lines (87 loc) · 4.47 KB
/
export_slow_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
import argparse
import json
import os
import statistics
from collections import defaultdict
from tools.stats_utils.s3_stat_parser import get_previous_reports_for_branch, Report, Version2Report
from typing import cast, DefaultDict, Dict, List, Any
from urllib.request import urlopen
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
SLOW_TEST_CASE_THRESHOLD_SEC = 60.0
RELATIVE_DIFFERENCE_THRESHOLD = 0.1
def get_test_case_times() -> Dict[str, float]:
reports: List[Report] = get_previous_reports_for_branch('origin/viable/strict', "")
# an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values]))
test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list)
for report in reports:
if report.get('format_version', 1) != 2: # type: ignore[misc]
raise RuntimeError("S3 format currently handled is version 2 only")
v2report = cast(Version2Report, report)
for test_file in v2report['files'].values():
for suitename, test_suite in test_file['suites'].items():
for casename, test_case in test_suite['cases'].items():
# The below attaches a __main__ as that matches the format of test.__class__ in
# common_utils.py (where this data will be used), and also matches what the output
# of a running test would look like.
name = f'{casename} (__main__.{suitename})'
succeeded: bool = test_case['status'] is None
if succeeded:
test_names_to_times[name].append(test_case['seconds'])
return {test_case: statistics.mean(times) for test_case, times in test_names_to_times.items()}
def filter_slow_tests(test_cases_dict: Dict[str, float]) -> Dict[str, float]:
return {test_case: time for test_case, time in test_cases_dict.items() if time >= SLOW_TEST_CASE_THRESHOLD_SEC}
def get_test_infra_slow_tests() -> Dict[str, float]:
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
contents = urlopen(url, timeout=1).read().decode('utf-8')
return cast(Dict[str, float], json.loads(contents))
def too_similar(calculated_times: Dict[str, float], other_times: Dict[str, float], threshold: float) -> bool:
# check that their keys are the same
if calculated_times.keys() != other_times.keys():
return False
for test_case, test_time in calculated_times.items():
other_test_time = other_times[test_case]
relative_difference = abs((other_test_time - test_time) / max(other_test_time, test_time))
if relative_difference > threshold:
return False
return True
def export_slow_tests(options: Any) -> None:
filename = options.filename
if os.path.exists(filename):
print(f'Overwriting existent file: {filename}')
with open(filename, 'w+') as file:
slow_test_times: Dict[str, float] = filter_slow_tests(get_test_case_times())
if options.ignore_small_diffs:
test_infra_slow_tests_dict = get_test_infra_slow_tests()
if too_similar(slow_test_times, test_infra_slow_tests_dict, options.ignore_small_diffs):
slow_test_times = test_infra_slow_tests_dict
json.dump(slow_test_times, file, indent=' ', separators=(',', ': '), sort_keys=True)
file.write('\n')
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description='Export a JSON of slow test cases in PyTorch unit test suite')
parser.add_argument(
'-f',
'--filename',
nargs='?',
type=str,
default=SLOW_TESTS_FILE,
const=SLOW_TESTS_FILE,
help='Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests.json',
)
parser.add_argument(
'--ignore-small-diffs',
nargs='?',
type=float,
const=RELATIVE_DIFFERENCE_THRESHOLD,
help='Compares generated results with stats/slow-tests.json in pytorch/test-infra. If the relative differences '
'between test times for each test are smaller than the threshold and the set of test cases have not '
'changed, we will export the stats already in stats/slow-tests.json. Else, we will export the calculated '
'results. The default threshold is 10%.',
)
return parser.parse_args()
def main() -> None:
options = parse_args()
export_slow_tests(options)
if __name__ == '__main__':
main()