Skip to content

Commit

Permalink
fix fim tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sohamparikh committed Jan 28, 2025
1 parent 40a80f6 commit e908303
Showing 1 changed file with 94 additions and 15 deletions.
109 changes: 94 additions & 15 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_test_data_and_samples(
seed: int = 54983,
cache_directory: pathlib.Path | None = None,
sequence_length: int = 512,
consumed_samples: int = 0,
vocab_size=TEST_VOCAB_SIZE,
):
distributed_config = DistributedConfig(seed=seed)
Expand All @@ -82,7 +83,7 @@ def get_test_data_and_samples(
batch_config.setup(distributed_config)
batch_config.validate()
samples = {
phase: list(data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0))
phase: list(data.get_iterator(batch_config, phase, consumed_samples=consumed_samples, num_workers=0))
for phase, n_samples in samples_per_phase.items()
}
return data, samples
Expand Down Expand Up @@ -818,6 +819,8 @@ def test_gpt_blended_mixed_data():
[],
]

GPT_FIM_VALID_IDS = [2, 3, 4, 6, 7]


def test_gpt_fim():
# Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior.
Expand Down Expand Up @@ -845,17 +848,54 @@ def test_gpt_fim():
Assert.eq(len(sampled), 8)
# TODO: Does this output make sense?
Assert.all_equal(
np.stack([sampled[i].ids for i in range(8)]),
np.array(GPT_FIM_EXPECTED_SAMPLES_IDS),
np.stack([sampled[i].ids for i in GPT_FIM_VALID_IDS]),
np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS]),
)
Assert.all_equal(
np.vstack([sampled[i].spans for i in range(8)]),
np.vstack([np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]),
np.vstack([sampled[i].spans for i in GPT_FIM_VALID_IDS]),
np.vstack(
[
np.array(x, dtype=sampled[GPT_FIM_VALID_IDS[0]].spans.dtype).reshape(-1, 2)
for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS]
]
),
)


def test_gpt_fim_data():
_, samples = get_test_data_and_samples(
_, samples1 = get_test_data_and_samples(
{
"datasets": {
"Training": {
"type": "fim",
"dataset": {"type": "memmap", "path": DATASET_PREFIX},
"rate": 0.5,
"prefix_token": "w",
"middle_token": "x",
"pad_token": "y",
"suffix_token": "z",
}
},
"tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"},
},
{PhaseType.training: 5},
sequence_length=5,
consumed_samples=2,
)
Assert.all_equal(
np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]),
np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]]),
)
Assert.all_equal(
np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]),
np.vstack(
[
np.array(x, dtype=np.int32).reshape(-1, 2)
for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]]
]
),
)
_, samples2 = get_test_data_and_samples(
{
"datasets": {
"Training": {
Expand All @@ -872,19 +912,52 @@ def test_gpt_fim_data():
},
{PhaseType.training: 8},
sequence_length=5,
consumed_samples=6,
)
Assert.all_equal(
np.stack([batch.ids[0] for batch in samples[PhaseType.training]]),
np.array(GPT_FIM_EXPECTED_SAMPLES_IDS),
np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]),
np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]),
)
Assert.all_equal(
np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]),
np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]),
np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]),
np.vstack(
[
np.array(x, dtype=np.int32).reshape(-1, 2)
for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]]
]
),
)


def test_gpt_fim_data_legacy():
_, samples = get_test_data_and_samples(
_, samples1 = get_test_data_and_samples(
{
"format": "list",
"path": [str(DATASET_PREFIX)],
"fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"},
"tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"},
"split": [1, 0, 0],
},
{PhaseType.training: 5},
sequence_length=5,
consumed_samples=2,
)
Assert.all_equal(
np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]),
np.array(
[GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]],
),
)
Assert.all_equal(
np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]),
np.vstack(
[
np.array(x, dtype=np.int32).reshape(-1, 2)
for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]]
]
),
)
_, samples2 = get_test_data_and_samples(
{
"format": "list",
"path": [str(DATASET_PREFIX)],
Expand All @@ -894,12 +967,18 @@ def test_gpt_fim_data_legacy():
},
{PhaseType.training: 8},
sequence_length=5,
consumed_samples=6,
)
Assert.all_equal(
np.stack([batch.ids[0] for batch in samples[PhaseType.training]]),
np.array(GPT_FIM_EXPECTED_SAMPLES_IDS),
np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]),
np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]),
)
Assert.all_equal(
np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]),
np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]),
np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]),
np.vstack(
[
np.array(x, dtype=np.int32).reshape(-1, 2)
for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]]
]
),
)

0 comments on commit e908303

Please sign in to comment.