Skip to content

Commit

Permalink
Fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 22, 2025
1 parent a0aae75 commit 9dbbcf9
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
6 changes: 3 additions & 3 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class GPTSampledDatasetConfig(SampledDatasetConfig):
)

def _validate(self) -> None:
if self.type is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.eq(self.type, self.type_)
if self.type is None:
self.type = self.type_
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.eq(self.type, self.__class__.type_)
super()._validate()

Expand Down
2 changes: 2 additions & 0 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(
self._dataset = dataset
self._sampling_config = sampling_config
self._tokenizer = sampling_config.tokenizer
if self._tokenizer is None:
raise ValueError("Fim requires a tokenizer")
self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = (
self._tokenizer.vocab[tok]
for tok in [config.suffix_token, config.prefix_token, config.middle_token, config.pad_token]
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/dataset/gpt/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def __init__(self, name: str):
self._name = name

def sample(self, config: GPTSamplingConfig) -> "GPTRandomSampledDataset":
return GPTRandomSampledDataset(f"{self.name}_sampled", config)
return GPTRandomSampledDataset(config, f"{self.name}_sampled")

@property
def name(self) -> str:
return self._name


class GPTRandomSampledDataset(SampledDataset):
def __init__(self, name: str, config: GPTSamplingConfig):
def __init__(self, config: GPTSamplingConfig, name: str):
self._name = name
self._config = config

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __getstate__(
if hasattr(self, "_doc_idx_filename"):
return (
self._indexed_dataset,
self._doc_idx,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
)
Expand Down
38 changes: 28 additions & 10 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,17 @@ def test_gpt_sampled_data():
)
Assert.all_equal(
np.stack(samples[PhaseType.training]),
np.array(RANDOM_DATASET_EXPECTED_SAMPLES),
np.array(GPT_SAMPLED_EXPECTED_SAMPLES),
)


def test_gpt_sampled_data_legacy():
_, samples = get_test_data_and_samples(
{"format": "list", "path": [DATASET_PREFIX], "split": [1]}, {PhaseType.training: 8}, sequence_length=5
{"format": "list", "path": [DATASET_PREFIX], "split": [1, 0, 0]}, {PhaseType.training: 8}, sequence_length=5
)
Assert.all_equal(
np.stack(samples[PhaseType.training]),
np.array(RANDOM_DATASET_EXPECTED_SAMPLES),
np.array(GPT_SAMPLED_EXPECTED_SAMPLES),
)


Expand Down Expand Up @@ -238,7 +238,7 @@ def test_gpt_concatenate():
begin = i * MEMMAP_DATASET_EXPECTED_LENGTH
Assert.all_equal([len(dataset.get(begin + i)) for i in range(100)], sizes[begin : begin + 100])
for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items():
Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=numpy.uint16))
Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=np.uint16))
sampled = dataset.sample(get_sampling_config(8, sequence_length=5))
Assert.eq(len(sampled), 8)
Assert.all_equal(
Expand All @@ -262,7 +262,7 @@ def test_gpt_concatenate_data():
)
Assert.all_equal(
np.stack(samples[PhaseType.training]),
np.array(RANDOM_DATASET_EXPECTED_SAMPLES),
np.array(GPT_CONCATENATED_EXPECTED_SAMPLES),
)


Expand Down Expand Up @@ -344,6 +344,7 @@ def test_gpt_slice_data():


def test_gpt_slice_data_legacy():
get_test_dataset()
_, samples = get_test_data_and_samples(
{"format": "list", "path": [str(DATASET_PREFIX)], "split": [0.0015, 0.0015, 0.997]},
{PhaseType.training: 4, PhaseType.validation: 8, PhaseType.test: 5},
Expand Down Expand Up @@ -374,10 +375,14 @@ def test_gpt_slice_data_legacy():
def test_gpt_blended():
# Make sure dataset blending works and check for unintended changes in behavior.
get_test_dataset()
get_test_dataset_1()
sampled = _get_dataset_config(
{
"type": "blended",
"datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(2)],
"datasets": [
{"type": "memmap", "path": DATASET_PREFIX},
{"type": "memmap", "path": DATASET_PREFIX_MIX_1},
],
"weights": [0.75, 0.25],
},
GPTBlendedDatasetConfig,
Expand All @@ -390,12 +395,17 @@ def test_gpt_blended():


def test_gpt_blended_data():
get_test_dataset()
get_test_dataset_1()
_, samples = get_test_data_and_samples(
{
"datasets": {
"Training": {
"type": "blended",
"datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(2)],
"datasets": [
{"type": "memmap", "path": DATASET_PREFIX},
{"type": "memmap", "path": DATASET_PREFIX_MIX_1},
],
"weights": [0.75, 0.25],
}
}
Expand All @@ -409,9 +419,15 @@ def test_gpt_blended_data():
)


def test_gpt_blended_legacy_data():
def test_gpt_blended_data_legacy():
get_test_dataset()
get_test_dataset_1()
_, samples = get_test_data_and_samples(
{"format": "list", "path": [0.75, DATASET_PREFIX, 0.25, DATASET_PREFIX]},
{
"format": "list",
"path": [0.75, str(DATASET_PREFIX), 0.25, str(DATASET_PREFIX_MIX_1)],
"split": [1, 0, 0],
},
{PhaseType.training: 8},
sequence_length=5,
)
Expand Down Expand Up @@ -526,7 +542,8 @@ def test_gpt_fim_data():
"pad_token": "y",
"suffix_token": "z",
}
}
},
"tokenizer": {"path": TOKENIZER_PATH},
},
{PhaseType.training: 8},
sequence_length=5,
Expand All @@ -544,6 +561,7 @@ def test_gpt_fim_data_legacy():
"path": [str(DATASET_PREFIX)],
"fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"},
"tokenizer": {"path": TOKENIZER_PATH},
"split": [1, 0, 0],
},
{PhaseType.training: 8},
sequence_length=5,
Expand Down

0 comments on commit 9dbbcf9

Please sign in to comment.