-
Notifications
You must be signed in to change notification settings - Fork 1
/
map_eval.py
69 lines (54 loc) · 2.07 KB
/
map_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0
import json
import re
from warcio.archiveiterator import ArchiveIterator
import glob
import multiprocessing as mp
from tqdm import tqdm
import argparse
from copy import deepcopy
'''
Maps news articles URLs to those present in the CC News dump
'''
def filter_domains(fname):
filter_count = 0
dev_data = json.load(open("./data/dev.json"))
test_data = json.load(open("./data/test.json"))
inp_data = deepcopy(dev_data)
inp_data.extend(deepcopy(test_data))
all_urls = list()
for item in inp_data:
for doc in item["documents"]:
all_urls.append(doc["url"])
all_urls = set(all_urls)
try:
out_fname = re.sub('\.warc\.gz', '', fname) + '-filter.jsonl'
with open(fname, 'rb') as stream, open(out_fname, 'w') as f_out:
for record in tqdm(ArchiveIterator(stream)):
if record.rec_type == 'response':
url = record.rec_headers.get_header('WARC-Target-URI')
if url in all_urls:
content = record.content_stream().read()
html = str(content)
fields = {}
fields['html'] = html
fields['url'] = url
json_str = json.dumps(fields) + "\n"
f_out.write(json_str)
filter_count += 1
except Exception as e:
print (str(e))
return filter_count
def filter_data(input_path):
print("Number of processors: ", mp.cpu_count())
fnames = glob.glob(input_path)
print("Num files: ", len(fnames))
pool = mp.Pool(20)
res = pool.map_async(filter_domains, [fname for fname in fnames]).get()
pool.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Parser')
parser.add_argument('--input_path', type=str, help="input path to .gz files")
args = parser.parse_args()
filter_data(args.input_path)