-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmapreduce.py
132 lines (90 loc) · 3.93 KB
/
mapreduce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import gzip
import json
import multiprocessing as mp
from pathos.pools import ProcessPool
from tqdm import tqdm
from contextlib import contextmanager
# Jsonl (GZ) handler --------------------------------------------------------------------
class JsonlSaver:
def __init__(self, save_dir, gzip_compress = False, num_objects = 1e5):
self.save_dir = save_dir
self.num_objects = num_objects
self.gzip_compress = gzip_compress
self._file_ending = ".jsonl.gz" if gzip_compress else ".jsonl"
self.object_count = 0
self.file_count = 0
self.file_handler = None
self._find_unique_index()
self._update_handler()
def _file_path(self):
return os.path.join(self.save_dir, f"statistics-{self.file_count}{self._file_ending}")
def _find_unique_index(self):
while os.path.exists(self._file_path()):
self.file_count += 1
def _open_file(self, file_path):
return (gzip.open if self.gzip_compress else open)(file_path, "wb")
def _update_handler(self):
need_update = self.file_handler is None or self.object_count >= self.num_objects
if not need_update: return
if self.file_handler is not None: self.file_handler.close()
self.file_handler = self._open_file(self._file_path())
self.file_count += 1
self.object_count = 0
def save(self, obj):
json_obj = json.dumps(obj) + "\n"
self.file_handler.write(json_obj.encode("utf-8"))
self.object_count += 1
self._update_handler()
def close(self):
if self.file_handler is not None:
self.file_handler.close()
self.file_handler = None
@contextmanager
def jsonl_reduce_io(output_dir, compress = False):
saver = JsonlSaver(output_dir, gzip_compress = compress)
try:
yield saver.save
finally:
saver.close()
# Map multiprocessing ----------------------------------------------------------------
def pmap(map_fn, data):
cpu_count = mp.cpu_count()
if cpu_count <= 4: # Too few CPUs for multiprocessing
for output in map(map_fn, data):
yield output
else:
with ProcessPool(processes = cpu_count) as pool:
for output in pool.uimap(map_fn, data, chunksize = 4 * cpu_count):
yield output
# Helper ------------------------------------------------------------------
def _reduce_mapped_instances(mapped_instance_stream, reducer_fn):
# Reduce all mapped instances
for mapped_instance in _reduce_generator(mapped_instance_stream):
reducer_fn(mapped_instance)
def _reduce_to_file(mapped_instance_stream, dir_path, compress = False):
with jsonl_reduce_io(dir_path, compress) as saver:
_reduce_mapped_instances(mapped_instance_stream, saver)
def _reduce_generator(mapped_instance_stream):
for mapped_instances in mapped_instance_stream:
if mapped_instances is None: continue
for mapped_instance in mapped_instances:
yield mapped_instance
# API method ----------------------------------------------------------------
# Map step runs in parrallel / Reduce in single thread
def mapreduce(data, map_fn, reducer_fn = None, parallel = False, compress = False, report = False):
"""
Map then reduce functions
Output of map has to be always a collection
reducer_fn == None: Same as pmap / map
reducer_fn == file_path: Saves all entries to jsonl into a dir
reducer_fn == callable : Calls reducer with the mapped results
"""
mapped_instance_stream = (pmap if parallel else map)(map_fn, data)
if report: mapped_instance_stream = tqdm(mapped_instance_stream, total = len(data))
if isinstance(reducer_fn, str):
_reduce_to_file(mapped_instance_stream, reducer_fn, compress)
elif callable(reducer_fn):
_reduce_mapped_instances(mapped_instance_stream, reducer_fn)
else:
return _reduce_generator(mapped_instance_stream)