Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't add eot token if add_eot_token knob is False #834

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __getitem__(self, idx):


def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
end_of_conversation_token, max_seq_len):
end_of_conversation_token, max_seq_len, add_eot_token=True):
prompt_dataset = []
chosen_dataset = []
reject_dataset = []
Expand All @@ -176,7 +176,8 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
chosen_sentence = raw_dataset.get_prompt_and_chosen(
tmp_data) # the accept response
if chosen_sentence is not None:
chosen_sentence += end_of_conversation_token
if add_eot_token is True:
chosen_sentence += end_of_conversation_token
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
Expand All @@ -199,8 +200,9 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
reject_sentence = raw_dataset.get_prompt_and_rejected(
tmp_data) # the accept response
if chosen_sentence is not None and reject_sentence is not None:
chosen_sentence += end_of_conversation_token # the accept response
reject_sentence += end_of_conversation_token
if add_eot_token is True:
chosen_sentence += end_of_conversation_token # the accept response
reject_sentence += end_of_conversation_token
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
Expand All @@ -211,12 +213,7 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
padding="max_length",
truncation=True,
return_tensors="pt")
chosen_token["input_ids"] = chosen_token["input_ids"]
chosen_token["attention_mask"] = chosen_token["attention_mask"]
chosen_dataset.append(chosen_token)

reject_token["input_ids"] = reject_token["input_ids"]
reject_token["attention_mask"] = reject_token["attention_mask"]
reject_dataset.append(reject_token)
print(
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
Expand Down Expand Up @@ -245,7 +242,7 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,

def create_dataset(local_rank, dataset_name, data_split, output_path,
train_phase, seed, tokenizer, end_of_conversation_token,
max_seq_len, rebuild):
max_seq_len, rebuild, add_eot_token=True):
raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank)
train_dataset = raw_dataset.get_train_data()
train_index = get_raw_dataset_split_index(local_rank, output_path,
Expand All @@ -257,7 +254,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
train_dataset = create_dataset_split(train_dataset, raw_dataset,
train_phase, tokenizer,
end_of_conversation_token,
max_seq_len)
max_seq_len, add_eot_token=add_eot_token)

eval_dataset = raw_dataset.get_eval_data()
eval_index = get_raw_dataset_split_index(local_rank, output_path,
Expand All @@ -268,7 +265,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
eval_dataset = Subset(eval_dataset, eval_index)
eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase,
tokenizer, end_of_conversation_token,
max_seq_len)
max_seq_len, add_eot_token=add_eot_token)
return train_dataset, eval_dataset


Expand All @@ -281,11 +278,14 @@ def create_prompt_dataset(local_rank,
tokenizer,
max_seq_len,
end_of_conversation_token="<|endoftext|>",
sft_only_data_path=[],
reload=False):
sft_only_data_path=None,
reload=False,
add_eot_token=True):
"""
Creates the prompt dataset
"""
if sft_only_data_path is None:
sft_only_data_path = []
os.makedirs(output_path, exist_ok=True)
fname = "_".join(data_path)
sft_cache_key = "_".join(sft_only_data_path)
Expand Down Expand Up @@ -315,7 +315,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
rebuild=reload,
add_eot_token=add_eot_token)
else: # Blending datasets.
train_datasets = []
eval_datasets = []
Expand All @@ -332,7 +333,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
rebuild=reload,
add_eot_token=add_eot_token)
train_datasets.append(train_dataset)
eval_datasets.append(eval_dataset)
train_size += len(train_dataset)
Expand Down Expand Up @@ -361,7 +363,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
rebuild=reload,
add_eot_token=add_eot_token)
sft_train_datasets.append(sft_train_dataset)
sft_eval_datasets.append(sft_eval_dataset)
sft_train_size += len(sft_train_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def main():
train_dataset, eval_dataset = create_prompt_dataset(
args.local_rank, args.data_path, args.data_split,
args.data_output_path, train_phase, args.seed, tokenizer,
args.max_seq_len)
args.max_seq_len, add_eot_token=args.add_eot_token)

# DataLoaders creation:
data_collator = DataCollatorReward()
Expand Down