diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipeline.py b/torchrec/distributed/train_pipeline/tests/test_train_pipeline.py index 0e22f4ef7..38cbde818 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipeline.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipeline.py @@ -873,7 +873,9 @@ def gpu_preproc(x: StageOut) -> StageOut: fill_callback=sdd.wait_sparse_data_dist, ), ] - pipeline = StagedTrainPipeline(pipeline_stages=pipeline_stages) + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream() + ) dataloader = iter(data) pipelined_out = [] diff --git a/torchrec/distributed/train_pipeline/train_pipeline.py b/torchrec/distributed/train_pipeline/train_pipeline.py index a137f02a1..69309b8bb 100644 --- a/torchrec/distributed/train_pipeline/train_pipeline.py +++ b/torchrec/distributed/train_pipeline/train_pipeline.py @@ -582,6 +582,8 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]): Args: pipeline_stages (List[PipelineStage]): A list of stages to execute. + compute_stream (Optional[torch.cuda.Stream]): The main compute stream in which model forward is run, + usually torch.cuda.default_stream(). Defaults to the current cuda stream. debug_mode (bool): Whether to enable debug mode. Example:: @@ -610,6 +612,7 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]): def __init__( self, pipeline_stages: List[PipelineStage], + compute_stream: Optional[torch.cuda.Stream] = None, debug_mode: bool = False, ) -> None: self._pipeline_stages = pipeline_stages @@ -619,20 +622,20 @@ def __init__( ) self._initialized = False self._num_steps = 0 + self._data_depleted = False + self._compute_stream = compute_stream or torch.cuda.current_stream() @property def num_stages(self) -> int: return len(self._pipeline_stages) - def _advance(self) -> Optional[StageOut]: + def _advance(self) -> Optional[StageOutputWithEvent]: # left shifts all batch results. out = self._stage_outputs[0] for idx in range(self.num_stages - 1): self._stage_outputs[idx] = self._stage_outputs[idx + 1] self._stage_outputs[-1] = None - if out is None: - return out - return out[0] + return out def _run_with_event( self, @@ -654,6 +657,15 @@ def _run_with_event( new_event.record(stream) return (output, new_event) + def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: + if self._data_depleted: + return None + + batch_to_wait = next(dataloader_iter, None) + if batch_to_wait is None: + self._data_depleted = True + return batch_to_wait + def _run_stage( self, batch_offset: int, @@ -672,7 +684,7 @@ def _run_stage( f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##" ): if stage_idx == 0: - batch_to_wait = next(dataloader_iter, None) + batch_to_wait = self._next_batch(dataloader_iter) event = None else: batch_to_wait_with_event = self._stage_outputs[batch_offset] @@ -757,7 +769,12 @@ def progress( if not self._initialized: self._fill_pipeline(dataloader_iter) - output = self._advance() + output_with_event = self._advance() + + if output_with_event is None: + # All data consumed, exit early + return None + self._num_steps += 1 for stage_idx in range(self.num_stages): @@ -768,4 +785,11 @@ def progress( dataloader_iter=dataloader_iter, ) - return output + out, event = output_with_event + if event is not None: + # Since model forward() is expected to run outside the pipeline, + # we need to explicitly wait for the last stage to finish + event.wait(self._compute_stream) + out.record_stream(self._compute_stream) + + return out diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index ce995f2cd..a95fc1cb2 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -963,8 +963,12 @@ def start_sparse_data_dist(self, batch: In) -> In: return batch def wait_sparse_data_dist(self) -> None: - self.context.module_contexts = self.context.module_contexts_next_batch.copy() - self.context.input_dist_tensors_requests.clear() - for names, awaitable in self.context.fused_splits_awaitables: - for name, request in zip(names, awaitable.wait()): - self.context.input_dist_tensors_requests[name] = request + with record_function("## wait_sparse_data_dist ##"): + with torch.cuda.stream(self.stream): + self.context.module_contexts = ( + self.context.module_contexts_next_batch.copy() + ) + self.context.input_dist_tensors_requests.clear() + for names, awaitable in self.context.fused_splits_awaitables: + for name, request in zip(names, awaitable.wait()): + self.context.input_dist_tensors_requests[name] = request