Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
enable text eval with non-cm3
Browse files Browse the repository at this point in the history
  • Loading branch information
Lili Yu committed Jul 22, 2023
1 parent 7452544 commit e6d4986
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions metaseq/tasks/streaming_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"DatasetWithShardInformation", ["dataset", "is_sharded", "shard_id", "num_shards"]
)

TEXT_DATA_EVALSETS = ["llama", "text_eval", "marmot"]
IMAGE_PREFIX = "IMGIMG"


Expand Down Expand Up @@ -679,8 +680,12 @@ def load_dataset(

dataset = torch.utils.data.ConcatDataset(datasets)

break_mode = "complete" if split != "train" else self.args.sample_break_mode
no_image_break = False if split != "train" else self.args.no_image_break
is_text = any([subset in split for subset in TEXT_DATA_EVALSETS])

# chunk into blocks of tokens
if self.has_cm3:
if self.has_cm3 and not is_text:
# We chose not to use compositional inheritance because there's a
# lot of downstream code that has isinstance checks.
# So just to be safe and not change anything we use proper inheritance.
Expand All @@ -695,19 +700,19 @@ def load_dataset(
# We generate blocks with one extra token, so that we have a target
# for the final input token. This results in slight data loss.
block_size=self.args.tokens_per_sample + 1,
break_mode=self.args.sample_break_mode,
break_mode=break_mode,
# we drop the remainder block during training
drop_last=(split == "train"),
padding_idx=self.source_dictionary.pad(),
seed=self.args.seed,
percent_full_document_rotation=self.args.cm3_percent_full_document_rotation,
no_break_image=self.args.no_break_image,
no_break_image=no_image_break,
)
else:
self.datasets[split] = DocumentToSequenceDataset(
dataset,
block_size=self.args.tokens_per_sample + 1,
break_mode=self.args.sample_break_mode,
break_mode=break_mode,
drop_last=(split == "train"),
padding_idx=self.source_dictionary.pad(),
seed=self.args.seed,
Expand Down

0 comments on commit e6d4986

Please sign in to comment.