forked from uma-pi1/kge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
search.py
232 lines (197 loc) · 8.8 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import copy
import gc
import os
import torch.multiprocessing
import concurrent.futures
from kge.job import Job, Trace
from kge.config import _process_deprecated_options
from kge.util.io import get_checkpoint_file, load_checkpoint
from kge.util.metric import Metric
from kge.misc import init_from
class SearchJob(Job):
"""Base class of jobs for hyperparameter search.
Provides functionality for scheduling training jobs across workers.
"""
def __init__(self, config, dataset, parent_job=None):
super().__init__(config, dataset, parent_job)
# create data structures for parallel job submission
self.num_workers = self.config.get("search.num_workers")
self.device_pool = self.config.get("search.device_pool")
if len(self.device_pool) == 0:
self.device_pool = [self.config.get("job.device")]
if len(self.device_pool) < self.num_workers:
self.device_pool = self.device_pool * self.num_workers
self.device_pool = self.device_pool[: self.num_workers]
self.config.log("Using device pool: {}".format(self.device_pool))
self.free_devices = copy.deepcopy(self.device_pool)
self.on_error = self.config.check("search.on_error", ["abort", "continue"])
self.running_tasks = set() #: set of futures currently runnning
self.ready_task_results = list() #: set of results
if self.num_workers > 1:
self.process_pool = concurrent.futures.ProcessPoolExecutor(
max_workers=self.num_workers,
mp_context=torch.multiprocessing.get_context("spawn"),
)
else:
self.process_pool = None # marks that we run in single process
self.config.check_range("valid.every", 1, config.get("train.max_epochs"))
if self.__class__ == SearchJob:
for f in Job.job_created_hooks:
f(self)
@staticmethod
def create(config, dataset, parent_job=None):
"""Factory method to create a search job."""
search_type = config.get("search.type")
class_name = config.get_default(f"{search_type}.class_name")
return init_from(class_name, config.modules(), config, dataset, parent_job)
def submit_task(self, task, task_arg, wait_when_full=True):
"""Runs the given task with the given argument.
When ``search.num_workers`` is 1, blocks and runs synchronous. Otherwise,
schedules the task at a free worker. If no worker is free, either waits
(`wait_when_full` true) or throws an error (`wait_when_full` false).
In addition to task_arg, the task is given a keyword argument `device`, holding
the device on which it should run.
"""
if self.process_pool is None:
self.ready_task_results.append(task(task_arg, device=self.free_devices[0]))
else:
if len(self.running_tasks) >= self.num_workers:
if wait_when_full:
self.config.log("No more free workers.")
self.wait_task()
else:
raise ValueError("no more free workers for running the task")
task_device = self.free_devices.pop(0)
future = self.process_pool.submit(task, task_arg, device=task_device)
future.add_done_callback(lambda _: self.free_devices.append(task_device))
self.running_tasks.add(future)
def wait_task(self, return_when=concurrent.futures.FIRST_COMPLETED):
"""Waits for one or more running tasks to complete.
Results of all completed tasks are copied into ``self.ready_task_results``.
When no task is running, does nothing.
"""
if len(self.running_tasks) > 0:
self.config.log("Waiting for tasks to complete...")
ready_tasks, self.running_tasks = concurrent.futures.wait(
self.running_tasks, return_when=return_when
)
for task in ready_tasks:
self.ready_task_results.append(task.result())
# Overridden such that instances of search job can be pickled to workers
def __getstate__(self):
state = dict(self.__dict__)
del state["process_pool"]
del state["running_tasks"]
return state
def _run_train_job(sicnk, device=None):
"""Runs a training job and returns the trace entry of its best validation result.
Also takes are of appropriate tracing.
"""
search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk
try:
# load the job
if device is not None:
train_job_config.set("job.device", device)
search_job.config.log(
"Starting training job {} ({}/{}) on device {}...".format(
train_job_config.folder,
train_job_index + 1,
train_job_count,
train_job_config.get("job.device"),
)
)
checkpoint_file = get_checkpoint_file(train_job_config)
if checkpoint_file is not None:
checkpoint = load_checkpoint(
checkpoint_file, train_job_config.get("job.device")
)
job = Job.create_from(
checkpoint=checkpoint,
new_config=train_job_config,
dataset=search_job.dataset,
parent_job=search_job,
)
else:
job = Job.create(
config=train_job_config,
dataset=search_job.dataset,
parent_job=search_job,
)
# process the trace entries to far (in case of a resumed job)
metric_name = search_job.config.get("valid.metric")
valid_trace = []
def copy_to_search_trace(job, trace_entry=None):
if trace_entry is None:
trace_entry = job.valid_trace[-1]
trace_entry = copy.deepcopy(trace_entry)
for key in trace_keys:
# Process deprecated options to some extent. Support key renames, but
# not value renames.
actual_key = {key: None}
_process_deprecated_options(actual_key)
if len(actual_key) > 1:
raise KeyError(
f"{key} is deprecated but cannot be handled automatically"
)
actual_key = next(iter(actual_key.keys()))
value = train_job_config.get(actual_key)
trace_entry[key] = value
trace_entry["folder"] = os.path.split(train_job_config.folder)[1]
metric_value = Trace.get_metric(trace_entry, metric_name)
trace_entry["metric_name"] = metric_name
trace_entry["metric_value"] = metric_value
trace_entry["parent_job_id"] = search_job.job_id
search_job.config.trace(**trace_entry)
valid_trace.append(trace_entry)
for trace_entry in job.valid_trace:
copy_to_search_trace(None, trace_entry)
# run the job (adding new trace entries as we go)
# TODO make this less hacky (easier once integrated into SearchJob)
from kge.job import ManualSearchJob
if not isinstance(search_job, ManualSearchJob) or search_job.config.get(
"manual_search.run"
):
job.post_valid_hooks.append(copy_to_search_trace)
job.run()
else:
search_job.config.log(
"Skipping running of training job as requested by user."
)
return (train_job_index, None, None)
# analyze the result
search_job.config.log("Best result in this training job:")
best = None
best_metric = None
for trace_entry in valid_trace:
metric = trace_entry["metric_value"]
if not best or Metric(search_job).better(metric, best_metric):
best = trace_entry
best_metric = metric
# record the best result of this job
best["child_job_id"] = best["job_id"]
for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]:
if k in best:
del best[k]
search_job.trace(
event="search_completed",
echo=True,
echo_prefix=" ",
log=True,
scope="train",
**best,
)
# force releasing the GPU memory of the job to avoid memory leakage
del job
gc.collect()
return (train_job_index, best, best_metric)
except BaseException as e:
search_job.config.log(
"Trial {:05d} failed: {}".format(train_job_index, repr(e))
)
if search_job.on_error == "continue":
return (train_job_index, None, None)
else:
search_job.config.log(
"Aborting search due to failure of trial {:05d}".format(train_job_index)
)
raise e