-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcollect_jahs.py
56 lines (43 loc) · 1.56 KB
/
collect_jahs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from argparse import ArgumentParser
import itertools
import json
import numpy as np
import pandas as pd
from utils.api_wrapper import BenchmarkWrapper
from utils.constants import PARAM_NAMES, SEARCH_SPACE, TASK_NAMES
FIXED_CONFIGS = {p: [] for p in PARAM_NAMES}
FIXED_CONFIGS[PARAM_NAMES[0]] = 0.0
FIXED_CONFIGS[PARAM_NAMES[1]] = 0.0
for ps in itertools.product(*(SEARCH_SPACE[p] for p in PARAM_NAMES[2:])):
for name, p in zip(PARAM_NAMES[2:], ps):
FIXED_CONFIGS[name].append(p)
FIXED_CONFIGS["Optimizer"] = "SGD"
FIXED_CONFIGS["Resolution"] = 1.0
def save_results(dataset_name: str) -> None:
bench = BenchmarkWrapper(task=dataset_name)
config_table = pd.DataFrame(FIXED_CONFIGS)
results = []
for ps in itertools.product(
*(
SEARCH_SPACE["LearningRate"],
SEARCH_SPACE["WeightDecay"],
)
):
print(ps)
config_table[PARAM_NAMES[0]] = ps[0]
config_table[PARAM_NAMES[1]] = ps[1]
preds = bench(config_table)["valid-acc"]
results.append(preds.to_numpy())
with open(f"data/{dataset_name}.json", mode="w") as f:
divisor = 1000
rounded_data = (
np.asarray(np.hstack(results) * divisor, dtype=np.int32) / divisor
)
print(rounded_data, rounded_data.min())
json.dump(rounded_data.tolist(), f, indent=4)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--dataset", choices=TASK_NAMES, default="cifar10")
args = parser.parse_args()
dataset_name = args.dataset
save_results(dataset_name)