@@ -32,10 +32,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
32
32
33
33
start_time = time .time ()
34
34
indexed_dataset = make_indexed_dataset (data_prefix , data_impl , skip_warmup )
35
- print_rank_0 (
36
- " > finished creating indexed dataset in {:4f} "
37
- "seconds" .format (time .time () - start_time )
38
- )
35
+ print_rank_0 (" > finished creating indexed dataset in {:4f} " "seconds" .format (time .time () - start_time ))
39
36
print_rank_0 (" number of documents: {}" .format (indexed_dataset .sizes .shape [0 ]))
40
37
41
38
return indexed_dataset
@@ -53,20 +50,22 @@ def build_train_valid_test_datasets(
53
50
build_index_mappings = True ,
54
51
shuffle_before_split = False ,
55
52
weighted_loss_mode = None ,
56
- ds_weights = [1. , 1. , 1. ],
57
- train_mode = ' sft' ,
53
+ ds_weights = [1.0 , 1.0 , 1.0 ],
54
+ train_mode = " sft" ,
58
55
):
59
56
"""Build train, valid, and test datasets."""
60
57
61
58
# Indexed dataset.
62
- assert os .path .exists (data_prefix + "_input_ids.bin" ), f"Input tokens datafile not found: { data_prefix } _input_ids.bin"
59
+ assert os .path .exists (
60
+ data_prefix + "_input_ids.bin"
61
+ ), f"Input tokens datafile not found: { data_prefix } _input_ids.bin"
63
62
64
63
# Indexed dataset.
65
64
input_ids_indexed_dataset = get_indexed_dataset_ (data_prefix + "_input_ids" , data_impl , skip_warmup )
66
- if train_mode == ' sft' :
65
+ if train_mode == " sft" :
67
66
loss_mask_indexed_dataset = get_indexed_dataset_ (data_prefix + "_loss_mask" , data_impl , skip_warmup )
68
67
else :
69
- print (f' pretrain mode, loss mask is ones' )
68
+ print (f" pretrain mode, loss mask is ones" )
70
69
loss_mask_indexed_dataset = None
71
70
72
71
total_num_of_documents = input_ids_indexed_dataset .sizes .shape [0 ]
@@ -79,9 +78,7 @@ def print_split_stats(name, index):
79
78
print_rank_0 (" {}:" .format (name ))
80
79
print_rank_0 (
81
80
" document indices in [{}, {}) total of {} "
82
- "documents" .format (
83
- splits [index ], splits [index + 1 ], splits [index + 1 ] - splits [index ]
84
- )
81
+ "documents" .format (splits [index ], splits [index + 1 ], splits [index + 1 ] - splits [index ])
85
82
)
86
83
87
84
print_split_stats ("train" , 0 )
@@ -100,11 +97,9 @@ def build_dataset(index, name, ds_weight=1.0):
100
97
dataset = None
101
98
if splits [index + 1 ] > splits [index ]:
102
99
if shuffle_before_split :
103
- documents = shuffle_doc_index [splits [index ]: splits [index + 1 ]]
100
+ documents = shuffle_doc_index [splits [index ] : splits [index + 1 ]]
104
101
else :
105
- documents = np .arange (
106
- start = splits [index ], stop = splits [index + 1 ], step = 1 , dtype = np .int32
107
- )
102
+ documents = np .arange (start = splits [index ], stop = splits [index + 1 ], step = 1 , dtype = np .int32 )
108
103
109
104
dataset = GPT2PromptDataset (
110
105
name ,
@@ -130,11 +125,13 @@ def build_dataset(index, name, ds_weight=1.0):
130
125
return train_dataset , valid_dataset , test_dataset , total_num_of_documents
131
126
132
127
133
- def build_multiple_train_valid_test_datasets (args , train_valid_test_num_samples , use_shared_fs = True , data_impl = "mmap" , mmap_warmup = False ):
128
+ def build_multiple_train_valid_test_datasets (
129
+ args , train_valid_test_num_samples , use_shared_fs = True , data_impl = "mmap" , mmap_warmup = False
130
+ ):
134
131
"""Build multiple train, valid, and test datasets."""
135
- data_prefixes = list (args .data_paths [1 :- 1 ].split (',' ))
132
+ data_prefixes = list (args .data_paths [1 :- 1 ].split ("," ))
136
133
137
- data_weights = list (map (float , args .data_weights [1 :- 1 ].split (',' )))
134
+ data_weights = list (map (float , args .data_weights [1 :- 1 ].split ("," )))
138
135
print ("data weights: " )
139
136
print (data_weights )
140
137
use_shared_fs = use_shared_fs
@@ -143,7 +140,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
143
140
seq_length = args .seq_length
144
141
# seq_length = args.block_size
145
142
seed = args .seed
146
- skip_warmup = ( not mmap_warmup )
143
+ skip_warmup = not mmap_warmup
147
144
weight_by_num_documents = args .weight_by_num_documents
148
145
shuffle_before_split = args .shuffle_before_split
149
146
weighted_loss_mode = args .weighted_loss_mode
@@ -183,9 +180,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
183
180
factor = 1
184
181
if weight_by_num_documents :
185
182
# gets the number of documents in each data path
186
- get_num_docs_list = lambda datasets : [
187
- dataset .input_ids_indexed_dataset .sizes .shape [0 ] for dataset in datasets
188
- ]
183
+ get_num_docs_list = lambda datasets : [dataset .input_ids_indexed_dataset .sizes .shape [0 ] for dataset in datasets ]
189
184
train_num_docs , valid_num_docs , test_num_docs = (
190
185
get_num_docs_list (train_datasets ),
191
186
get_num_docs_list (valid_datasets ),
@@ -201,7 +196,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
201
196
)
202
197
assert sum (train_weights ) != 0.0 , "found train weights to be 0.0"
203
198
assert sum (valid_weights ) != 0.0 , "found valid weights to be 0.0"
204
-
199
+
205
200
train_weights , train_num_samples = get_normalized_weights_and_num_samples (
206
201
train_weights , train_valid_test_num_samples [0 ]
207
202
)
@@ -265,7 +260,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
265
260
if num_tokens :
266
261
factor = sum (num_tokens ) / (sum (total_sample_cnt ) * args .seq_length )
267
262
factor /= sum ([1.0 / w for w in train_ds_weights ]) / len (train_ds_weights )
268
-
263
+
269
264
print_rank_0 (f"> common denomination factor for CE loss: { factor } " )
270
265
271
266
# Blend.
@@ -274,7 +269,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
274
269
i = 0
275
270
for ds in train_datasets :
276
271
ds .update_ds_weight (ds .ds_weight / factor )
277
- print (f' loss weight of dataset { i } after update: { ds .ds_weight } ' )
272
+ print (f" loss weight of dataset { i } after update: { ds .ds_weight } " )
278
273
i += 1
279
274
blending_train_dataset = BlendableDataset (train_datasets , train_weights )
280
275
blending_valid_dataset = None
@@ -318,9 +313,7 @@ def get_train_valid_test_split_(splits_string, size):
318
313
return splits_index
319
314
320
315
321
- def get_normalized_weights_and_num_samples (
322
- weights : List [float ], num_samples : int
323
- ) -> Tuple [List [float ], List [int ]]:
316
+ def get_normalized_weights_and_num_samples (weights : List [float ], num_samples : int ) -> Tuple [List [float ], List [int ]]:
324
317
# Normalize weights
325
318
weight_sum = sum (weights )
326
319
assert weight_sum > 0.0
@@ -346,12 +339,7 @@ def get_datasets_normalized_weights_and_num_samples(
346
339
# samples left to feed to the network.
347
340
weighted_num_samples = []
348
341
for weight in weights :
349
- weighted_num_samples .append (
350
- [
351
- int (math .ceil (val * weight * 1.005 ))
352
- for val in num_samples
353
- ]
354
- )
342
+ weighted_num_samples .append ([int (math .ceil (val * weight * 1.005 )) for val in num_samples ])
355
343
return weights , weighted_num_samples
356
344
357
345
0 commit comments