From e6d4986642b8f1c30fc552349ddef18e9bc0fa4b Mon Sep 17 00:00:00 2001 From: Lili Yu Date: Fri, 21 Jul 2023 19:11:55 -0700 Subject: [PATCH] enable text eval with non-cm3 --- metaseq/tasks/streaming_language_modeling.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/metaseq/tasks/streaming_language_modeling.py b/metaseq/tasks/streaming_language_modeling.py index 3751aa09b..4afcc7261 100644 --- a/metaseq/tasks/streaming_language_modeling.py +++ b/metaseq/tasks/streaming_language_modeling.py @@ -55,6 +55,7 @@ "DatasetWithShardInformation", ["dataset", "is_sharded", "shard_id", "num_shards"] ) +TEXT_DATA_EVALSETS = ["llama", "text_eval", "marmot"] IMAGE_PREFIX = "IMGIMG" @@ -679,8 +680,12 @@ def load_dataset( dataset = torch.utils.data.ConcatDataset(datasets) + break_mode = "complete" if split != "train" else self.args.sample_break_mode + no_image_break = False if split != "train" else self.args.no_image_break + is_text = any([subset in split for subset in TEXT_DATA_EVALSETS]) + # chunk into blocks of tokens - if self.has_cm3: + if self.has_cm3 and not is_text: # We chose not to use compositional inheritance because there's a # lot of downstream code that has isinstance checks. # So just to be safe and not change anything we use proper inheritance. @@ -695,19 +700,19 @@ def load_dataset( # We generate blocks with one extra token, so that we have a target # for the final input token. This results in slight data loss. block_size=self.args.tokens_per_sample + 1, - break_mode=self.args.sample_break_mode, + break_mode=break_mode, # we drop the remainder block during training drop_last=(split == "train"), padding_idx=self.source_dictionary.pad(), seed=self.args.seed, percent_full_document_rotation=self.args.cm3_percent_full_document_rotation, - no_break_image=self.args.no_break_image, + no_break_image=no_image_break, ) else: self.datasets[split] = DocumentToSequenceDataset( dataset, block_size=self.args.tokens_per_sample + 1, - break_mode=self.args.sample_break_mode, + break_mode=break_mode, drop_last=(split == "train"), padding_idx=self.source_dictionary.pad(), seed=self.args.seed,