Skip to content

Commit e827775

Browse files
committed
0.5.dev, first refactor, add offline_tokneizion for pretraining, add xxpo for dpo/rpo/orpo, add mpt as a optional, safe, but space-consuming way for full-params pretrain/MFT
1 parent bce9d43 commit e827775

31 files changed

+5633
-596
lines changed

mftcoder_accelerate/src/data/blendable_dataset.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, datasets, weights):
4343

4444
# recompute weights
4545
weights = self.calc_weights()
46-
46+
4747
# Build indices.
4848
start_time = time.time()
4949
assert num_datasets < 255
@@ -63,17 +63,15 @@ def __init__(self, datasets, weights):
6363

6464
print(
6565
"> RANK {} elapsed time for building blendable dataset indices: "
66-
"{:.2f} (sec)".format(
67-
torch.distributed.get_rank(), time.time() - start_time
68-
)
66+
"{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time)
6967
)
7068

7169
def calc_weights(self):
7270
dataset_sample_cnt = [len(ds) for ds in self.datasets]
7371
total_cnt = sum(dataset_sample_cnt)
7472
weights = np.array([(cnt + 0.0) / total_cnt for cnt in dataset_sample_cnt], dtype=np.float64)
7573
return weights
76-
74+
7775
def __len__(self):
7876
return self.size
7977

mftcoder_accelerate/src/data/data_utils.py

+23-35
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
3232

3333
start_time = time.time()
3434
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
35-
print_rank_0(
36-
" > finished creating indexed dataset in {:4f} "
37-
"seconds".format(time.time() - start_time)
38-
)
35+
print_rank_0(" > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time))
3936
print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0]))
4037

4138
return indexed_dataset
@@ -53,20 +50,22 @@ def build_train_valid_test_datasets(
5350
build_index_mappings=True,
5451
shuffle_before_split=False,
5552
weighted_loss_mode=None,
56-
ds_weights=[1., 1., 1.],
57-
train_mode='sft',
53+
ds_weights=[1.0, 1.0, 1.0],
54+
train_mode="sft",
5855
):
5956
"""Build train, valid, and test datasets."""
6057

6158
# Indexed dataset.
62-
assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin"
59+
assert os.path.exists(
60+
data_prefix + "_input_ids.bin"
61+
), f"Input tokens datafile not found: {data_prefix}_input_ids.bin"
6362

6463
# Indexed dataset.
6564
input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup)
66-
if train_mode == 'sft':
65+
if train_mode == "sft":
6766
loss_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_loss_mask", data_impl, skip_warmup)
6867
else:
69-
print(f'pretrain mode, loss mask is ones')
68+
print(f"pretrain mode, loss mask is ones")
7069
loss_mask_indexed_dataset = None
7170

7271
total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0]
@@ -79,9 +78,7 @@ def print_split_stats(name, index):
7978
print_rank_0(" {}:".format(name))
8079
print_rank_0(
8180
" document indices in [{}, {}) total of {} "
82-
"documents".format(
83-
splits[index], splits[index + 1], splits[index + 1] - splits[index]
84-
)
81+
"documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
8582
)
8683

8784
print_split_stats("train", 0)
@@ -100,11 +97,9 @@ def build_dataset(index, name, ds_weight=1.0):
10097
dataset = None
10198
if splits[index + 1] > splits[index]:
10299
if shuffle_before_split:
103-
documents = shuffle_doc_index[splits[index]:splits[index + 1]]
100+
documents = shuffle_doc_index[splits[index] : splits[index + 1]]
104101
else:
105-
documents = np.arange(
106-
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
107-
)
102+
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
108103

109104
dataset = GPT2PromptDataset(
110105
name,
@@ -130,11 +125,13 @@ def build_dataset(index, name, ds_weight=1.0):
130125
return train_dataset, valid_dataset, test_dataset, total_num_of_documents
131126

132127

133-
def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False):
128+
def build_multiple_train_valid_test_datasets(
129+
args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False
130+
):
134131
"""Build multiple train, valid, and test datasets."""
135-
data_prefixes = list(args.data_paths[1:-1].split(','))
132+
data_prefixes = list(args.data_paths[1:-1].split(","))
136133

