forked from jjfeng/spinn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparallel_worker.py
88 lines (78 loc) · 2.48 KB
/
parallel_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import os
import traceback
import numpy as np
import tensorflow as tf
class ParallelWorker:
"""
Stores the information for running something in parallel
These workers can be run throught the ParallelWorkerManager
"""
def __init__(self, seed):
"""
@param seed: a seed for for each parallel worker
"""
raise NotImplementedError()
def run(self):
"""
Do not implement this function!
"""
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
result = None
try:
result = self.run_worker()
except Exception as e:
print("Exception caught in parallel worker: %s", e)
traceback.print_exc()
return result
def run_worker(self):
"""
Implement this function!
Returns whatever value needed from this task
"""
raise NotImplementedError()
def __str__(self):
"""
@return: string for identifying this worker in an error
"""
raise NotImplementedError()
class ParallelWorkerManager:
"""
Runs many ParallelWorkers
"""
def run(self):
raise NotImplementedError()
class MultiprocessingManager(ParallelWorkerManager):
"""
Handles submitting jobs to a multiprocessing pool
We have written our own custom function for batching jobs together
So runs ParallelWorkers using multiple CPUs on the same machine
"""
def __init__(self, pool, worker_list):
"""
@param worker_list: List of ParallelWorkers
"""
self.pool = pool
self.worker_list = worker_list
def run(self):
try:
results_raw = self.pool.map(run_multiprocessing_worker, self.worker_list)
except Exception as e:
print("Error occured when trying to process workers in parallel: %s", e)
# Just do it all one at a time instead
results_raw = map(run_multiprocessing_worker, self.worker_list)
results = []
for i, r in enumerate(results_raw):
if r is None:
print("WARNING: multiprocessing worker for this worker failed: %s", self.worker_list[i])
else:
results.append(r)
return results
def run_multiprocessing_worker(worker):
"""
@param worker: ParallelWorker
Function called on each worker process, used by MultiprocessingManager
Note: this must be a global function
"""
return worker.run()