diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index f60672833..0da971457 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -1667,7 +1667,7 @@ def gpu_preproc(x: StageOut) -> StageOut: sdd = SparseDataDistUtil[ModelInput]( model=sharded_model_pipelined, - stream=torch.cuda.Stream(), + data_dist_stream=torch.cuda.Stream(), apply_jit=False, ) @@ -1695,7 +1695,7 @@ def gpu_preproc(x: StageOut) -> StageOut: PipelineStage( name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, - stream=sdd.stream, + stream=sdd.data_dist_stream, fill_callback=sdd.wait_sparse_data_dist, ), ] @@ -1744,7 +1744,7 @@ def gpu_preproc(x: StageOut) -> StageOut: sdd = SparseDataDistUtil[ModelInput]( model=sharded_model_pipelined, - stream=torch.cuda.Stream(), + data_dist_stream=torch.cuda.Stream(), apply_jit=False, ) @@ -1762,7 +1762,7 @@ def gpu_preproc(x: StageOut) -> StageOut: PipelineStage( name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, - stream=sdd.stream, + stream=sdd.data_dist_stream, fill_callback=sdd.wait_sparse_data_dist, ), ] @@ -1860,7 +1860,7 @@ def test_model_detach(self) -> None: sdd = SparseDataDistUtil[ModelInput]( model=sharded_model_pipelined, - stream=torch.cuda.Stream(), + data_dist_stream=torch.cuda.Stream(), apply_jit=False, ) @@ -1873,7 +1873,7 @@ def test_model_detach(self) -> None: PipelineStage( name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, - stream=sdd.stream, + stream=sdd.data_dist_stream, fill_callback=sdd.wait_sparse_data_dist, ), ] @@ -1964,3 +1964,133 @@ def test_model_detach(self) -> None: # Check pipeline exhausted preproc_input = pipeline.progress(dataloader) self.assertIsNone(preproc_input) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] + ), + cache_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + load_factor=st.sampled_from( + [ + 0.2, + 0.4, + ] + ), + ) + def test_pipelining_prefetch( + self, + sharding_type: str, + kernel_type: str, + cache_precision: DataType, + load_factor: float, + ) -> None: + model = self._setup_model() + + fused_params = { + "cache_load_factor": load_factor, + "cache_precision": cache_precision, + "stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16 + } + fused_params_pipelined = { + **fused_params, + "prefetch_pipeline": True, + } + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + num_batches = 12 + data = self._generate_data( + num_batches=num_batches, + batch_size=32, + ) + + non_pipelined_outputs = [] + for batch in data: + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + non_pipelined_outputs.append(pred) + + def gpu_preproc(x: StageOut) -> StageOut: + return x + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_model_pipelined, + data_dist_stream=torch.cuda.Stream(), + apply_jit=False, + prefetch_stream=torch.cuda.Stream(), + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=self.device), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + PipelineStage( + name="prefetch", + runnable=sdd.prefetch, + # pyre-ignore + stream=sdd.prefetch_stream, + fill_callback=sdd.load_prefetch, + ), + ] + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream() + ) + dataloader = iter(data) + + pipelined_out = [] + num_batches_processed = 0 + + while model_in := pipeline.progress(dataloader): + num_batches_processed += 1 + optim_pipelined.zero_grad() + loss, pred = sharded_model_pipelined(model_in) + loss.backward() + optim_pipelined.step() + pipelined_out.append(pred) + + self.assertEqual(num_batches_processed, num_batches) + + self.assertEqual(len(pipelined_out), len(non_pipelined_outputs)) + for out, ref_out in zip(pipelined_out, non_pipelined_outputs): + torch.testing.assert_close(out, ref_out) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 63e71cda9..8e59858c1 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -33,6 +33,7 @@ from torchrec.distributed.train_pipeline.utils import ( _override_input_dist_forwards, _pipeline_detach_model, + _prefetch_embeddings, _rewrite_model, _start_data_dist, _start_embedding_lookup, @@ -1101,46 +1102,18 @@ def _prefetch(self, batch: Optional[In]) -> None: batch.record_stream( torch.get_device_module(self._device).current_stream() ) + data_per_pipelined_module = _prefetch_embeddings( + batch, + self._context, + self._pipelined_modules, + self._device, + self._stream_context, + self._data_dist_stream, + self._default_stream, + ) for sharded_module in self._pipelined_modules: forward = sharded_module.forward - assert isinstance(forward, PrefetchPipelinedForward) - - assert forward._name in self._context.input_dist_tensors_requests - request = self._context.input_dist_tensors_requests.pop( - forward._name - ) - assert isinstance(request, Awaitable) - with record_function("## wait_sparse_data_dist ##"): - # Finish waiting on the dist_stream, - # in case some delayed stream scheduling happens during the wait() call. - with self._stream_context(self._data_dist_stream): - data = request.wait() - - # Make sure that both result of input_dist and context - # are properly transferred to the current stream. - module_context = self._context.module_contexts[forward._name] - if self._data_dist_stream is not None: - torch.get_device_module( - self._device - ).current_stream().wait_stream(self._data_dist_stream) - cur_stream = torch.get_device_module( - self._device - ).current_stream() - - assert isinstance( - data, (torch.Tensor, Multistreamable) - ), f"{type(data)} must implement Multistreamable interface" - data.record_stream(cur_stream) - data.record_stream(self._default_stream) - - module_context.record_stream(cur_stream) - module_context.record_stream(self._default_stream) - - sharded_module.prefetch( - ctx=module_context, - dist_input=data, - forward_stream=self._default_stream, - ) + data = data_per_pipelined_module[forward._name] self._context.module_input_post_prefetch[forward._name] = data self._context.module_contexts_post_prefetch[forward._name] = ( self._context.module_contexts.pop(forward._name) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 2972112cc..00c8a4295 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -41,6 +41,7 @@ KJTListSplitsAwaitable, KJTSplitsAllToAllMeta, ) +from torchrec.distributed.embedding_types import KJTList from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule from torchrec.distributed.types import Awaitable, LazyNoWait @@ -109,6 +110,12 @@ class PrefetchTrainPipelineContext(TrainPipelineContext): module_contexts_post_prefetch: Dict[str, Multistreamable] = field( default_factory=dict ) + module_input_post_prefetch_next_batch: Dict[str, Multistreamable] = field( + default_factory=dict + ) + module_contexts_post_prefetch_next_batch: Dict[str, Multistreamable] = field( + default_factory=dict + ) @dataclass @@ -1212,23 +1219,121 @@ def get_next_batch(self, none_throws: bool = False) -> Optional[In]: return batch +def _prefetch_embeddings( + batch: In, + context: PrefetchTrainPipelineContext, + pipelined_modules: List[ShardedModule], + device: torch.device, + stream_context: torch.Stream, + data_dist_stream: Optional[torch.Stream], + default_stream: Optional[torch.Stream], +) -> Dict[str, KJTList]: + data_per_sharded_module = {} + for sharded_module in pipelined_modules: + forward = sharded_module.forward + assert isinstance(forward, PrefetchPipelinedForward) + + assert forward._name in context.input_dist_tensors_requests + request = context.input_dist_tensors_requests.pop(forward._name) + assert isinstance(request, Awaitable) + with record_function("## wait_sparse_data_dist ##"): + # Finish waiting on the dist_stream, + # in case some delayed stream scheduling happens during the wait() call. + with stream_context(data_dist_stream): + data = request.wait() + + # Make sure that both result of input_dist and context + # are properly transferred to the current stream. + module_context = context.module_contexts[forward._name] + if data_dist_stream is not None: + torch.get_device_module(device).current_stream().wait_stream( + data_dist_stream + ) + cur_stream = torch.get_device_module(device).current_stream() + + assert isinstance( + data, (torch.Tensor, Multistreamable) + ), f"{type(data)} must implement Multistreamable interface" + data.record_stream(cur_stream) + data.record_stream(default_stream) + + module_context.record_stream(cur_stream) + module_context.record_stream(default_stream) + + sharded_module.prefetch( + ctx=module_context, + dist_input=data, + forward_stream=default_stream, + ) + data_per_sharded_module[forward._name] = data + return data_per_sharded_module + + class SparseDataDistUtil(Generic[In]): + """ + Helper class exposing methods for sparse data dist and prefetch pipelining. + Currently used for `StagedTrainPipeline` pipeline stages + + Args: + model (torch.nn.Module): Model to pipeline + prefetch_stream (torch.cuda.Stream): Stream on which to run sparse data dist. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs + Defaults to `None`. This needs to be passed in to enable prefetch pipelining. + + Example:: + sdd = SparseDataDistUtil( + model=model, + data_dist_stream=torch.cuda.Stream(), + prefetch_stream=torch.cuda.Stream(), <-- required to enable prefetch pipeline + ) + pipeline = [ + PipelineStage( + name="data_copy", + runnable=lambda batch, context: batch.to( + self._device, non_blocking=True + ), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + PipelineStage( + name="prefetch", + runnable=sdd.prefetch, + stream=sdd.prefetch_stream, + fill_callback=sdd.load_prefetch, + ), + ] + + return StagedTrainPipeline(pipeline_stages=pipeline) + """ + def __init__( self, model: torch.nn.Module, - stream: torch.Stream, + data_dist_stream: torch.Stream, apply_jit: bool = False, + prefetch_stream: Optional[torch.Stream] = None, ) -> None: super().__init__() self.model = model - self.stream = stream + self.data_dist_stream = data_dist_stream + self.prefetch_stream = prefetch_stream self.apply_jit = apply_jit - self.context = TrainPipelineContext(version=0) + self.context = ( + PrefetchTrainPipelineContext(version=0) + if prefetch_stream + else TrainPipelineContext(version=0) + ) self.initialized = False self._pipelined_modules: List[ShardedModule] = [] # pyre-ignore self.fwd_hook = None - self._device: torch.device = stream.device + self._device: torch.device = data_dist_stream.device # pyre-ignore self._original_forwards: List[Callable[..., Any]] = [] @@ -1236,6 +1341,16 @@ def __init__( Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]] ] = [] + self._pipelined_forward = ( + PrefetchPipelinedForward if prefetch_stream else PipelinedForward + ) + + self._default_stream: Optional[torch.Stream] = ( + (torch.get_device_module(self._device).Stream()) + if self._device.type in ["cuda", "mtia"] + else None + ) + def detach(self) -> torch.nn.Module: """ Removes sparse data dist (SDD) pipelining from model forward and input dist. @@ -1270,9 +1385,10 @@ def start_sparse_data_dist(self, batch: In) -> In: _rewrite_model( model=self.model, context=self.context, - dist_stream=self.stream, + dist_stream=self.data_dist_stream, batch=batch, apply_jit=self.apply_jit, + pipelined_forward=self._pipelined_forward, ) ) # initializes input dist, so we can override input dist forwards @@ -1287,6 +1403,9 @@ def forward_hook( input: Union[torch.Tensor, Tuple[torch.Tensor]], output: Union[torch.Tensor, Tuple[torch.Tensor]], ) -> None: + if self.prefetch_stream is not None: + # Need to load prefetch before wait_sparse_data_dist + self.load_prefetch() self.wait_sparse_data_dist() self.fwd_hook = self.model.register_forward_hook(forward_hook) @@ -1299,7 +1418,7 @@ def forward_hook( def wait_sparse_data_dist(self) -> None: with record_function("## wait_sparse_data_dist ##"): - with torch.get_device_module(self._device).stream(self.stream): + with torch.get_device_module(self._device).stream(self.data_dist_stream): self.context.module_contexts = ( self.context.module_contexts_next_batch.copy() ) @@ -1307,3 +1426,58 @@ def wait_sparse_data_dist(self) -> None: for names, awaitable in self.context.fused_splits_awaitables: for name, request in zip(names, awaitable.wait()): self.context.input_dist_tensors_requests[name] = request + + def prefetch(self, batch: In) -> In: + """ + Waits for input dist to finish, then prefetches data. + """ + assert isinstance( + self.context, PrefetchTrainPipelineContext + ), "Pass prefetch_stream into SparseDataDistUtil to use prefetch() as a stage" + self.context.module_input_post_prefetch_next_batch.clear() + # pyre-ignore + self.context.module_contexts_post_prefetch_next_batch.clear() + + data_per_pipelined_module = _prefetch_embeddings( + batch, + # pyre-ignore + self.context, + self._pipelined_modules, + self._device, + torch.get_device_module(self._device).stream, + self.data_dist_stream, + self._default_stream, + ) + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + data = data_per_pipelined_module[forward._name] + # pyre-ignore [16] + self.context.module_input_post_prefetch_next_batch[forward._name] = data + self.context.module_contexts_post_prefetch_next_batch[forward._name] = ( + self.context.module_contexts.pop(forward._name) + ) + return batch + + def load_prefetch(self) -> None: + assert isinstance( + self.context, PrefetchTrainPipelineContext + ), "Pass prefetch_stream into SparseDataDistUtil to use load_prefetch()" + self.context.module_input_post_prefetch.clear() + # pyre-ignore + self.context.module_contexts_post_prefetch.clear() + + with record_function("## load_sharded_module_prefetch ##"): + with torch.get_device_module(self._device).stream(self.prefetch_stream): + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + assert isinstance(forward, PrefetchPipelinedForward) + self.context.module_input_post_prefetch[forward._name] = ( + self.context.module_input_post_prefetch_next_batch[ + forward._name + ] + ) + self.context.module_contexts_post_prefetch[forward._name] = ( + self.context.module_contexts_post_prefetch_next_batch[ + forward._name + ] + )