Skip to content

Commit

Permalink
Add missing event wait for last stage in StagedTrainPipeline
Browse files Browse the repository at this point in the history
Summary:
StagedTrainPipeline expects model forward() to happen outside of the pipeline, which means that we need to wait for the last pre-forward stage to finish before progressing in the main compute stream.

Also changes `wait_sparse_data_dist` to happen in the SDD stream instead of main stream

Reviewed By: dracifer

Differential Revision: D54685704
  • Loading branch information
sarckk authored and facebook-github-bot committed Mar 8, 2024
1 parent 1df82cc commit 7e7dcf8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,9 @@ def gpu_preproc(x: StageOut) -> StageOut:
fill_callback=sdd.wait_sparse_data_dist,
),
]
pipeline = StagedTrainPipeline(pipeline_stages=pipeline_stages)
pipeline = StagedTrainPipeline(
pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream()
)
dataloader = iter(data)

pipelined_out = []
Expand Down
38 changes: 31 additions & 7 deletions torchrec/distributed/train_pipeline/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]):
Args:
pipeline_stages (List[PipelineStage]): A list of stages to execute.
compute_stream (Optional[torch.cuda.Stream]): The main compute stream in which model forward is run,
usually torch.cuda.default_stream(). Defaults to the current cuda stream.
debug_mode (bool): Whether to enable debug mode.
Example::
Expand Down Expand Up @@ -610,6 +612,7 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]):
def __init__(
self,
pipeline_stages: List[PipelineStage],
compute_stream: Optional[torch.cuda.Stream] = None,
debug_mode: bool = False,
) -> None:
self._pipeline_stages = pipeline_stages
Expand All @@ -619,20 +622,20 @@ def __init__(
)
self._initialized = False
self._num_steps = 0
self._data_depleted = False
self._compute_stream = compute_stream or torch.cuda.current_stream()

@property
def num_stages(self) -> int:
return len(self._pipeline_stages)

def _advance(self) -> Optional[StageOut]:
def _advance(self) -> Optional[StageOutputWithEvent]:
# left shifts all batch results.
out = self._stage_outputs[0]
for idx in range(self.num_stages - 1):
self._stage_outputs[idx] = self._stage_outputs[idx + 1]
self._stage_outputs[-1] = None
if out is None:
return out
return out[0]
return out

def _run_with_event(
self,
Expand All @@ -654,6 +657,15 @@ def _run_with_event(
new_event.record(stream)
return (output, new_event)

def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
if self._data_depleted:
return None

batch_to_wait = next(dataloader_iter, None)
if batch_to_wait is None:
self._data_depleted = True
return batch_to_wait

def _run_stage(
self,
batch_offset: int,
Expand All @@ -672,7 +684,7 @@ def _run_stage(
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##"
):
if stage_idx == 0:
batch_to_wait = next(dataloader_iter, None)
batch_to_wait = self._next_batch(dataloader_iter)
event = None
else:
batch_to_wait_with_event = self._stage_outputs[batch_offset]
Expand Down Expand Up @@ -757,7 +769,12 @@ def progress(
if not self._initialized:
self._fill_pipeline(dataloader_iter)

output = self._advance()
output_with_event = self._advance()

if output_with_event is None:
# All data consumed, exit early
return None

self._num_steps += 1

for stage_idx in range(self.num_stages):
Expand All @@ -768,4 +785,11 @@ def progress(
dataloader_iter=dataloader_iter,
)

return output
out, event = output_with_event
if event is not None:
# Since model forward() is expected to run outside the pipeline,
# we need to explicitly wait for the last stage to finish
event.wait(self._compute_stream)
out.record_stream(self._compute_stream)

return out
14 changes: 9 additions & 5 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,12 @@ def start_sparse_data_dist(self, batch: In) -> In:
return batch

def wait_sparse_data_dist(self) -> None:
self.context.module_contexts = self.context.module_contexts_next_batch.copy()
self.context.input_dist_tensors_requests.clear()
for names, awaitable in self.context.fused_splits_awaitables:
for name, request in zip(names, awaitable.wait()):
self.context.input_dist_tensors_requests[name] = request
with record_function("## wait_sparse_data_dist ##"):
with torch.cuda.stream(self.stream):
self.context.module_contexts = (
self.context.module_contexts_next_batch.copy()
)
self.context.input_dist_tensors_requests.clear()
for names, awaitable in self.context.fused_splits_awaitables:
for name, request in zip(names, awaitable.wait()):
self.context.input_dist_tensors_requests[name] = request

0 comments on commit 7e7dcf8

Please sign in to comment.