Skip to content

Commit

Permalink
Merge pull request #57 from RWKV/selective-loss-training
Browse files Browse the repository at this point in the history
Selective loss training
  • Loading branch information
PicoCreator authored Jan 18, 2024
2 parents 492c41f + ce4a461 commit 36e6737
Show file tree
Hide file tree
Showing 7 changed files with 1,291 additions and 251 deletions.
48 changes: 40 additions & 8 deletions RWKV-v5/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ def prepare_data_static(**kargs):

# =====================================================

# Util functions
#--------------------------------

# Apply the data_prefix_skip_mask to the given mask
# where relevent, and disables the training mask for the first X tokens
data_prefix_skip_mask_val = int(kargs["data_prefix_skip_mask"])
def apply_data_prefix_skip_mask(mask):
mask_len = len(mask)
if data_prefix_skip_mask_val > 0 and mask_len:
for i in range(max(data_prefix_skip_mask_val, mask_len)):
mask[i] = 0
return mask

# Special handling for binidx
#--------------------------------

Expand All @@ -66,7 +79,7 @@ def gen():
yield {
'input_ids': tokens,
'token_type_ids': [0] * len(tokens),
'attention_mask': [1] * len(tokens)
'attention_mask': apply_data_prefix_skip_mask([1] * len(tokens))
}

# Load the huggingface dataset from the generator
Expand Down Expand Up @@ -375,7 +388,7 @@ def map_tokenizer(x):
return {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask
'attention_mask': apply_data_prefix_skip_mask(attention_mask)
}

# Multi column merging support
Expand Down Expand Up @@ -443,7 +456,7 @@ def map_tokenizer(x):
return {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask
'attention_mask': apply_data_prefix_skip_mask(attention_mask)
}

# Prompt completion support
Expand Down Expand Up @@ -472,12 +485,17 @@ def map_tokenizer(x):
return {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'attention_mask': apply_data_prefix_skip_mask(attention_mask),
}

# Fallback to standard text tokenization
if 'text' in x:
return encodeTokens(x['text'])
ret = encodeTokens(x['text'])
return {
'input_ids': ret['input_ids'],
'token_type_ids': ret['token_type_ids'],
'attention_mask': apply_data_prefix_skip_mask(ret['attention_mask']),
}

raise ValueError('Invalid dataset format, must contain either the configured "multi column" or prompt/completion or text')

Expand Down Expand Up @@ -519,7 +537,7 @@ def rechunk_text(x):
# with the newline token in between
full_input_ids += x["input_ids"][i] + endOfDoc_tokenSet["input_ids"][0]
full_token_type_ids += x["token_type_ids"][i] + endOfDoc_tokenSet["token_type_ids"][0]
full_attention_mask += x["attention_mask"][i] + endOfDoc_tokenSet["attention_mask"][0]
full_attention_mask += apply_data_prefix_skip_mask( x["attention_mask"][i] ) + endOfDoc_tokenSet["attention_mask"][0]

# Total length, and sample count
# note that thte "remainder" will be discarded
Expand All @@ -540,7 +558,7 @@ def rechunk_text(x):
# Push the sample to the output arrays
out_input_ids.append(full_input_ids[start:end])
out_token_type_ids.append(full_token_type_ids[start:end])
out_attention_mask.append(full_attention_mask[start:end])
out_attention_mask.append(apply_data_prefix_skip_mask( full_attention_mask[start:end] ))

# Prepare and return the output object
ret = {
Expand All @@ -565,6 +583,8 @@ def dataset_filter(x):
return False
if kargs["max_token_size"] > 0 and row_length > kargs["max_token_size"]:
return False
if sum(x["attention_mask"]) <= 0:
return False
return True
src_dataset = src_dataset.filter(dataset_filter, num_proc=num_cpus)

Expand Down Expand Up @@ -902,6 +922,18 @@ def __init__(
# prompt/completion format masking support
disable_prompt_completion_mask: bool = False,

# ----------------------------
# Selective loss training
# ----------------------------

# Prefix token masking
#
# The rationale behind this, is that the first X tokens should not be "backpropped"
# for any new training record. As its unfair to expect the model (or a human) make
# any resonable guesses at that stage. As such this is used to "mask" the first X tokens
# from the loss calculation, and thus not backpropped.
data_prefix_skip_mask: int = 0,

# ----------------------------
# dataset packing support
# ----------------------------
Expand Down Expand Up @@ -1022,4 +1054,4 @@ def val_dataloader(self):
batch_size=1,
# Pinned in GPU memory
pin_memory=True
)
)
Loading

0 comments on commit 36e6737

Please sign in to comment.