Skip to content

Commit

Permalink
Merge branch 'main' into soham/loss-masking-spans
Browse files Browse the repository at this point in the history
  • Loading branch information
sohamparikh authored Jan 28, 2025
2 parents 4f955ff + 6dc77a0 commit 70e40e8
Show file tree
Hide file tree
Showing 23 changed files with 383 additions and 130 deletions.
21 changes: 17 additions & 4 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import contextlib
import datetime
import logging
import typing

Expand All @@ -25,12 +26,21 @@
logger = logging.getLogger(__name__)


def broadcast(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False) -> Work | None:
def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None:
if group is not None and timeout is not None:
# TODO: Only works for nccl?
group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout))


def broadcast(
tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, timeout: float | None = None
) -> Work | None:
"""Same as torch.distributed.broadcast, but without the complication of going through the global rank."""
assert group is not None
opts = torch.distributed.BroadcastOptions()
opts.rootRank = src
opts.rootTensor = 0
add_ephemeral_timeout(group, timeout)
work = group.broadcast([tensor], opts)
if async_op:
return work
Expand All @@ -53,10 +63,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(group: ProcessGroup | None, value: int | str = 1) -> None:
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
if group:
hashed = hash(value) % 2**32
out = allreduce_scalar(hashed, dtype=torch.int64, group=group)
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout)
if out != hashed * group.size():
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")

Expand All @@ -66,9 +76,11 @@ def allreduce_scalar(
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
op=ReduceOp.SUM,
timeout: float | None = None,
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
add_ephemeral_timeout(group, timeout)
torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
else:
Expand All @@ -80,13 +92,14 @@ def broadcast_scalar(
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
timeout: float | None = None,
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
if group.rank() == src:
tensor.fill_(value)
broadcast(tensor, src, group)
broadcast(tensor, src, group, timeout=timeout)
return tensor.item()


Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def setup(
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._samples_per_phase = samples_per_phase
Expand Down
4 changes: 4 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.utils.data

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
Expand Down Expand Up @@ -69,6 +70,7 @@ def setup(
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
Expand Down Expand Up @@ -99,6 +101,8 @@ def setup(
)
dataset = self._config.datasets[phase].build_and_sample(sampling_config)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True

@property
Expand Down
23 changes: 8 additions & 15 deletions fast_llm/data/dataset/blended.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

import numpy as np

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert, normalize_probabilities

try:
Expand Down Expand Up @@ -44,7 +42,7 @@ def __init__(

if sampling_config.cache_directory is None:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = self._build_blending_indices(len(self._datasets) <= 20)
self._dataset_index, self._sample_index = self._build_blending_indices()
else:
group = sampling_config.distributed.world_group
self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy")
Expand All @@ -55,14 +53,11 @@ def __init__(
if (group is None or group.rank() == 0) and not (
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
):
dataset_index, sample_index = self._build_blending_indices(len(self._datasets) <= 20)
dataset_index, sample_index = self._build_blending_indices()
sampling_config.cache_directory.mkdir(exist_ok=True, parents=True)
np.save(self._dataset_idx_filename, dataset_index)
np.save(self._sample_idx_filename, sample_index)

safe_barrier(group, self._name)
self._load_mappings(True)

def __getstate__(self) -> tuple[typing.Any, ...]:
return (
self._datasets,
Expand All @@ -84,23 +79,20 @@ def __setstate__(self, state: tuple[typing.Any, ...]):
) = state
if isinstance(dataset_index, pathlib.Path):
self._dataset_idx_filename, self._sample_idx_filename = dataset_index, sample_index
self._load_mappings(False)
else:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = dataset_index, sample_index

def _load_mappings(self, verbose: bool) -> None:
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}")
def _load_mappings(self) -> None:
if hasattr(self, "_dataset_index") and hasattr(self, "_sample_index"):
return
self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r")
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._sample_idx_filename}")
self._sample_index = np.load(self._sample_idx_filename, mmap_mode="r")

def __len__(self) -> int:
return self._num_samples

def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray]:
def _build_blending_indices(self) -> tuple[np.ndarray, np.ndarray]:
assert _extension_available, (
"The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
)
Expand All @@ -113,7 +105,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
self._weights,
len(self._datasets),
self._num_samples,
verbose,
True, # Verbose
)
available_samples_per_dataset = np.array([len(dataset) for dataset in self._datasets])
sampled_per_dataset = np.bincount(dataset_index)
Expand All @@ -133,6 +125,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
return dataset_index, dataset_sample_index

def __getitem__(self, idx: int) -> typing.Any:
self._load_mappings()
return self._datasets[self._dataset_index[idx]][self._sample_index[idx].item()]

@property
Expand Down
65 changes: 64 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import enum
import json
import pathlib
import time
import typing
import warnings

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.data.dataset.abstract import SampledDataset
Expand All @@ -21,7 +23,7 @@
if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset, GPTRandomSampledDataset
from fast_llm.data.tokenizer import Tokenizer


Expand Down Expand Up @@ -65,6 +67,10 @@ def _from_dict(
if type_ is None:
actual_cls = cls
else:
if type_ not in cls._registry:
raise ValueError(
f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}"
)
actual_cls = cls._registry[type_]
Assert.custom(issubclass, actual_cls, cls)
if actual_cls == cls:
Expand Down Expand Up @@ -160,6 +166,41 @@ class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig):
datasets: list[GPTSampledDatasetConfig] = FieldUpdate()


@config_class()
class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "concatenated_memmap"
path: pathlib.Path = Field(
default=None,
desc="The path to a dataset directory.",
hint=FieldHint.core,
)

def build(self) -> "GPTConcatenatedDataset":
pass

assert self.path.is_dir()
index_path = self.path / "index.txt"

if index_path.is_file():
prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()]
else:
warnings.warn(
f"The dataset path {self.path} points to a directory."
" The dataset will be indexed automatically, which may be unsafe."
" We recommend using an index file instead."
)
prefixes = [
path.with_suffix("")
for path in self.path.iterdir()
if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file()
]
dataset_config = GPTConcatenatedDatasetConfig.from_dict(
{"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]}
)
return dataset_config.build()


@config_class()
class FimConfig(Config):
"""
Expand Down Expand Up @@ -364,3 +405,25 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset:
)

return dataset_config.build_and_sample(config)


@config_class()
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
"""
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
"""

# TODO: This belongs to a testing plugin.
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "test_slow"
sleep: float = Field(
default=1,
desc="Sleep time during build, in seconds.",
hint=FieldHint.core,
)

def build_and_sample(self, config: SamplingConfig) -> "GPTRandomSampledDataset":
assert config.distributed.config.world_size > 1
if config.distributed.config.rank == 0:
time.sleep(self.sleep)
return GPTRandomDatasetConfig().build_and_sample(config)
10 changes: 6 additions & 4 deletions fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ def __setstate__(self, state: tuple[str, pathlib.Path]):
self._init(*state)

def __del__(self):
self._bin_buffer_mmap._mmap.close() # noqa
del self._bin_buffer_mmap
self._index_bin_buffer_mmap._mmap.close() # noqa
del self._index_bin_buffer_mmap
if hasattr(self, "_bin_buffer_mmap"):
self._bin_buffer_mmap._mmap.close() # noqa
del self._bin_buffer_mmap
if hasattr(self, "_index_bin_buffer"):
self._index_bin_buffer_mmap._mmap.close() # noqa
del self._index_bin_buffer_mmap

def get(self, idx, offset=0, length=None) -> GPTMemmapSample:
ids = np.frombuffer(
Expand Down
Loading

0 comments on commit 70e40e8

Please sign in to comment.