-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathproblems.py
156 lines (129 loc) · 6.19 KB
/
problems.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import tensorflow as tf
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import registry
from tensor2tensor.utils import mlperf_log
def txt_line_iterator(txt_path):
"""Iterate through lines of file."""
with open(txt_path) as f:
for line in f:
yield line.strip()
CONFIG = {}
@registry.register_problem
class TranslateManyToMany(translate.TranslateProblem):
@property
def use_small_dataset(self):
return False #'wtf?'
@property
def name(self):
return "translate_many_to_many"
@property
def vocab_filename(self):
return "vocab.txt"
def vocab_data_files(self):
"""Files to be passed to get_or_generate_vocab. Skips langpair files."""
return [[x1, x2[:-1]] for x1, x2 in self.source_data_files(problem.DatasetSplit.TRAIN)]
@property
def approx_vocab_size(self):
return 65536 # 2**16
def dataset_filename(self):
return "translate"
@property
def additional_reserved_tokens(self):
return self.prefixes
@property
def prefixes(self):
return ["2<%s>" % lang for lang in CONFIG['languages']]
@property
def inputs_prefix(self):
raise NotImplementedError() #return "translate English German "
@property
def dataset_splits(self):
"""Splits of data to produce and number of output shards for each."""
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 50,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
def source_data_files(self, dataset_split):
if dataset_split == problem.DatasetSplit.TRAIN:
return CONFIG['training_files']
else:
return CONFIG['testing_files']
def generate_samples(self, data_dir, tmp_dir, dataset_split, custom_iterator='unused'):
datasets = self.source_data_files(dataset_split)
if dataset_split == problem.DatasetSplit.TRAIN:
tag = "train"
datatypes_to_clean = self.datatypes_to_clean
else:
tag = "dev"
datatypes_to_clean = None
data_paths = [d[1] for d in datasets]
return self._meta_iterator(data_paths)
def _meta_iterator(self, data_files):
for source, target, index in data_files:
print("Processing file", source)
for example_dict in self.text2text_txt_iterator(source, target, index):
yield example_dict
def _determine_language_from_suffix(self, filename):
if filename.endswith(".txt"):
filename = filename[:-4]
language = filename.split(".")[-1]
if language not in CONFIG['languages']:
raise ValueError(f"The text file {filename} has an unexpected suffix: {language}. Expecting one of the language codes: {CONFIG['languages']}")
return language
def text2text_txt_iterator(self, source_txt_path, target_txt_path, index_path=None, bidirectional=True):
"""Yield dicts for Text2TextProblem.generate_samples from lines of files."""
if index_path is not None:
# We're dealing with pre-merged files with multiple languages ...
# Each line of index_path contains the languages of the corresponding
# line in source_txt_path and target_txt_path.
for inputs, targets, language_pair in zip(
txt_line_iterator(source_txt_path),
txt_line_iterator(target_txt_path),
txt_line_iterator(index_path)
):
src_lang, tgt_lang = language_pair.split(" ")
yield {"inputs": inputs, "targets": targets, "src_lang": src_lang, "tgt_lang": tgt_lang}
if bidirectional:
yield {"inputs": targets, "targets": inputs, "src_lang": tgt_lang, "tgt_lang": src_lang}
else:
# We're dealing with a file containing only one language pair ...
# We only need to determine the language from the files' suffixes.
src_lang = self._determine_language_from_suffix(source_txt_path)
tgt_lang = self._determine_language_from_suffix(target_txt_path)
for inputs, targets in zip(txt_line_iterator(source_txt_path), txt_line_iterator(target_txt_path)):
yield {"inputs": inputs, "targets": targets, "src_lang": src_lang, "tgt_lang": tgt_lang}
if bidirectional:
yield {"inputs": targets, "targets": inputs, "src_lang": tgt_lang, "tgt_lang": src_lang}
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
if dataset_split == problem.DatasetSplit.TRAIN:
mlperf_log.transformer_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING)
elif dataset_split == problem.DatasetSplit.EVAL:
mlperf_log.transformer_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL)
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
return self.text2text_generate_encoded(generator, encoder,
has_inputs=self.has_inputs,
inputs_prefix=None,
targets_prefix=None)
def text2text_generate_encoded(self, sample_generator,
vocab,
targets_vocab=None,
has_inputs=True,
inputs_prefix=None,
targets_prefix=None):
"""Encode Text2Text samples from the generator with the vocab."""
targets_vocab = targets_vocab or vocab
for sample in sample_generator:
if has_inputs:
sample["inputs"] = vocab.encode(sample["inputs"])
sample["inputs"].append(text_encoder.EOS_ID)
sample["inputs"].insert(0, CONFIG['languages'].index(sample["tgt_lang"]) + 2)
sample["targets"] = targets_vocab.encode(sample["targets"])
sample["targets"].append(text_encoder.EOS_ID)
yield sample