diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 4c3a56758..97f5633fe 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -1,8 +1,6 @@ from typing import List from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoTokenizer import trlx from trlx.data.configs import ( @@ -93,11 +91,10 @@ if __name__ == "__main__": - def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): - original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str], original_summaries: List[str], **kwargs): scores = [ - meteor.compute(predictions=[output.strip()], references=[original])["meteor"] - for (original, output) in zip(original_summaries, outputs) + meteor.compute(predictions=[output.strip()], references=[original_summary])["meteor"] + for (original_summary, output) in zip(original_summaries, outputs) ] return scores @@ -112,31 +109,11 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] val_summaries = dataset["validation"]["highlights"][0:1000] - # make dictionary of prompts and labels to use for reward function - tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) - tokenizer.padding_side = "left" - tokenizer.truncation_side = "right" - tokenizer.sep_token = "" - prompt_label = {} - max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] - - for i in tqdm(range(len(prompts))): - key = tokenizer.decode( - tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], - skip_special_tokens=True, - ) # get prompt like trlx's prompt - prompt_label[key.strip()] = summaries[i] - - for i in tqdm(range(len(val_prompts))): - key = tokenizer.decode( - tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], - skip_special_tokens=True, - ) # get prompt like trlx's prompt - prompt_label[key.strip()] = val_summaries[i] - trlx.train( reward_fn=reward_fn, - prompts=prompts, - eval_prompts=val_prompts, + prompts=[{"prompt": prompt, "original_summaries": summary} for prompt, summary in zip(prompts, summaries)], + eval_prompts=[ + {"prompt": prompt, "original_summaries": summary} for prompt, summary in zip(val_prompts, val_summaries) + ], config=config, )