From e5e9d719206439157672f8afb990a704ad32e89c Mon Sep 17 00:00:00 2001 From: Joshua Deng Date: Thu, 29 Feb 2024 12:31:16 -0800 Subject: [PATCH] Add staged train pipeline to torchrec and refactor train pipeline organization (#1624) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1624 # train pipeline code structure - moves train_pipeline.py under /train_pipeline dir - moves helper functions, classes to utils.py - keeps imports under` torchrec.train_pipeline` # staged train pipeline Variable stage train pipeline that can support arbitrary number of stages. Note that this is a pre-forward scheduling pipeline, this means the forward is expected executed after the last stage in the pipeline and the last stage in the pipeline has to be the start SDD stage. This is because ShardedModule forward (overwritten by pipeline) has to consume the input from pipeline context post input dist. The design is illustrated in the figure below. A pipeline is composed of K stages, each one depends on its precedence. Different stage may be executed in different or the same streams, with multiple batches concurrently executed in the same iteration. For example, in the image, batch[0] is the oldest batch that has passed H2D, preproc, and SDD stages, and will be running through comp in the current iteration. Batch[1], on the other hand, is the second oldest batch that will execute SDD in the current iteration. Similarly for other batches. {F1150156522} with this, 4 batches will be handled together in the same iteration, while each of them is under a different stage. When an iteration is done, there will be a advance step to copy newer data (slots larger index) to older data (slots with smaller index) so that they will be handled by the next stage in the next iteration. For SDD handling, we currently wrap on top of existing torchrec utilities. This part could potentially be improved in the future. Some additional things on top of this: * Adding fill callback for start sdd * Adding await sdd as a callback * Modifying progress to be walrus'able Reviewed By: sarckk Differential Revision: D51182804 --- .../distributed/train_pipeline/__init__.py | 28 + .../train_pipeline/staged_train_pipeline.py | 273 +++++++ .../tests/test_staged_train_pipeline.py | 216 ++++++ .../tests/test_train_pipeline.py | 7 +- .../train_pipeline/train_pipeline.py | 635 ++++++++++++++++ .../utils.py} | 704 ++---------------- 6 files changed, 1233 insertions(+), 630 deletions(-) create mode 100644 torchrec/distributed/train_pipeline/__init__.py create mode 100644 torchrec/distributed/train_pipeline/staged_train_pipeline.py create mode 100644 torchrec/distributed/train_pipeline/tests/test_staged_train_pipeline.py rename torchrec/distributed/{ => train_pipeline}/tests/test_train_pipeline.py (99%) create mode 100644 torchrec/distributed/train_pipeline/train_pipeline.py rename torchrec/distributed/{train_pipeline.py => train_pipeline/utils.py} (56%) diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py new file mode 100644 index 000000000..676b9943b --- /dev/null +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torchrec.distributed.train_pipeline.train_pipeline import ( # noqa + DataLoadingThread, # noqa + EvalPipelineSparseDist, # noqa + PrefetchTrainPipelineSparseDist, # noqa + TrainPipeline, # noqa + TrainPipelineBase, # noqa + TrainPipelineSparseDist, # noqa +) +from torchrec.distributed.train_pipeline.utils import ( # noqa + _override_input_dist_forwards, # noqa + _rewrite_model, # noqa + _start_data_dist, # noqa + _to_device, # noqa + _wait_for_batch, # noqa + ArgInfo, # noqa + In, # noqa + Out, # noqa + Tracer, # noqa + TrainPipelineContext, # noqa +) diff --git a/torchrec/distributed/train_pipeline/staged_train_pipeline.py b/torchrec/distributed/train_pipeline/staged_train_pipeline.py new file mode 100644 index 000000000..2be7a73aa --- /dev/null +++ b/torchrec/distributed/train_pipeline/staged_train_pipeline.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import logging +from dataclasses import dataclass + +from typing import Callable, cast, Generic, Iterator, List, Optional, Tuple, TypeVar + +import torch + +from torch.profiler import record_function +from torchrec.distributed.train_pipeline.utils import In +from torchrec.distributed.utils import none_throws +from torchrec.streamable import Pipelineable + +logger: logging.Logger = logging.getLogger(__name__) + +StageOut = TypeVar("StageOut", bound=Pipelineable) +RunnableType = Callable[..., StageOut] +StageOutputWithEvent = Tuple[Optional[StageOut], Optional[torch.cuda.Event]] + + +def get_h2d_func(batch: Pipelineable, device: torch.device) -> Pipelineable: + return batch.to(device, non_blocking=True) + + +@dataclass +class PipelineStage: + """ + A pipeline stage represents a transform to an input that is independent of the + backwards() of the model. Examples include batch H2D transfer, GPU preproc, or + gradient-less model processing. + + Args: + name (str): Name of the stage. + runnable (Callable[In, Out]): Function that performs a gradient-less + transform. + stream (torch.cuda.streams.Stream): Stream to run on. Often each stage has a + unique stream, but having different pipelines share a stream provides more + synchronization semantics. + """ + + name: str + runnable: RunnableType + stream: torch.cuda.streams.Stream + fill_callback: Optional[Callable[[], None]] = None + + +class StagedTrainPipeline(Generic[In, StageOut]): + """ + StagedTrainPipeline orchestrates the pipelined execution of its constitutent stages + from inputs of `data_iter`. Namely scheduling the execution of stages before model + forward. + + NOTE: the SDD stage needs to be the final stage of the pipeline so that the + ShardedModule forward can properly consume the SDD output. + + Calling progress on a StagedTrainPipeline provides an output that is equivalent to + calling each of the pipeline stages in order. + + In the example below a fully synchronous will expose the `data_copy` and + `gpu_preproc` calls. After pipelining, the `data_copy` of batch i+2 can be + overlapped with the `gpu_preproc` of batch i+1 and the main model processing of + batch i. + + Args: + data_iter (Optional[Iterator[In]]): An iterator that produces the inputs to the + pipeline. + pipeline_stages (List[PipelineStage]): A list of stages to execute. + debug_mode (bool): Whether to enable debug mode. + + Example:: + train_pipeline = StagedTrainPipeline( + data_iter=data_iter, + pipeline=[ + PipelineStage( + name="data_copy", + runnable=get_h2d_func("cuda"), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_preproc", + runnable=gpu_preproc, + stream=torch.cuda.Stream(), + ), + ] + ) + + while batch_for_forward := train_pipeline.progress(): + optimizer.zero_grad() + loss, pred = model(batch_for_forward) + loss.backward() + optimizer.step() + """ + + def __init__( + self, + data_iter: Optional[Iterator[In]], + pipeline_stages: List[PipelineStage], + debug_mode: bool = False, + ) -> None: + self._data_iter = data_iter + self._pipeline_stages = pipeline_stages + self._debug_mode = debug_mode + self._stage_outputs: List[Optional[StageOutputWithEvent]] = cast( + List[Optional[StageOutputWithEvent]], [None] * len(self._pipeline_stages) + ) + self._initialized = False + self._num_steps = 0 + + @property + def num_stages(self) -> int: + return len(self._pipeline_stages) + + def _next_batch(self) -> Optional[In]: + batch = next(none_throws(self._data_iter, "`data_iter` cannot be none"), None) + return batch + + def _advance(self) -> Optional[StageOut]: + # left shifts all batch results. + out = self._stage_outputs[0] + for idx in range(self.num_stages - 1): + self._stage_outputs[idx] = self._stage_outputs[idx + 1] + self._stage_outputs[-1] = None + if out is None: + return out + return out[0] + + def _run_with_event( + self, + runnable: RunnableType, + event: Optional[torch.cuda.Event], + inputs: Optional[In], + stream: torch.cuda.streams.Stream, + ) -> StageOutputWithEvent: + if inputs is None: + return (None, None) + with torch.cuda.stream(stream): + # If there is no previous event, data is entering the pipeline + if event is not None: + event.wait(stream) + inputs.record_stream(stream) + + output = runnable(inputs) + new_event = torch.cuda.Event() + new_event.record(stream) + return (output, new_event) + + def _run_stage( + self, + batch_offset: int, + stage_idx: int, + fill: bool = False, + ) -> StageOutputWithEvent: + """ + Each stage of the pipeline MUST have an input and output. + If the input is None, it means there is no more data to process. + It will short circuit and NOT execute the runnable. + """ + stage = self._pipeline_stages[stage_idx] + + with record_function( + f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##" + ): + if stage_idx == 0: + batch = self._next_batch() + batch_to_wait, event = batch, None + else: + batch_to_wait_with_event = self._stage_outputs[batch_offset] + assert batch_to_wait_with_event is not None + batch_to_wait, event = batch_to_wait_with_event + + new_result = self._run_with_event( + runnable=stage.runnable, + event=event, + inputs=batch_to_wait, + stream=stage.stream, + ) + + self._stage_outputs[batch_offset] = new_result + if self._debug_mode: + logger.info( + f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", + ) + + if fill and (fill_callback := stage.fill_callback) is not None: + if self._debug_mode: + logger.info(f"Finished callback for {stage.name}") + fill_callback() + + return new_result + + def _fill_pipeline(self) -> None: + """ + There should always be `self.num_stages` batches in flight. This function + initializes the pipeline by filling it with `self.num_stages` batches. + Intuitively, it does all the stages before the model forward. + + For a 5 stage pipeline during `_fill_pipeline`: + batch 0: stages 0, 1, 2, 3, 4 will be run + batch 1: stages 0, 1, 2, 3 will be run + batch 2: stages 0, 1, 2 will be run + batch 3: stages 0, 1 will be run + batch 4: stage 0 will be run + batch 5: will start in `progress()` + + In the initial `progress()` + batch 0: model forward will be run + batch 1: stage 4 will be run + batch 2: stage 3 will be run + batch 3: stage 2 will be run + batch 4: stage 1 will be run + batch 5: stage 1 will be run + """ + for batch_offset in range(self.num_stages): + stages_to_run = self.num_stages - batch_offset + for stage_idx in range(stages_to_run): + self._run_stage( + batch_offset=batch_offset, stage_idx=stage_idx, fill=True + ) + + self._initialized = True + if self._debug_mode: + logger.info("Finished fill pipeline") + + def progress( + self, + data_iter: Optional[Iterator[In]] = None, + run_stage_order: Optional[List[int]] = None, + ) -> Optional[StageOut]: + """ + The stages process data in reverse order, so stage_0 processes the newest data. + Stage order can be modified through the `run_stage_order` arg. This is useful in + achieving better overlap for different stages. + + NOTE: if SDD is enabled it must be the last stage in the pipeline. + + Args: + data_iter (Optional[Iterator[In]]): An iterator that produces the inputs to + the pipeline. + run_stage_order (Optional[List[int]]): Specifies the order of running the + stages. If `None`, the pipeline will run stages in the original order, + i.e. stage_0 -> stage_1 -> ... -> stage_n. + + Returns: + Optional[StageOut]: Output of the final stage. `None` signifies that the + dataloader iterator is depleted. + """ + if self._data_iter is None: + self._data_iter = none_throws(data_iter, "`data_iter` cannot be none") + + if not self._initialized: + self._fill_pipeline() + + output = self._advance() + self._num_steps += 1 + + if not run_stage_order: + run_stage_order = list(range(self.num_stages)) + for stage_idx in run_stage_order: + stage_output_idx = self.num_stages - 1 - stage_idx + self._run_stage( + batch_offset=stage_output_idx, + stage_idx=stage_idx, + ) + + return output diff --git a/torchrec/distributed/train_pipeline/tests/test_staged_train_pipeline.py b/torchrec/distributed/train_pipeline/tests/test_staged_train_pipeline.py new file mode 100644 index 000000000..44f49bb64 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/test_staged_train_pipeline.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest +from functools import partial +from typing import cast, Tuple + +import torch +from torch import nn, optim +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.fp_embeddingbag import ( + FeatureProcessedEmbeddingBagCollectionSharder, +) +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + table_wise, +) +from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.distributed.test_utils.test_sharding import copy_state_dict + +from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( + create_module_and_freeze, +) +from torchrec.distributed.train_pipeline.staged_train_pipeline import ( + get_h2d_func, + PipelineStage, + StagedTrainPipeline, +) +from torchrec.distributed.train_pipeline.utils import SparseDataDistUtil +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.test_utils import init_distributed_single_host + + +class TrainPipelineSparseDistTest(MultiProcessTestBase): + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_train_pipeline(self) -> None: + device = torch.device("cuda:0") + pg = init_distributed_single_host(backend="nccl", rank=0, world_size=1) + + embedding_bag_configs = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=16, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=4, + num_embeddings=16, + ), + EmbeddingBagConfig( + name="table_2", + feature_names=["feature_2"], + embedding_dim=4, + num_embeddings=16, + ), + EmbeddingBagConfig( + name="table_3", + feature_names=["feature_3"], + embedding_dim=4, + num_embeddings=16, + ), + ] + + sharder = cast( + ModuleSharder[nn.Module], FeatureProcessedEmbeddingBagCollectionSharder() + ) + + class DummyWrapper(nn.Module): + def __init__(self, sparse_arch): + super().__init__() + self.m = sparse_arch + + def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: + return self.m(model_input.idlist_features) + + sparse_arch = DummyWrapper( + create_module_and_freeze( + tables=embedding_bag_configs, + device=device, + use_fp_collection=False, + ) + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch.m._fp_ebc, + per_param_sharding={ + "table_0": table_wise(rank=0), + "table_1": table_wise(rank=0), + "table_2": table_wise(rank=0), + "table_3": table_wise(rank=0), + }, + local_size=1, + world_size=1, + device_type=device.type, + sharder=sharder, + ) + sharded_sparse_arch_no_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(pg), + sharders=[sharder], + device=device, + ) + + sharded_sparse_arch_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(pg), + sharders=[sharder], + device=device, + ) + + copy_state_dict( + sharded_sparse_arch_no_pipeline.state_dict(), + sharded_sparse_arch_pipeline.state_dict(), + ) + + data = [ + ModelInput.generate( + tables=embedding_bag_configs, + weighted_tables=[], + batch_size=10, + world_size=1, + num_float_features=0, + pooling_avg=5, + )[0] + for i in range(10) + ] + + optimizer_no_pipeline = optim.SGD( + sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 + ) + optimizer_pipeline = optim.SGD( + sharded_sparse_arch_pipeline.parameters(), lr=0.1 + ) + + non_pipelined_outputs = [] + for batch in data: + batch = batch.to(device) + optimizer_no_pipeline.zero_grad() + loss, pred = sharded_sparse_arch_no_pipeline(batch) + loss.backward() + optimizer_no_pipeline.step() + non_pipelined_outputs.append(pred) + + h2d_stream = torch.cuda.Stream() + + # pyre-ignore + def gpu_preproc(x): + return x + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_sparse_arch_pipeline, + stream=torch.cuda.Stream(), + apply_jit=False, + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=device), + stream=h2d_stream, + ), + PipelineStage( + name="gpu_preproc", + runnable=gpu_preproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_preproc_1", + runnable=gpu_preproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_preproc_2", + runnable=gpu_preproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + ] + pipeline = StagedTrainPipeline( + data_iter=iter(data), + pipeline_stages=pipeline_stages, + debug_mode=True, + ) + + pipelined_out = [] + while model_in := pipeline.progress(): + optimizer_pipeline.zero_grad() + loss, pred = sharded_sparse_arch_pipeline(model_in) + loss.backward() + optimizer_pipeline.step() + pipelined_out.append(pred) + + self.assertEqual(len(pipelined_out), len(non_pipelined_outputs)) + + for out, ref_out in zip(pipelined_out, non_pipelined_outputs): + torch.testing.assert_close(out, ref_out) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/train_pipeline/tests/test_train_pipeline.py similarity index 99% rename from torchrec/distributed/tests/test_train_pipeline.py rename to torchrec/distributed/train_pipeline/tests/test_train_pipeline.py index 1dd0b1209..766f5f5e5 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipeline.py @@ -39,7 +39,10 @@ TestSparseNN, ) from torchrec.distributed.test_utils.test_sharding import copy_state_dict -from torchrec.distributed.train_pipeline import ( +from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( + create_module_and_freeze, +) +from torchrec.distributed.train_pipeline.train_pipeline import ( DataLoadingThread, EvalPipelineSparseDist, PrefetchTrainPipelineSparseDist, @@ -63,8 +66,6 @@ from torchrec.streamable import Pipelineable from torchrec.test_utils import get_free_port, init_distributed_single_host -from .test_fp_embeddingbag_utils import create_module_and_freeze - class TestShardedEmbeddingBagCollection(ShardedEmbeddingBagCollection): def input_dist( diff --git a/torchrec/distributed/train_pipeline/train_pipeline.py b/torchrec/distributed/train_pipeline/train_pipeline.py new file mode 100644 index 000000000..ded91db03 --- /dev/null +++ b/torchrec/distributed/train_pipeline/train_pipeline.py @@ -0,0 +1,635 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +NOTE: Due to an internal packaging issue, `train_pipeline.py` must be compatible with +older versions of TorchRec. Importing new modules from other files may break model +publishing flows. +""" +import abc +import logging +from threading import Event, Thread +from typing import cast, Generic, Iterator, List, Optional, Tuple + +import torch +from torch.autograd.profiler import record_function +from torchrec.distributed.model_parallel import ShardedModule +from torchrec.distributed.train_pipeline.utils import ( + _override_input_dist_forwards, + _rewrite_model, + _start_data_dist, + _to_device, + _wait_for_batch, + In, + Out, + PrefetchPipelinedForward, + PrefetchTrainPipelineContext, + TrainPipelineContext, +) +from torchrec.distributed.types import Awaitable +from torchrec.streamable import Multistreamable + +logger: logging.Logger = logging.getLogger(__name__) + + +class TrainPipeline(abc.ABC, Generic[In, Out]): + @abc.abstractmethod + def progress(self, dataloader_iter: Iterator[In]) -> Out: + pass + + +class TrainPipelineBase(TrainPipeline[In, Out]): + """ + This class runs training iterations using a pipeline of two stages, each as a CUDA + stream, namely, the current (default) stream and `self._memcpy_stream`. For each + iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU + memory, and the default stream runs forward, backward, and optimization. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.Stream() if device.type == "cuda" else None + ) + self._cur_batch: Optional[In] = None + self._connected = False + + def _connect(self, dataloader_iter: Iterator[In]) -> None: + cur_batch = next(dataloader_iter) + self._cur_batch = cur_batch + with torch.cuda.stream(self._memcpy_stream): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) + self._connected = True + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._connected: + self._connect(dataloader_iter) + + # Fetch next batch + with record_function("## next_batch ##"): + next_batch = next(dataloader_iter) + cur_batch = self._cur_batch + assert cur_batch is not None + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## _wait_for_batch ##"): + _wait_for_batch(cur_batch, self._memcpy_stream) + + with record_function("## forward ##"): + (losses, output) = self._model(cur_batch) + + if self._model.training: + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + # Copy the next batch to GPU + self._cur_batch = cur_batch = next_batch + with record_function("## copy_batch_to_gpu ##"): + with torch.cuda.stream(self._memcpy_stream): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) + + # Update + if self._model.training: + with record_function("## optimizer ##"): + self._optimizer.step() + + return output + + +class TrainPipelineSparseDist(TrainPipeline[In, Out]): + """ + This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with + forward and backward. This helps hide the all2all latency while preserving the + training forward / backward ordering. + + stage 3: forward, backward - uses default CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, and + forward/backward pass will happen. + execute_all_batches (bool): executes remaining batches in pipeline after + exhausting dataloader iterator. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._execute_all_batches = execute_all_batches + self._apply_jit = apply_jit + # use two data streams to support two concurrent batches + if device.type == "cuda": + self._memcpy_stream: Optional[ + torch.cuda.streams.Stream + ] = torch.cuda.Stream(priority=-1) + self._data_dist_stream: Optional[ + torch.cuda.streams.Stream + ] = torch.cuda.Stream(priority=-1) + else: + self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None + self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None + self._batch_i: Optional[In] = None + self._batch_ip1: Optional[In] = None + self._batch_ip2: Optional[In] = None + self._context = TrainPipelineContext() + self._pipelined_modules: List[ShardedModule] = [] + + def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + # pipeline is already filled + if self._batch_i and self._batch_ip1: + return + # executes last batch in pipeline + if self._batch_i and self._execute_all_batches: + return + + # batch 1 + self._batch_i = self._copy_batch_to_gpu(dataloader_iter) + if self._batch_i is None: + raise StopIteration + + self._init_pipelined_modules(self._batch_i) + self._start_sparse_data_dist(self._batch_i) + self._wait_sparse_data_dist() + + # batch 2 + self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + self._fill_pipeline(dataloader_iter) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## _wait_for_batch ##"): + _wait_for_batch(cast(In, self._batch_i), self._data_dist_stream) + + self._start_sparse_data_dist(self._batch_ip1) + + self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) + + # forward + with record_function("## forward ##"): + losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) + + self._wait_sparse_data_dist() + + if self._model.training: + # backward + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self._batch_i = self._batch_ip1 + self._batch_ip1 = self._batch_ip2 + + return output + + def _init_pipelined_modules(self, batch: In) -> None: + """ + Retrieves the pipelined modules after overriding their forwards, initializes the + modules' input dists, and overrides the input dist forwards to support fusing + the splits collective in the input dist. + """ + if self._pipelined_modules: + return + self._pipelined_modules, self._model = _rewrite_model( + model=self._model, + context=self._context, + dist_stream=self._data_dist_stream, + batch=self._batch_i, + apply_jit=self._apply_jit, + ) + # initializes input dist, so we can override input dist forwards + self._start_sparse_data_dist(self._batch_i) + _override_input_dist_forwards(self._pipelined_modules) + + def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]: + """ + Retrieves batch from dataloader and moves it to the provided device. + + Raises: + StopIteration: if the dataloader iterator is exhausted; unless + `self._execute_all_batches=True`, then returns None. + """ + with record_function("## copy_batch_to_gpu ##"): + with torch.cuda.stream(self._memcpy_stream): + batch = next(dataloader_iter, None) + if batch is not None: + batch = _to_device(batch, self._device, non_blocking=True) + elif not self._execute_all_batches: + raise StopIteration + return batch + + def _start_sparse_data_dist(self, batch: Optional[In]) -> None: + """ + Waits for batch to finish getting copied to GPU, then starts the input dist. + """ + if batch is None: + return + with record_function("## start_sparse_data_dist ##"): + with torch.cuda.stream(self._data_dist_stream): + _wait_for_batch(batch, self._memcpy_stream) + _start_data_dist(self._pipelined_modules, batch, self._context) + + def _wait_sparse_data_dist(self) -> None: + """ + Waits on the input dist splits requests to get the input dist tensors requests, + and populates the context with them. + """ + with record_function("## wait_sparse_data_dist ##"): + with torch.cuda.stream(self._data_dist_stream): + self._context.module_contexts = ( + self._context.module_contexts_next_batch.copy() + ) + self._context.input_dist_tensors_requests.clear() + for names, awaitable in self._context.fused_splits_awaitables: + for name, request in zip(names, awaitable.wait()): + self._context.input_dist_tensors_requests[name] = request + + +class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): + """ + This pipeline overlaps device transfer, `ShardedModule.input_dist()`, and cache + prefetching with forward and backward. This helps hide the all2all latency while + preserving the training forward / backward ordering. + + stage 4: forward, backward - uses default CUDA stream + stage 3: prefetch - uses prefetch CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, prefetch, + and forward/backward pass will happen. + execute_all_batches (bool): executes remaining batches in pipeline after + exhausting dataloader iterator. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + ) -> None: + super().__init__( + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=execute_all_batches, + apply_jit=apply_jit, + ) + self._context = PrefetchTrainPipelineContext() + if self._device.type == "cuda": + self._prefetch_stream: Optional[ + torch.cuda.streams.Stream + ] = torch.cuda.Stream() + self._default_stream: Optional[ + torch.cuda.streams.Stream + ] = torch.cuda.current_stream() + else: + self._prefetch_stream: Optional[torch.cuda.streams.Stream] = None + self._default_stream: Optional[torch.cuda.streams.Stream] = None + self._batch_ip3: Optional[In] = None + + def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + # pipeline is already filled + if self._batch_i and self._batch_ip1 and self._batch_ip2: + return + # executes last batch in pipeline + if self._execute_all_batches and (self._batch_i or self._batch_ip1): + return + + # batch 1 + self._batch_i = self._copy_batch_to_gpu(dataloader_iter) + if self._batch_i is None: + raise StopIteration + + self._init_pipelined_modules(self._batch_i) + self._start_sparse_data_dist(self._batch_i) + self._wait_sparse_data_dist() + self._prefetch(self._batch_i) + + # batch 2 + self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) + self._start_sparse_data_dist(self._batch_ip1) + self._wait_sparse_data_dist() + + # batch 3 + self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + self._fill_pipeline(dataloader_iter) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## _wait_for_batch ##"): + _wait_for_batch(cast(In, self._batch_i), self._prefetch_stream) + + self._start_sparse_data_dist(self._batch_ip2) + + self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter) + + # forward + with record_function("## forward ##"): + losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) + + self._prefetch(self._batch_ip1) + + self._wait_sparse_data_dist() + + if self._model.training: + # backward + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self._batch_i = self._batch_ip1 + self._batch_ip1 = self._batch_ip2 + self._batch_ip2 = self._batch_ip3 + + return output + + def _init_pipelined_modules(self, batch: In) -> None: + """ + Retrieves the pipelined modules after overriding their forwards, initializes the + modules' input dists, and overrides the input dist forwards to support fusing + the splits collective in the input dist. + """ + if self._pipelined_modules: + return + self._pipelined_modules, self._model = _rewrite_model( + model=self._model, + context=self._context, + dist_stream=self._data_dist_stream, + batch=self._batch_i, + apply_jit=self._apply_jit, + pipelined_forward=PrefetchPipelinedForward, + ) + + # initializes input dist, so we can override input dist forwards + self._start_sparse_data_dist(self._batch_i) + _override_input_dist_forwards(self._pipelined_modules) + + def _prefetch(self, batch: Optional[In]) -> None: + """ + Waits for input dist to finish, then prefetches data. + """ + if batch is None: + return + self._context.module_input_post_prefetch.clear() + self._context.module_contexts_post_prefetch.clear() + + with record_function("## sharded_module_prefetch ##"): + with torch.cuda.stream(self._prefetch_stream): + batch.record_stream(torch.cuda.current_stream()) + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + assert isinstance(forward, PrefetchPipelinedForward) + + assert forward._name in self._context.input_dist_tensors_requests + request = self._context.input_dist_tensors_requests[forward._name] + assert isinstance(request, Awaitable) + with record_function("## wait_sparse_data_dist ##"): + # Finish waiting on the dist_stream, + # in case some delayed stream scheduling happens during the wait() call. + with torch.cuda.stream(self._data_dist_stream): + data = request.wait() + + # Make sure that both result of input_dist and context + # are properly transferred to the current stream. + if self._data_dist_stream is not None: + torch.cuda.current_stream().wait_stream(self._data_dist_stream) + cur_stream = torch.cuda.current_stream() + + assert isinstance( + data, (torch.Tensor, Multistreamable) + ), f"{type(data)} must implement Multistreamable interface" + data.record_stream(cur_stream) + data.record_stream(self._default_stream) + + ctx = self._context.module_contexts[forward._name] + ctx.record_stream(cur_stream) + ctx.record_stream(self._default_stream) + + sharded_module.prefetch( + dist_input=data, forward_stream=self._default_stream + ) + self._context.module_input_post_prefetch[forward._name] = data + self._context.module_contexts_post_prefetch[ + forward._name + ] = self._context.module_contexts[forward._name] + + +class DataLoadingThread(Thread, Generic[In]): + def __init__( + self, + device: torch.device, + dataloader_iter: Iterator[In], + to_device_non_blocking: bool, + memcpy_stream_priority: int = 0, + ) -> None: + super().__init__() + self._stop: bool = False + self._dataloader_iter = dataloader_iter + self._buffer_empty_event: Event = Event() + self._buffer_filled_event: Event = Event() + self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.Stream(priority=memcpy_stream_priority) + if device.type == "cuda" + else None + ) + self._device = device + self.to_device_non_blocking = to_device_non_blocking + self._buffered: Optional[In] = None + self._buffer_empty_event.set() + + def run(self) -> None: + while not self._stop: + self._buffer_empty_event.wait() + # Set the filled event to unblock progress() and return. + if self._stop: + self._buffer_filled_event.set() + return + with record_function("## load_batch ##"): + try: + batch = next(self._dataloader_iter) + except StopIteration: + self._stop = True + self._buffer_filled_event.set() + return + with record_function("## copy_batch_to_gpu ##"): + with torch.cuda.stream(self._memcpy_stream): + self._buffered = cast( + In, + batch.to( + self._device, non_blocking=self.to_device_non_blocking + ), + ) + self._buffer_empty_event.clear() + self._buffer_filled_event.set() + + def stop(self) -> None: + logger.info("Stopping data loading thread...") + self._stop = True + # Unblock any thread that are waiting for these events. + self._buffer_filled_event.set() + self._buffer_empty_event.set() + logger.info("Data loading thread stopped.") + + def get_next_batch(self, none_throws: bool = False) -> Optional[In]: + """ + Get the next batch from the buffer if threading is enabled, otherwise + call load_next_batch directly. + + This function is not thread safe. We assume this is only invoked from + the main thread in the training loop. + """ + self._buffer_filled_event.wait() + batch = self._buffered + if batch is None: + if none_throws: + raise StopIteration + return None + self._buffered = None + self._buffer_filled_event.clear() + self._buffer_empty_event.set() + return batch + + +class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]): + """ + This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with + forward. This helps hide the all2all latency. We use a background thread to + perform device transfer to further reduce latency. + + stage 2: forward- uses default CUDA stream + stage 1: ShardedModule.input_dist() - uses data_dist CUDA stream + background: device transfer - uses memcpy CUDA stream + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, and + forward/backward pass will happen. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + apply_jit: bool = False, + ) -> None: + super().__init__(model, optimizer, device, True, apply_jit) + self._batch_loader: Optional[DataLoadingThread[In]] = None + + def __del__(self) -> None: + if self._batch_loader is not None: + self._batch_loader.stop() + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._batch_loader: + self._batch_loader = DataLoadingThread( + device=self._device, + dataloader_iter=dataloader_iter, + to_device_non_blocking=True, + memcpy_stream_priority=-1, + ) + self._batch_loader.start() + + batch_loader = self._batch_loader + assert batch_loader is not None + + # batch 1 + self._batch_i = batch_loader.get_next_batch() + assert self._batch_i is not None + + self._init_pipelined_modules(self._batch_i) + self._start_sparse_data_dist(self._batch_i) + self._wait_sparse_data_dist() + + # batch 2 + self._batch_ip1 = batch_loader.get_next_batch() + + if self._batch_i is None: + raise StopIteration + + batch_loader = self._batch_loader + assert batch_loader is not None + with record_function("## _wait_for_batch ##"): + _wait_for_batch(cast(In, self._batch_i), self._data_dist_stream) + + self._start_sparse_data_dist(self._batch_ip1) + + # forward + with record_function("## forward ##"): + losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) + + self._wait_sparse_data_dist() + + self._batch_i = self._batch_ip1 + self._batch_ip1 = batch_loader.get_next_batch() + + return output diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline/utils.py similarity index 56% rename from torchrec/distributed/train_pipeline.py rename to torchrec/distributed/train_pipeline/utils.py index 3892fa734..497b76e4f 100644 --- a/torchrec/distributed/train_pipeline.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -5,24 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -NOTE: Due to an internal packaging issue, `train_pipeline.py` must be compatible with -older versions of TorchRec. Importing new modules from other files may break model -publishing flows. -""" -import abc +#!/usr/bin/env python3 import copy import itertools import logging from collections import defaultdict from dataclasses import dataclass, field -from threading import Event, Thread from typing import ( Any, cast, Dict, Generic, - Iterator, List, Optional, Set, @@ -32,209 +25,31 @@ Union, ) -import torch from torch import distributed as dist -from torch.autograd.profiler import record_function from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.node import Node +from torch.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable from torchrec.distributed.embedding_sharding import ( KJTListAwaitable, KJTListSplitsAwaitable, ) from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule + from torchrec.distributed.types import Awaitable + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable, Pipelineable + logger: logging.Logger = logging.getLogger(__name__) +import torch In = TypeVar("In", bound=Pipelineable) Out = TypeVar("Out") -class DataLoadingThread(Thread, Generic[In]): - def __init__( - self, - device: torch.device, - dataloader_iter: Iterator[In], - to_device_non_blocking: bool, - memcpy_stream_priority: int = 0, - ) -> None: - super().__init__() - self._stop: bool = False - self._dataloader_iter = dataloader_iter - self._buffer_empty_event: Event = Event() - self._buffer_filled_event: Event = Event() - self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( - torch.cuda.Stream(priority=memcpy_stream_priority) - if device.type == "cuda" - else None - ) - self._device = device - self._to_device_non_blocking = to_device_non_blocking - self._buffered: Optional[In] = None - self._buffer_empty_event.set() - - def run(self) -> None: - while not self._stop: - self._buffer_empty_event.wait() - # Set the filled event to unblock progress() and return. - if self._stop: - self._buffer_filled_event.set() - return - with record_function("## load_batch ##"): - try: - batch = next(self._dataloader_iter) - except StopIteration: - self._stop = True - self._buffer_filled_event.set() - return - with record_function("## copy_batch_to_gpu ##"): - with torch.cuda.stream(self._memcpy_stream): - self._buffered = cast( - In, - batch.to( - self._device, non_blocking=self._to_device_non_blocking - ), - ) - self._buffer_empty_event.clear() - self._buffer_filled_event.set() - - def stop(self) -> None: - logger.info("Stopping data loading thread...") - self._stop = True - # Unblock any thread that are waiting for these events. - self._buffer_filled_event.set() - self._buffer_empty_event.set() - logger.info("Data loading thread stopped.") - - def get_next_batch(self, none_throws: bool = False) -> Optional[In]: - """ - Get the next batch from the buffer if threading is enabled, otherwise - call load_next_batch directly. - - This function is not thread safe. We assume this is only invoked from - the main thread in the training loop. - """ - self._buffer_filled_event.wait() - batch = self._buffered - if batch is None: - if none_throws: - raise StopIteration - return None - self._buffered = None - self._buffer_filled_event.clear() - self._buffer_empty_event.set() - return batch - - -class TrainPipeline(abc.ABC, Generic[In, Out]): - @abc.abstractmethod - def progress(self, dataloader_iter: Iterator[In]) -> Out: - pass - - -def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: - assert isinstance( - batch, (torch.Tensor, Pipelineable) - ), f"{type(batch)} must implement Pipelineable interface" - return cast(In, batch.to(device=device, non_blocking=non_blocking)) - - -def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: - if stream is None: - return - torch.cuda.current_stream().wait_stream(stream) - """ - As mentioned in - https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch - uses the "caching allocator" for memory allocation for tensors. When a tensor is - freed, its memory is likely to be reused by newly constructed tenosrs. By default, - this allocator traces whether a tensor is still in use by only the CUDA stream where - it was created. When a tensor is used by additional CUDA streams, we need to call - `record_stream` to tell the allocator about these streams. Otherwise, the allocator - might free the underlying memory of the tensor once it is no longer used by the - creator stream. This is a notable programming trick when we write programs using - multiple CUDA streams. - """ - - cur_stream = torch.cuda.current_stream() - assert isinstance( - batch, (torch.Tensor, Multistreamable) - ), f"{type(batch)} must implement Multistreamable interface" - batch.record_stream(cur_stream) - - -class TrainPipelineBase(TrainPipeline[In, Out]): - """ - This class runs training iterations using a pipeline of two stages, each as a CUDA - stream, namely, the current (default) stream and `self._memcpy_stream`. For each - iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU - memory, and the default stream runs forward, backward, and optimization. - """ - - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - device: torch.device, - ) -> None: - self._model = model - self._optimizer = optimizer - self._device = device - self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( - torch.cuda.Stream() if device.type == "cuda" else None - ) - self._cur_batch: Optional[In] = None - self._connected = False - - def _connect(self, dataloader_iter: Iterator[In]) -> None: - cur_batch = next(dataloader_iter) - self._cur_batch = cur_batch - with torch.cuda.stream(self._memcpy_stream): - self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) - self._connected = True - - def progress(self, dataloader_iter: Iterator[In]) -> Out: - if not self._connected: - self._connect(dataloader_iter) - - # Fetch next batch - with record_function("## next_batch ##"): - next_batch = next(dataloader_iter) - cur_batch = self._cur_batch - assert cur_batch is not None - - if self._model.training: - with record_function("## zero_grad ##"): - self._optimizer.zero_grad() - - with record_function("## wait_for_batch ##"): - _wait_for_batch(cur_batch, self._memcpy_stream) - - with record_function("## forward ##"): - (losses, output) = self._model(cur_batch) - - if self._model.training: - with record_function("## backward ##"): - torch.sum(losses, dim=0).backward() - - # Copy the next batch to GPU - self._cur_batch = cur_batch = next_batch - with record_function("## copy_batch_to_gpu ##"): - with torch.cuda.stream(self._memcpy_stream): - self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) - - # Update - if self._model.training: - with record_function("## optimizer ##"): - self._optimizer.step() - - return output - - class Tracer(torch.fx.Tracer): """ Disables proxying buffers during tracing. Ideally, proxying buffers would be @@ -614,6 +429,37 @@ def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta: ) +def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: + assert isinstance( + batch, (torch.Tensor, Pipelineable) + ), f"{type(batch)} must implement Pipelineable interface" + return cast(In, batch.to(device=device, non_blocking=non_blocking)) + + +def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: + if stream is None: + return + torch.cuda.current_stream().wait_stream(stream) + """ + As mentioned in + https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch + uses the "caching allocator" for memory allocation for tensors. When a tensor is + freed, its memory is likely to be reused by newly constructed tenosrs. By default, + this allocator traces whether a tensor is still in use by only the CUDA stream where + it was created. When a tensor is used by additional CUDA streams, we need to call + `record_stream` to tell the allocator about these streams. Otherwise, the allocator + might free the underlying memory of the tensor once it is no longer used by the + creator stream. This is a notable programming trick when we write programs using + multiple CUDA streams. + """ + + cur_stream = torch.cuda.current_stream() + assert isinstance( + batch, (torch.Tensor, Multistreamable) + ), f"{type(batch)} must implement Multistreamable interface" + batch.record_stream(cur_stream) + + def _start_data_dist( pipelined_modules: List[ShardedModule], batch: In, @@ -956,450 +802,54 @@ def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> Non ) -class TrainPipelineSparseDist(TrainPipeline[In, Out]): - """ - This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with - forward and backward. This helps hide the all2all latency while preserving the - training forward / backward ordering. - - stage 3: forward, backward - uses default CUDA stream - stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream - stage 1: device transfer - uses memcpy CUDA stream - - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. - - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - - Args: - model (torch.nn.Module): model to pipeline. - optimizer (torch.optim.Optimizer): optimizer to use. - device (torch.device): device where device transfer, sparse data dist, and - forward/backward pass will happen. - execute_all_batches (bool): executes remaining batches in pipeline after - exhausting dataloader iterator. - apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. - """ - +class SparseDataDistUtil(Generic[In]): def __init__( self, model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - device: torch.device, - execute_all_batches: bool = True, + stream: torch.cuda.streams.Stream, apply_jit: bool = False, ) -> None: - self._model = model - self._optimizer = optimizer - self._device = device - self._execute_all_batches = execute_all_batches - self._apply_jit = apply_jit - # use two data streams to support two concurrent batches - if device.type == "cuda": - self._memcpy_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream(priority=-1) - self._data_dist_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream(priority=-1) - else: - self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None - self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None - self._batch_i: Optional[In] = None - self._batch_ip1: Optional[In] = None - self._batch_ip2: Optional[In] = None - self._context = TrainPipelineContext() + super().__init__() + self.model = model + self.stream = stream + self.apply_jit = apply_jit + self.context = TrainPipelineContext() + self.initialized = False self._pipelined_modules: List[ShardedModule] = [] - def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: - # pipeline is already filled - if self._batch_i and self._batch_ip1: - return - # executes last batch in pipeline - if self._batch_i and self._execute_all_batches: - return - - # batch 1 - self._batch_i = self._copy_batch_to_gpu(dataloader_iter) - if self._batch_i is None: - raise StopIteration - - self._init_pipelined_modules(self._batch_i) - self._start_sparse_data_dist(self._batch_i) - self._wait_sparse_data_dist() - - # batch 2 - self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) - - def progress(self, dataloader_iter: Iterator[In]) -> Out: - self._fill_pipeline(dataloader_iter) - - if self._model.training: - with record_function("## zero_grad ##"): - self._optimizer.zero_grad() - - with record_function("## wait_for_batch ##"): - _wait_for_batch(cast(In, self._batch_i), self._data_dist_stream) - - self._start_sparse_data_dist(self._batch_ip1) - - self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) - - # forward - with record_function("## forward ##"): - losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) - - self._wait_sparse_data_dist() - - if self._model.training: - # backward - with record_function("## backward ##"): - torch.sum(losses, dim=0).backward() - - # update - with record_function("## optimizer ##"): - self._optimizer.step() - - self._batch_i = self._batch_ip1 - self._batch_ip1 = self._batch_ip2 - - return output - - def _init_pipelined_modules(self, batch: In) -> None: - """ - Retrieves the pipelined modules after overriding their forwards, initializes the - modules' input dists, and overrides the input dist forwards to support fusing - the splits collective in the input dist. - """ - if self._pipelined_modules: - return - self._pipelined_modules, self._model = _rewrite_model( - model=self._model, - context=self._context, - dist_stream=self._data_dist_stream, - batch=self._batch_i, - apply_jit=self._apply_jit, - ) - # initializes input dist, so we can override input dist forwards - self._start_sparse_data_dist(self._batch_i) - _override_input_dist_forwards(self._pipelined_modules) - - def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]: - """ - Retrieves batch from dataloader and moves it to the provided device. - - Raises: - StopIteration: if the dataloader iterator is exhausted; unless - `self._execute_all_batches=True`, then returns None. - """ - with record_function("## copy_batch_to_gpu ##"): - with torch.cuda.stream(self._memcpy_stream): - batch = next(dataloader_iter, None) - if batch is not None: - batch = _to_device(batch, self._device, non_blocking=True) - elif not self._execute_all_batches: - raise StopIteration - return batch - - def _start_sparse_data_dist(self, batch: Optional[In]) -> None: - """ - Waits for batch to finish getting copied to GPU, then starts the input dist. - """ - if batch is None: - return - with record_function("## start_sparse_data_dist ##"): - with torch.cuda.stream(self._data_dist_stream): - _wait_for_batch(batch, self._memcpy_stream) - _start_data_dist(self._pipelined_modules, batch, self._context) - - def _wait_sparse_data_dist(self) -> None: - """ - Waits on the input dist splits requests to get the input dist tensors requests, - and populates the context with them. - """ - with record_function("## wait_sparse_data_dist ##"): - with torch.cuda.stream(self._data_dist_stream): - self._context.module_contexts = ( - self._context.module_contexts_next_batch.copy() - ) - self._context.input_dist_tensors_requests.clear() - for names, awaitable in self._context.fused_splits_awaitables: - for name, request in zip(names, awaitable.wait()): - self._context.input_dist_tensors_requests[name] = request - - -class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]): - """ - This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with - forward. This helps hide the all2all latency. We use a background thread to - perform device transfer to further reduce latency. - - stage 2: forward- uses default CUDA stream - stage 1: ShardedModule.input_dist() - uses data_dist CUDA stream - background: device transfer - uses memcpy CUDA stream - - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. - - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - - Args: - model (torch.nn.Module): model to pipeline. - optimizer (torch.optim.Optimizer): optimizer to use. - device (torch.device): device where device transfer, sparse data dist, and - forward/backward pass will happen. - apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. - """ - - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - device: torch.device, - apply_jit: bool = False, - ) -> None: - super().__init__(model, optimizer, device, True, apply_jit) - self._batch_loader: Optional[DataLoadingThread[In]] = None - - def __del__(self) -> None: - if self._batch_loader is not None: - self._batch_loader.stop() - - def progress(self, dataloader_iter: Iterator[In]) -> Out: - if not self._batch_loader: - self._batch_loader = DataLoadingThread( - device=self._device, - dataloader_iter=dataloader_iter, - to_device_non_blocking=True, - memcpy_stream_priority=-1, + # pyre-ignore + self.original_forward = self.model.forward + + def forward_hook( + module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor]], + output: Union[torch.Tensor, Tuple[torch.Tensor]], + ) -> None: + self.wait_sparse_data_dist() + + self.model.register_forward_hook(forward_hook) + + def start_sparse_data_dist(self, batch: In) -> In: + if not self.initialized: + self._pipelined_modules, self.model = _rewrite_model( + model=self.model, + context=self.context, + dist_stream=self.stream, + batch=batch, + apply_jit=self.apply_jit, ) - self._batch_loader.start() - - batch_loader = self._batch_loader - assert batch_loader is not None - - # batch 1 - self._batch_i = batch_loader.get_next_batch() - assert self._batch_i is not None + # initializes input dist, so we can override input dist forwards + _start_data_dist(self._pipelined_modules, batch, self.context) + _override_input_dist_forwards(self._pipelined_modules) + self.initialized = True - self._init_pipelined_modules(self._batch_i) - self._start_sparse_data_dist(self._batch_i) - self._wait_sparse_data_dist() + _start_data_dist(self._pipelined_modules, batch, self.context) - # batch 2 - self._batch_ip1 = batch_loader.get_next_batch() - - if self._batch_i is None: - raise StopIteration - - batch_loader = self._batch_loader - assert batch_loader is not None - with record_function("## wait_for_batch ##"): - _wait_for_batch(cast(In, self._batch_i), self._data_dist_stream) - - self._start_sparse_data_dist(self._batch_ip1) - - # forward - with record_function("## forward ##"): - losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) - - self._wait_sparse_data_dist() - - self._batch_i = self._batch_ip1 - self._batch_ip1 = batch_loader.get_next_batch() - - return output - - -class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): - """ - This pipeline overlaps device transfer, `ShardedModule.input_dist()`, and cache - prefetching with forward and backward. This helps hide the all2all latency while - preserving the training forward / backward ordering. - - stage 4: forward, backward - uses default CUDA stream - stage 3: prefetch - uses prefetch CUDA stream - stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream - stage 1: device transfer - uses memcpy CUDA stream - - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. - - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - - Args: - model (torch.nn.Module): model to pipeline. - optimizer (torch.optim.Optimizer): optimizer to use. - device (torch.device): device where device transfer, sparse data dist, prefetch, - and forward/backward pass will happen. - execute_all_batches (bool): executes remaining batches in pipeline after - exhausting dataloader iterator. - apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. - """ - - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - device: torch.device, - execute_all_batches: bool = True, - apply_jit: bool = False, - ) -> None: - super().__init__( - model=model, - optimizer=optimizer, - device=device, - execute_all_batches=execute_all_batches, - apply_jit=apply_jit, - ) - self._context = PrefetchTrainPipelineContext() - if self._device.type == "cuda": - self._prefetch_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream() - self._default_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.current_stream() - else: - self._prefetch_stream: Optional[torch.cuda.streams.Stream] = None - self._default_stream: Optional[torch.cuda.streams.Stream] = None - self._batch_ip3: Optional[In] = None - - def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: - # pipeline is already filled - if self._batch_i and self._batch_ip1 and self._batch_ip2: - return - # executes last batch in pipeline - if self._execute_all_batches and (self._batch_i or self._batch_ip1): - return - - # batch 1 - self._batch_i = self._copy_batch_to_gpu(dataloader_iter) - if self._batch_i is None: - raise StopIteration - - self._init_pipelined_modules(self._batch_i) - self._start_sparse_data_dist(self._batch_i) - self._wait_sparse_data_dist() - self._prefetch(self._batch_i) - - # batch 2 - self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) - self._start_sparse_data_dist(self._batch_ip1) - self._wait_sparse_data_dist() - - # batch 3 - self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) - - def progress(self, dataloader_iter: Iterator[In]) -> Out: - self._fill_pipeline(dataloader_iter) - - if self._model.training: - with record_function("## zero_grad ##"): - self._optimizer.zero_grad() - - with record_function("## wait_for_batch ##"): - _wait_for_batch(cast(In, self._batch_i), self._prefetch_stream) - - self._start_sparse_data_dist(self._batch_ip2) - - self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter) - - # forward - with record_function("## forward ##"): - losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) - - self._prefetch(self._batch_ip1) - - self._wait_sparse_data_dist() - - if self._model.training: - # backward - with record_function("## backward ##"): - torch.sum(losses, dim=0).backward() - - # update - with record_function("## optimizer ##"): - self._optimizer.step() - - self._batch_i = self._batch_ip1 - self._batch_ip1 = self._batch_ip2 - self._batch_ip2 = self._batch_ip3 - - return output - - def _init_pipelined_modules(self, batch: In) -> None: - """ - Retrieves the pipelined modules after overriding their forwards, initializes the - modules' input dists, and overrides the input dist forwards to support fusing - the splits collective in the input dist. - """ - if self._pipelined_modules: - return - self._pipelined_modules, self._model = _rewrite_model( - model=self._model, - context=self._context, - dist_stream=self._data_dist_stream, - batch=self._batch_i, - apply_jit=self._apply_jit, - pipelined_forward=PrefetchPipelinedForward, - ) + return batch - # initializes input dist, so we can override input dist forwards - self._start_sparse_data_dist(self._batch_i) - _override_input_dist_forwards(self._pipelined_modules) - - def _prefetch(self, batch: Optional[In]) -> None: - """ - Waits for input dist to finish, then prefetches data. - """ - if batch is None: - return - self._context.module_input_post_prefetch.clear() - self._context.module_contexts_post_prefetch.clear() - - with record_function("## sharded_module_prefetch ##"): - with torch.cuda.stream(self._prefetch_stream): - batch.record_stream(torch.cuda.current_stream()) - for sharded_module in self._pipelined_modules: - forward = sharded_module.forward - assert isinstance(forward, PrefetchPipelinedForward) - - assert forward._name in self._context.input_dist_tensors_requests - request = self._context.input_dist_tensors_requests[forward._name] - assert isinstance(request, Awaitable) - with record_function("## wait_sparse_data_dist ##"): - # Finish waiting on the dist_stream, - # in case some delayed stream scheduling happens during the wait() call. - with torch.cuda.stream(self._data_dist_stream): - data = request.wait() - - # Make sure that both result of input_dist and context - # are properly transferred to the current stream. - if self._data_dist_stream is not None: - torch.cuda.current_stream().wait_stream(self._data_dist_stream) - cur_stream = torch.cuda.current_stream() - - assert isinstance( - data, (torch.Tensor, Multistreamable) - ), f"{type(data)} must implement Multistreamable interface" - data.record_stream(cur_stream) - data.record_stream(self._default_stream) - - ctx = self._context.module_contexts[forward._name] - ctx.record_stream(cur_stream) - ctx.record_stream(self._default_stream) - - sharded_module.prefetch( - dist_input=data, forward_stream=self._default_stream - ) - self._context.module_input_post_prefetch[forward._name] = data - self._context.module_contexts_post_prefetch[ - forward._name - ] = self._context.module_contexts[forward._name] + def wait_sparse_data_dist(self) -> None: + self.context.module_contexts = self.context.module_contexts_next_batch.copy() + self.context.input_dist_tensors_requests.clear() + for names, awaitable in self.context.fused_splits_awaitables: + for name, request in zip(names, awaitable.wait()): + self.context.input_dist_tensors_requests[name] = request