137-
data_weights = list(map(float, args.data_weights[1:-1].split(',')))
134+
data_weights = list(map(float, args.data_weights[1:-1].split(",")))
138135
print("data weights: ")
139136
print(data_weights)
140137
use_shared_fs = use_shared_fs
@@ -143,7 +140,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
143140
seq_length = args.seq_length
144141
# seq_length = args.block_size
145142
seed = args.seed
146-
skip_warmup = (not mmap_warmup)
143+
skip_warmup = not mmap_warmup
147144
weight_by_num_documents = args.weight_by_num_documents
148145
shuffle_before_split = args.shuffle_before_split
149146
weighted_loss_mode = args.weighted_loss_mode
@@ -183,9 +180,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
183180
factor = 1
184181
if weight_by_num_documents:
185182
# gets the number of documents in each data path
186-
get_num_docs_list = lambda datasets: [
187-
dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets
188-
]
183+
get_num_docs_list = lambda datasets: [dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets]
189184
train_num_docs, valid_num_docs, test_num_docs = (
190185
get_num_docs_list(train_datasets),
191186
get_num_docs_list(valid_datasets),
@@ -201,7 +196,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
201196
)
202197
assert sum(train_weights) != 0.0, "found train weights to be 0.0"
203198
assert sum(valid_weights) != 0.0, "found valid weights to be 0.0"
204-
199+
205200
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
206201
train_weights, train_valid_test_num_samples[0]
207202
)
@@ -265,7 +260,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
265260
if num_tokens:
266261
factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length)
267262
factor /= sum([1.0 / w for w in train_ds_weights]) / len(train_ds_weights)
268-
263+
269264
print_rank_0(f"> common denomination factor for CE loss: {factor}")
270265

271266
# Blend.
@@ -274,7 +269,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples,
274269
i = 0
275270
for ds in train_datasets:
276271
ds.update_ds_weight(ds.ds_weight / factor)
277-
print(f'loss weight of dataset {i} after update: {ds.ds_weight}')
272+
print(f"loss weight of dataset {i} after update: {ds.ds_weight}")
278273
i += 1
279274
blending_train_dataset = BlendableDataset(train_datasets, train_weights)
280275
blending_valid_dataset = None
@@ -318,9 +313,7 @@ def get_train_valid_test_split_(splits_string, size):
318313
return splits_index
319314

320315

321-
def get_normalized_weights_and_num_samples(
322-
weights: List[float], num_samples: int
323-
) -> Tuple[List[float], List[int]]:
316+
def get_normalized_weights_and_num_samples(weights: List[float], num_samples: int) -> Tuple[List[float], List[int]]:
324317
# Normalize weights
325318
weight_sum = sum(weights)
326319
assert weight_sum > 0.0
@@ -346,12 +339,7 @@ def get_datasets_normalized_weights_and_num_samples(
346339
# samples left to feed to the network.
347340
weighted_num_samples = []
348341
for weight in weights:
349-
weighted_num_samples.append(
350-
[
351-
int(math.ceil(val * weight * 1.005))
352-
for val in num_samples
353-
]
354-
)
342+
weighted_num_samples.append([int(math.ceil(val * weight * 1.005)) for val in num_samples])
355343
return weights, weighted_num_samples
356344

357345

mftcoder_accelerate/src/data/gpt2_dataset.py

+16-39
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
use_shared_fs=True,
4242
weighted_loss_mode=None,
4343
ds_weight=1.0,
44-
train_mode='sft',
44+
train_mode="sft",
4545
):
4646

