From e908303e72148c41d9049e450bbf4631c6e06762 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 22:33:38 +0000 Subject: [PATCH] fix fim tests --- tests/test_dataset.py | 109 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c693b6c..6c5bb58 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -71,6 +71,7 @@ def get_test_data_and_samples( seed: int = 54983, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, + consumed_samples: int = 0, vocab_size=TEST_VOCAB_SIZE, ): distributed_config = DistributedConfig(seed=seed) @@ -82,7 +83,7 @@ def get_test_data_and_samples( batch_config.setup(distributed_config) batch_config.validate() samples = { - phase: list(data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)) + phase: list(data.get_iterator(batch_config, phase, consumed_samples=consumed_samples, num_workers=0)) for phase, n_samples in samples_per_phase.items() } return data, samples @@ -818,6 +819,8 @@ def test_gpt_blended_mixed_data(): [], ] +GPT_FIM_VALID_IDS = [2, 3, 4, 6, 7] + def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. @@ -845,17 +848,54 @@ def test_gpt_fim(): Assert.eq(len(sampled), 8) # TODO: Does this output make sense? Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + np.stack([sampled[i].ids for i in GPT_FIM_VALID_IDS]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS]), ) Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack([np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), + np.vstack([sampled[i].spans for i in GPT_FIM_VALID_IDS]), + np.vstack( + [ + np.array(x, dtype=sampled[GPT_FIM_VALID_IDS[0]].spans.dtype).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS] + ] + ), ) def test_gpt_fim_data(): - _, samples = get_test_data_and_samples( + _, samples1 = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "fim", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "rate": 0.5, + "prefix_token": "w", + "middle_token": "x", + "pad_token": "y", + "suffix_token": "z", + } + }, + "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, + }, + {PhaseType.training: 5}, + sequence_length=5, + consumed_samples=2, + ) + Assert.all_equal( + np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]]), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]] + ] + ), + ) + _, samples2 = get_test_data_and_samples( { "datasets": { "Training": { @@ -872,19 +912,52 @@ def test_gpt_fim_data(): }, {PhaseType.training: 8}, sequence_length=5, + consumed_samples=6, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), + np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]] + ] + ), ) def test_gpt_fim_data_legacy(): - _, samples = get_test_data_and_samples( + _, samples1 = get_test_data_and_samples( + { + "format": "list", + "path": [str(DATASET_PREFIX)], + "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, + "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, + "split": [1, 0, 0], + }, + {PhaseType.training: 5}, + sequence_length=5, + consumed_samples=2, + ) + Assert.all_equal( + np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]), + np.array( + [GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]], + ), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]] + ] + ), + ) + _, samples2 = get_test_data_and_samples( { "format": "list", "path": [str(DATASET_PREFIX)], @@ -894,12 +967,18 @@ def test_gpt_fim_data_legacy(): }, {PhaseType.training: 8}, sequence_length=5, + consumed_samples=6, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), + np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]] + ] + ), )