Skip to content

Commit

Permalink
Fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 22, 2025
1 parent 9dbbcf9 commit 6dea63e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_gpt_sampled():


def test_gpt_sampled_data():
get_test_dataset()
_, samples = get_test_data_and_samples(
{
"datasets": {
Expand All @@ -204,7 +205,9 @@ def test_gpt_sampled_data():

def test_gpt_sampled_data_legacy():
_, samples = get_test_data_and_samples(
{"format": "list", "path": [DATASET_PREFIX], "split": [1, 0, 0]}, {PhaseType.training: 8}, sequence_length=5
{"format": "list", "path": [str(DATASET_PREFIX)], "split": [1, 0, 0]},
{PhaseType.training: 8},
sequence_length=5,
)
Assert.all_equal(
np.stack(samples[PhaseType.training]),
Expand Down Expand Up @@ -414,7 +417,7 @@ def test_gpt_blended_data():
sequence_length=5,
)
Assert.all_equal(
np.stack(samples[PhaseType.validation]),
np.stack(samples[PhaseType.training]),
np.array(GPT_BLENDED_EXPECTED_SAMPLES),
)

Expand All @@ -425,7 +428,7 @@ def test_gpt_blended_data_legacy():
_, samples = get_test_data_and_samples(
{
"format": "list",
"path": [0.75, str(DATASET_PREFIX), 0.25, str(DATASET_PREFIX_MIX_1)],
"path": ["0.75", str(DATASET_PREFIX), "0.25", str(DATASET_PREFIX_MIX_1)],
"split": [1, 0, 0],
},
{PhaseType.training: 8},
Expand Down

0 comments on commit 6dea63e

Please sign in to comment.