4747
self.name = name
@@ -50,9 +50,9 @@ def __init__(
5050

5151
self.weighted_loss_mode = weighted_loss_mode
5252
self.ds_weight = ds_weight
53-
54-
self.task_name = data_prefix.split('/')[-1]
55-
53+
54+
self.task_name = data_prefix.split("/")[-1]
55+
5656
self.task_id = TASK2ID[self.task_name]
5757

5858
# Checks
@@ -114,14 +114,10 @@ def __getitem__(self, idx):
114114

115115
else:
116116
# Otherwise, get the rest of the initial document.
117-
input_ids_list = [
118-
self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
119-
]
117+
input_ids_list = [self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
120118

121119
if self.loss_mask_indexed_dataset is not None:
122-
loss_mask_list = [
123-
self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
124-
]
120+
loss_mask_list = [self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
125121
else:
126122
loss_mask_list = []
127123

@@ -133,16 +129,12 @@ def __getitem__(self, idx):
133129

134130
# And finally add the relevant portion of last document.
135131
input_ids_list.append(
136-
self.input_ids_indexed_dataset.get(
137-
self.doc_idx[doc_index_l], length=offset_l + 1
138-
)
132+
self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
139133
)
140134

141135
if self.loss_mask_indexed_dataset is not None:
142136
loss_mask_list.append(
143-
self.loss_mask_indexed_dataset.get(
144-
self.doc_idx[doc_index_l], length=offset_l + 1
145-
)
137+
self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
146138
)
147139

148140
input_ids = np.concatenate(input_ids_list)
@@ -246,18 +238,12 @@ def __getitem__(self, idx):
246238
)
247239
else:
248240
# Otherwise, get the rest of the initial document.
249-
sample_list = [
250-
self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
251-
]
241+
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
252242
# Loop over all in between documents and add the entire document.
253243
for i in range(doc_index_f + 1, doc_index_l):
254244
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
255245
# And finally add the relevant portion of last document.
256-
sample_list.append(
257-
self.indexed_dataset.get(
258-
self.doc_idx[doc_index_l], length=offset_l + 1
259-
)
260-
)
246+
sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1))
261247
sample = np.concatenate(sample_list)
262248

263249
return {"text": np.array(sample, dtype=np.int64)}
@@ -313,10 +299,7 @@ def _build_index_mappings(
313299
or (not os.path.isfile(sample_idx_filename))
314300
or (not os.path.isfile(shuffle_idx_filename))
315301
):
316-
print_rank_0(
317-
" > WARNING: could not find index map files, building "
318-
"the indices on rank 0 ..."
319-
)
302+
print_rank_0(" > WARNING: could not find index map files, building " "the indices on rank 0 ...")
320303
# doc-idx.
321304
start_time = time.time()
322305
doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
@@ -338,13 +321,9 @@ def _build_index_mappings(
338321
# 我理解这里的num_samples应该是和入参的num_samples重名,这里只是为了计算构建所有索引的长度,从而决定是用int64还是int32
339322
num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length
340323
if 2 * (num_samples + 1) < np.iinfo(np.int32).max:
341-
sample_idx = helpers.build_sample_idx_int32(
342-
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
343-
)
324+
sample_idx = helpers.build_sample_idx_int32(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch)
344325
else:
345-
sample_idx = helpers.build_sample_idx_int64(
346-
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
347-
)
326+
sample_idx = helpers.build_sample_idx_int64(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch)
348327
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
349328
print_rank_0(
350329
" > elapsed time to build and save sample-idx mapping "
@@ -360,7 +339,7 @@ def _build_index_mappings(
360339
" > elapsed time to build and save shuffle-idx mapping"
361340
" (seconds): {:4f}".format(time.time() - start_time)
362341
)
363-
342+
364343
torch.distributed.barrier() # TODO: model parallel
365344

366345
# This should be a barrier but nccl barrier assumes
@@ -370,7 +349,7 @@ def _build_index_mappings(
370349
# torch.distributed.all_reduce(counts)
371350
# torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group())
372351
# assert counts[0].item() == torch.distributed.get_world_size(
373-
# group=mpu.get_io_parallel_group()
352+
# group=mpu.get_io_parallel_group()
374353
# )
375354

376355
# Load mappings.
@@ -381,9 +360,7 @@ def _build_index_mappings(
381360
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r")
382361
print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename))
383362
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r")
384-
print_rank_0(
385-
" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
386-
)
363+
print_rank_0(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time))
387364
print_rank_0(" total number of samples: {}".format(sample_idx.shape[0]))
388365
print_rank_0(" total number of epochs: {}".format(num_epochs))
389366

0 commit comments

Comments
 (0)