Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eugene/train test split #87

Open
wants to merge 13 commits into
base: rwkv-x-h100
Choose a base branch
from
127 changes: 91 additions & 36 deletions RWKV-v5/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datasets import load_from_disk, load_dataset, concatenate_datasets, Dataset, Features, Value, Sequence
from transformers import PreTrainedTokenizerFast, AutoTokenizer
from multiprocessing import cpu_count
import gc, yaml
import gc, yaml, json

num_cpus = cpu_count()
num_workers = cpu_count() if cpu_count() < 8 else 8
Expand Down Expand Up @@ -45,7 +45,9 @@ def prepare_data_static(
source_dataset_params: dict = None,
# Source dataset split to use
source_dataset_split: str = "train",
# Test split of source data, if it was not already done
# test dataset split (if any)
test_dataset_split: str = "test",
# Test split of source data, if the test_dataset_split was not found
test_split: float = 0.01,
test_split_shuffle: bool = False,
# Text rechunking size
Expand Down Expand Up @@ -298,14 +300,29 @@ def gen():
# Load the dataset
src_dataset = load_dataset(**load_dataset_params)

# If for some reason the dataset is a "test" only split, and missing a "train" split, we remap it as a "train" split
# If for some reason the dataset missing the "train" split, we throw accordingly
if source_dataset_split not in src_dataset.keys():
raise ValueError('Dataset missing split: ' + source_dataset_split)

if source_dataset_split != "train":
src_dataset["train"] = src_dataset[source_dataset_split]
del src_dataset[source_dataset_split]

# If test split exists, and != "test", we will move it to "test"
# or clear existing test (if not exists). This will allow the test_split fallback to work
if test_dataset_split != "test":
if test_dataset_split in src_dataset.keys():
src_dataset["test"] = src_dataset[test_dataset_split]
del src_dataset[test_dataset_split]
elif "test" in src_dataset.keys():
del src_dataset["test"]

# Remove all splits, that is not "train" or "test"
src_dataset_keys = list(src_dataset.keys())
for key in src_dataset_keys:
if key not in ["train", "test"]:
del src_dataset[key]

# If an int value is used, it is interprated as document count
# If a floating value (<1.0) is used, it is interprated as a percentage of the dataset
if kargs["dataset_offset"] > 0 or kargs["dataset_length"] > 0:
Expand Down Expand Up @@ -367,28 +384,66 @@ def gen():

# Function used to tokenize the dataset as per HF tokenizer format
# if given the textual data, it will return the tokenized data
def encodeTokens(x):
def encodeTokens(x, enforceSingleItem = False):
if world_tokenizer is True:

# Empty / Null string handling
if x is None:
return {
'input_ids': [],
'token_type_ids': [],
'attention_mask': [],
}

# If x is an array of strings, we encode them seperately, and conslidate the result
if isinstance(x, list):
id_arr = []
type_arr = []
mask_arr = []
for i in range(len(x)):
enc_str = world_tokenizer_encode(x[i], world_add_endoftext_token=world_add_endoftext_token)
id_arr.append(enc_str)
type_arr.append([0] * len(enc_str))
mask_arr.append([1] * len(enc_str))

# Consolidate the result
if enforceSingleItem:
# Converts it from list to str
x = json.dumps(x)
else:

# Handles it as an array of string, that needs conversion
id_arr = []
type_arr = []
mask_arr = []
for i in range(len(x)):
enc_str = world_tokenizer_encode(str(x[i]), world_add_endoftext_token=world_add_endoftext_token)
id_arr.append(enc_str)
type_arr.append([0] * len(enc_str))
mask_arr.append([1] * len(enc_str))

# Consolidate the result
return {
'input_ids': id_arr,
'token_type_ids': type_arr,
'attention_mask': mask_arr
}

# Converting from dictionary
if isinstance(x, dict):
# Dictionary to json string
x = json.dumps(x)

# Converting from boolean
if isinstance(x, bool):
if x:
x = "true"
else:
x = "false"

# Enforce string type
x = str(x)

# Empty / Null string handling
if len(x) == 0:
return {
'input_ids': id_arr,
'token_type_ids': type_arr,
'attention_mask': mask_arr
'input_ids': [],
'token_type_ids': [],
'attention_mask': [],
}

# Else we encode the string and return it following the HF tokenizer format
enc_str = world_tokenizer_encode(x, world_add_endoftext_token=world_add_endoftext_token)
enc_str = world_tokenizer_encode(str(x), world_add_endoftext_token=world_add_endoftext_token)
return {
'input_ids': enc_str,
'token_type_ids': [0] * len(enc_str),
Expand Down Expand Up @@ -424,34 +479,34 @@ def encodeTokens(x):
# Tokenize the multi column strings
for i in range(len(multi_column_keys)):
if multi_column_prefix is not None and multi_column_prefix[i] is not None:
multi_column_prefix_encodings.append(encodeTokens(multi_column_prefix[i]))
multi_column_prefix_encodings.append(encodeTokens(multi_column_prefix[i], enforceSingleItem=True))
if multi_column_suffix is not None and multi_column_suffix[i] is not None:
multi_column_suffix_encodings.append(encodeTokens(multi_column_suffix[i]))
multi_column_suffix_encodings.append(encodeTokens(multi_column_suffix[i], enforceSingleItem=True))

# Tokenize the multi column separator
if multi_column_separator is not None and len(multi_column_separator) > 0:
multi_column_separator_encodings = encodeTokens(multi_column_separator)
multi_column_separator_encodings = encodeTokens(multi_column_separator, enforceSingleItem=True)

conversation_prefix_encoding_map = {}
conversation_suffix_encoding_map = {}
conversation_end_of_conversation_token = encodeTokens(kargs["conversation_end_of_conversation"]) if kargs["conversation_end_of_conversation"] is not None else None
conversation_end_of_conversation_token = encodeTokens(kargs["conversation_end_of_conversation"], enforceSingleItem=True) if kargs["conversation_end_of_conversation"] is not None else None
conversation_enabled = False
if 'conversation_format' in kargs and kargs["conversation_format"] is not None:
if kargs["conversation_format"] == "iopairs":
# preencode all prefixes (keyed by the input key)
for key, prefix in kargs['conversation_input_key_prefix_map'].items():
conversation_prefix_encoding_map[key] = encodeTokens(prefix)
conversation_prefix_encoding_map[key] = encodeTokens(prefix, enforceSingleItem=True)
conversation_enabled = True
elif kargs["conversation_format"] == "sender":
# preencode all prefixes (keyed by the sender value)
for key, relabel in kargs['conversation_sender_value_map'].items():
for input_key, value in kargs['conversation_input_key_map'].items():
if input_key not in conversation_prefix_encoding_map:
conversation_prefix_encoding_map[input_key] = {}
conversation_prefix_encoding_map[input_key][key] = encodeTokens(value.replace('{sender}', relabel))
conversation_prefix_encoding_map[input_key][key] = encodeTokens(value.replace('{sender}', relabel), enforceSingleItem=True)

for key, suffix in kargs['conversation_sender_suffix'].items():
conversation_suffix_encoding_map[key] = encodeTokens(suffix)
conversation_suffix_encoding_map[key] = encodeTokens(suffix, enforceSingleItem=True)
# example conversation_prefix_encoding_map['message']['user'] = encodeTokens('\n\nUser:')

conversation_enabled = True
Expand All @@ -471,7 +526,7 @@ def map_tokenizer(x):
# Custom text column support
if kargs["custom_text_key"] is not None:
if kargs["custom_text_key"] in x:
return encodeTokens(x[kargs["custom_text_key"]])
return encodeTokens(x[kargs["custom_text_key"]], enforceSingleItem=True)

if conversation_enabled:
conv_key = kargs['conversation_key'] if 'conversation_key' in kargs else None
Expand Down Expand Up @@ -499,7 +554,7 @@ def map_tokenizer(x):
attention_mask += prefix['attention_mask']

# Tokenize the column
column_encodings = encodeTokens(value)
column_encodings = encodeTokens(value, enforceSingleItem=True)

# Add the column
input_ids += column_encodings['input_ids']
Expand Down Expand Up @@ -537,7 +592,7 @@ def map_tokenizer(x):
attention_mask += prefix['attention_mask']

# Tokenize the column
column_encodings = encodeTokens(turn[key])
column_encodings = encodeTokens(turn[key], enforceSingleItem=True)

# Add the column
input_ids += column_encodings['input_ids']
Expand Down Expand Up @@ -574,7 +629,7 @@ def map_tokenizer(x):
# that have data in them
num_columns = 0
for i in range(len(multi_column_keys)):
if multi_column_keys[i] in x and x[multi_column_keys[i]] is not None and len(x[multi_column_keys[i]]) > 0:
if multi_column_keys[i] in x and x[multi_column_keys[i]] is not None and len(str(x[multi_column_keys[i]])) > 0:
num_columns += 1
# If we have more than 1 column, we will have to merge them
if num_columns > 1:
Expand All @@ -589,21 +644,21 @@ def map_tokenizer(x):
# Lets loop through each column
for i in range(len(multi_column_keys)):
# And process the column if it has data
if multi_column_keys[i] in x and x[multi_column_keys[i]] is not None and len(x[multi_column_keys[i]]) > 0:
if multi_column_keys[i] in x and x[multi_column_keys[i]] is not None and len(str(x[multi_column_keys[i]])) > 0:
# Add the separator if this is not the first item
if not is_first_item and multi_column_separator_encodings is not None:
input_ids += multi_column_separator_encodings['input_ids']
token_type_ids += multi_column_separator_encodings['token_type_ids']
attention_mask += multi_column_separator_encodings['attention_mask']
attention_mask += ([0] * len(multi_column_separator_encodings['input_ids']))

# Add the prefix
if len(multi_column_prefix_encodings) > i and multi_column_prefix_encodings[i] is not None:
input_ids += multi_column_prefix_encodings[i]['input_ids']
token_type_ids += multi_column_prefix_encodings[i]['token_type_ids']
attention_mask += multi_column_prefix_encodings[i]['attention_mask']
attention_mask += ([0] * len(multi_column_prefix_encodings[i]['input_ids']))

# Tokenize the column
column_encodings = encodeTokens(x[multi_column_keys[i]])
column_encodings = encodeTokens(x[multi_column_keys[i]], enforceSingleItem=True)

# Add the column
input_ids += column_encodings['input_ids']
Expand All @@ -624,7 +679,7 @@ def map_tokenizer(x):
if len(multi_column_suffix_encodings) > i and multi_column_suffix_encodings[i] is not None:
input_ids += multi_column_suffix_encodings[i]['input_ids']
token_type_ids += multi_column_suffix_encodings[i]['token_type_ids']
attention_mask += multi_column_suffix_encodings[i]['attention_mask']
attention_mask += ([0] * len(multi_column_suffix_encodings[i]['input_ids']))

# Set the first item flag to false
is_first_item = False
Expand All @@ -645,8 +700,8 @@ def map_tokenizer(x):

# Tokenize both prompt and completion
# Note that the tokenizer will process and return the input_ids in batches
prompt_encodings = encodeTokens(x['prompt'])
completion_encodings = encodeTokens(x['completion'])
prompt_encodings = encodeTokens(x['prompt'], enforceSingleItem=True)
completion_encodings = encodeTokens(x['completion'], enforceSingleItem=True)

# Join the two input_ids lists
input_ids = prompt_encodings['input_ids'] + completion_encodings['input_ids']
Expand Down
36 changes: 32 additions & 4 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,7 @@ class SimpleRWKV():
def __init__(
self,
model_path: str,
ctx_len:int = 1024,
ctx_len:int = 256,
device:str = "cuda",
dtype:str = "fp32"
):
Expand Down Expand Up @@ -1610,10 +1610,38 @@ def _forward(
# The all_logits array, if requested
all_logits_arr = None

# For each token, process the state, in batches up to ctx_len
for i in range(0, token_len, self.ctx_len):
# Number of times we can do batched
full_len_chunk = token_len // self.ctx_len
full_len_remain = token_len % self.ctx_len

# # For each token, we can process in full ctx_len batches
# for i in range(0, full_len_chunk * self.ctx_len, self.ctx_len):
# # Token set
# token_set = tokens[i:i+self.ctx_len]

# # Check if tokens are already tensors
# batch_tokens = torch.tensor(
# token_set,
# dtype=torch.long, device=self.device
# ).unsqueeze(0)

# # Compute the logits and state
# logits_arr, shift_states, wkv_states = self.model.forward(
# batch_tokens, shift_states, wkv_states
# )

# # Build the all_logits array
# if all_logits:
# if all_logits_arr is None:
# all_logits_arr = logits_arr[0]
# else:
# all_logits_arr = torch.cat([all_logits_arr, logits_arr[0]], dim=0)

# For each remaining token, after the full batches
# full_len_chunk * self.ctx_len
for i in range(0, token_len, 1):
# Token set
token_set = tokens[i:i+self.ctx_len]
token_set = tokens[i:i+1]

# Check if tokens are already tensors
batch_tokens = torch.tensor(
Expand Down
Loading