-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_checkpoints.py
51 lines (40 loc) · 1.6 KB
/
batch_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# import json
import logging
import sys
import glob
import daisy
import copy
sys.path.insert(0, 'segway/tasks')
import task_helper
from task_04_extract_segmentation import SegmentationTask
logger = logging.getLogger(__name__)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
user_configs, global_config = task_helper.parseConfigs(
sys.argv[1:], aggregate_configs=False)
orig_global_config = copy.deepcopy(global_config)
network_conf = global_config["Network"]
network_path = network_conf["train_dir"]
checkpoints = glob.glob(
network_path + "/" + "*.data-00000-of-00001")
checkpoints = [int(c.split('.')[0].split('_')[-1]) for c in checkpoints]
if ("batch_min_iteration" in network_conf or
"batch_max_iteration" in network_conf):
assert("batch_min_iteration" in network_conf)
assert("batch_max_iteration" in network_conf)
checkpoints = [
c for c in checkpoints if (
c >= network_conf["batch_min_iteration"] and
c <= network_conf["batch_max_iteration"])]
# print(checkpoints); exit(0)
for c in checkpoints:
global_config = copy.deepcopy(orig_global_config)
print("Running inference for iteration %s" % c)
global_config["Network"]["iteration"] = c
task_helper.aggregateConfigs(global_config)
print(global_config)
daisy.distribute(
[{'task': SegmentationTask(global_config=global_config,
**user_configs),
'request': None}],
global_config=global_config)