Skip to content

Commit

Permalink
Enable prefetch stage for StagedTrainPipeline (pytorch#2239)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2239

Add ability to run prefetch as a stage in `StagedTrainPipeline`

Recommended usage to run 3-stage pipeline with data copy, sparse dist and prefetch steps (changes required shown with arrows):
```
sdd = SparseDataDistUtil(
    model=self._model,
    data_dist_stream=torch.torch.cuda.Stream(),
    prefetch_stream=torch.torch.cuda.Stream(), <--- define prefetch stream
)

pipeline = [
    PipelineStage(
        name="data_copy",
        runnable=lambda batch, context: batch.to(
            self._device, non_blocking=True
        ),
        stream=torch.cuda.Stream(),
    ),
    PipelineStage(
        name="start_sparse_data_dist",
        runnable=sdd.start_sparse_data_dist,
        stream=sdd.data_dist_stream,
        fill_callback=sdd.wait_sparse_data_dist,
    ),
    PipelineStage(
        name="prefetch",
        runnable=sdd.prefetch, <--- add stage with runnable=sdd.prefetch
        stream=sdd.prefetch_stream,
        fill_callback=sdd.load_prefetch, <--- fill_callback of sdd.load_prefetch
    ),
]

return StagedTrainPipeline(pipeline_stages=pipeline)
```

Order of execution for above pipeline:

Iteration pytorch#1:

_fill_pipeline():
batch 0: memcpy, start_sdd, wait_sdd (callback), prefetch, load_prefetch (callback)
batch 1: memcpy, start_sdd, wait_sdd (callback)
batch 2: memcpy

progress():
batch 3: memcpy
batch 2: start_sdd
batch 1: prefetch

after pipeline progress():
model(batch 0)
load_prefetch (prepares for model fwd on batch 1)
wait_sdd (prepares for batch 2 prefetch)

Iteration pytorch#2:
progress():
batch 4: memcpy
batch 3: start_sdd
batch 2: prefetch

after pipeline progress():
model(batch 1)
load_prefetch (prepares for model fwd on batch 2)
wait_sdd (prepares for batch 3 prefetch)

Reviewed By: zzzwen, joshuadeng

Differential Revision: D59786807

fbshipit-source-id: 6261c07cd6823bc541463d24ff867ab0e43631ea
  • Loading branch information
sarckk authored and facebook-github-bot committed Jul 23, 2024
1 parent 09d1ff2 commit 9264186
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 50 deletions.
142 changes: 136 additions & 6 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ def gpu_preproc(x: StageOut) -> StageOut:

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
)

Expand Down Expand Up @@ -1695,7 +1695,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]
Expand Down Expand Up @@ -1744,7 +1744,7 @@ def gpu_preproc(x: StageOut) -> StageOut:

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
)

Expand All @@ -1762,7 +1762,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def test_model_detach(self) -> None:

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
)

Expand All @@ -1873,7 +1873,7 @@ def test_model_detach(self) -> None:
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]
Expand Down Expand Up @@ -1964,3 +1964,133 @@ def test_model_detach(self) -> None:
# Check pipeline exhausted
preproc_input = pipeline.progress(dataloader)
self.assertIsNone(preproc_input)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
@settings(max_examples=4, deadline=None)
# pyre-ignore[56]
@given(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
]
),
cache_precision=st.sampled_from(
[
DataType.FP16,
DataType.FP32,
]
),
load_factor=st.sampled_from(
[
0.2,
0.4,
]
),
)
def test_pipelining_prefetch(
self,
sharding_type: str,
kernel_type: str,
cache_precision: DataType,
load_factor: float,
) -> None:
model = self._setup_model()

fused_params = {
"cache_load_factor": load_factor,
"cache_precision": cache_precision,
"stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16
}
fused_params_pipelined = {
**fused_params,
"prefetch_pipeline": True,
}

sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)
(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params_pipelined
)

copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

num_batches = 12
data = self._generate_data(
num_batches=num_batches,
batch_size=32,
)

non_pipelined_outputs = []
for batch in data:
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()
non_pipelined_outputs.append(pred)

def gpu_preproc(x: StageOut) -> StageOut:
return x

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
prefetch_stream=torch.cuda.Stream(),
)

pipeline_stages = [
PipelineStage(
name="data_copy",
runnable=partial(get_h2d_func, device=self.device),
stream=torch.cuda.Stream(),
),
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
PipelineStage(
name="prefetch",
runnable=sdd.prefetch,
# pyre-ignore
stream=sdd.prefetch_stream,
fill_callback=sdd.load_prefetch,
),
]
pipeline = StagedTrainPipeline(
pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream()
)
dataloader = iter(data)

pipelined_out = []
num_batches_processed = 0

while model_in := pipeline.progress(dataloader):
num_batches_processed += 1
optim_pipelined.zero_grad()
loss, pred = sharded_model_pipelined(model_in)
loss.backward()
optim_pipelined.step()
pipelined_out.append(pred)

self.assertEqual(num_batches_processed, num_batches)

self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
torch.testing.assert_close(out, ref_out)
49 changes: 11 additions & 38 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torchrec.distributed.train_pipeline.utils import (
_override_input_dist_forwards,
_pipeline_detach_model,
_prefetch_embeddings,
_rewrite_model,
_start_data_dist,
_start_embedding_lookup,
Expand Down Expand Up @@ -1101,46 +1102,18 @@ def _prefetch(self, batch: Optional[In]) -> None:
batch.record_stream(
torch.get_device_module(self._device).current_stream()
)
data_per_pipelined_module = _prefetch_embeddings(
batch,
self._context,
self._pipelined_modules,
self._device,
self._stream_context,
self._data_dist_stream,
self._default_stream,
)
for sharded_module in self._pipelined_modules:
forward = sharded_module.forward
assert isinstance(forward, PrefetchPipelinedForward)

assert forward._name in self._context.input_dist_tensors_requests
request = self._context.input_dist_tensors_requests.pop(
forward._name
)
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with self._stream_context(self._data_dist_stream):
data = request.wait()

# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
module_context = self._context.module_contexts[forward._name]
if self._data_dist_stream is not None:
torch.get_device_module(
self._device
).current_stream().wait_stream(self._data_dist_stream)
cur_stream = torch.get_device_module(
self._device
).current_stream()

assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)
data.record_stream(self._default_stream)

module_context.record_stream(cur_stream)
module_context.record_stream(self._default_stream)

sharded_module.prefetch(
ctx=module_context,
dist_input=data,
forward_stream=self._default_stream,
)
data = data_per_pipelined_module[forward._name]
self._context.module_input_post_prefetch[forward._name] = data
self._context.module_contexts_post_prefetch[forward._name] = (
self._context.module_contexts.pop(forward._name)
Expand Down
Loading

0 comments on commit 9264186

Please sign in to comment.