Skip to content

Commit

Permalink
Fix semi-sync race conditions and optimize memory usage (#2598)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2598

This PR introduces a set of fixes and optimizations for TorchRec's semi-synchronous training pipeline (where embeddings are fetched in advance and in parallel with backward):
- Fixes memory safety issues causing CUDA Illegal Memory Access errors and NaNs during training, especially at high memory utilisations. The fix involves using calling [`record_stream`](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) for tensors allocated in one CUDA stream and used in another (an example is embedding tensors which are looked up in an embedding stream but later accessed in the main stream).
- Optimizes memory allocations to reduce memory fragmentation observed at high memory utilizations causing significant performance degradation due to expensive defragmentation calls (thanks to che-sh for initial observation and analysis). Optimizations include:
  -  Freeing context objects and cached module outputs as early as possible to save memory
  - Moving small tensor allocations earlier to minimize fragmentation
  - Using a single stream per embedding module (instead of even and odd streams) as memory allocations by PyTorch's CUDACachingAllocator are associated with streams, meaning freed memory blocks in one stream cannot be used by other streams. By using more streams, we effectively decrease memory available to each stream, making defrags more likely.

This is joint work with che-sh (optimizations to reduce memory fragmentation) and dstaay-fb (record_stream fixes for embeddings)

Reviewed By: che-sh

Differential Revision: D64220706

fbshipit-source-id: 1ac8b9a0855a002bbace4fd21827e15a1fbd17b5
  • Loading branch information
sarckk authored and facebook-github-bot committed Dec 12, 2024
1 parent e1b5edd commit 575e081
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 96 deletions.
2 changes: 2 additions & 0 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def _wait_impl(self) -> W:
"""

ret = self.wait_function.apply(self.pg, self, self.dummy_tensor)
if isinstance(ret, torch.Tensor):
ret.record_stream(torch.get_device_module(ret.device).current_stream())
self.req = None
self.tensor = None
return ret
Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/tests/test_sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def _test_sharding(

@skip_if_asan_class
class ConstructParameterShardingAndShardTest(MultiProcessTestBase):
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
# pyre-fixme[56]
@given(
per_param_sharding=st.sampled_from(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def forward(self, x):
fqn="test_module",
args=[],
context=TrainPipelineContext(),
default_stream=MagicMock(),
dist_stream=MagicMock(),
)
# self-check - we want the state dict be the same between vanilla model and "rewritten model"
self.assertDictEqual(model.state_dict(), rewritten_model.state_dict())
Expand Down
115 changes: 51 additions & 64 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def _pipeline_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
default_stream=torch.get_device_module(self._device).current_stream(),
batch=batch,
apply_jit=self._apply_jit,
pipelined_forward=pipelined_forward,
Expand Down Expand Up @@ -576,15 +577,14 @@ def copy_batch_to_gpu(
StopIteration: if the dataloader iterator is exhausted; unless
`self._execute_all_batches=True`, then returns None.
"""
context = None
context = self._create_context()
with record_function(f"## copy_batch_to_gpu {self._next_index} ##"):
with self._stream_context(self._memcpy_stream):
batch = self._next_batch(dataloader_iter)
if batch is not None:
batch = _to_device(batch, self._device, non_blocking=True)
elif not self._execute_all_batches:
raise StopIteration
context = self._create_context()
return batch, context

def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
Expand Down Expand Up @@ -747,25 +747,9 @@ def __init__(
)
self._start_batch = start_batch
self._stash_gradients = stash_gradients
logger.debug(f"Starting semi-sync run at batch: {self._start_batch}")

# use two data streams to support two concurrent batches
self._embedding_odd_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=0))
if device.type in ["cuda", "mtia"]
else None
)
self._embedding_even_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=0))
if device.type in ["cuda", "mtia"]
else None
)
self._overarch_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=-1))
if device.type in ["cuda", "mtia"]
else None
)
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
self._embedding_even_streams: List[Optional[torch.Stream]] = []
self._embedding_streams: List[Optional[torch.Stream]] = []
self._gradients: Dict[str, torch.Tensor] = {}

def _grad_swap(self) -> None:
Expand All @@ -778,12 +762,7 @@ def _grad_swap(self) -> None:
def _init_embedding_streams(self) -> None:

for _ in self._pipelined_modules:
self._embedding_odd_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)
self._embedding_even_streams.append(
self._embedding_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
Expand Down Expand Up @@ -839,13 +818,9 @@ def is_semi_sync(self) -> bool:
return self.contexts[0].index >= self._start_batch
return False

def _mlp_optimizer_step(self) -> None:
def _mlp_optimizer_step(self, current_batch: int) -> None:
# special case: not all optimizers support optim.step() on null gradidents
if (
len(self.batches) >= 1
and self.contexts[0].index == self._start_batch
and self._stash_gradients
):
if current_batch == self._start_batch and self._stash_gradients:
return
self._optimizer.step()

Expand All @@ -860,42 +835,56 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.contexts[2],
)

losses, output = self._mlp_forward(cast(In, self.batches[0]), self.contexts[0])
batch, context = self.batches[0], self.contexts[0]
is_semi_sync = context.index is not None and context.index >= self._start_batch
iteration: int = context.index or 0
losses, output = self._mlp_forward(cast(In, batch), context)

# After this point, pipelined preproc/module forward won't be called
# so we can advance their contexts to the context of the next batch already
# and also pop batch and context from self.batches and self.contexts
self.dequeue_batch()

# batch no longer needed - delete to free up memory
del batch

# cached preproc fwd results no longer needed - delete to free up memory
del context.preproc_fwd_results

# batch i+3
self.enqueue_batch(dataloader_iter)

if len(self.batches) >= 2 and self.is_semi_sync():
if len(self.batches) >= 1 and is_semi_sync:
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])
self.start_embedding_lookup(self.batches[0], self.contexts[0])

if len(self.batches) >= 3:
self.wait_sparse_data_dist(self.contexts[2])
if len(self.batches) >= 2:
self.wait_sparse_data_dist(self.contexts[1])

if self._model.training:
with record_function(f"## backward {self.contexts[0].index} ##"):
with record_function(f"## backward {iteration} ##"):
torch.sum(losses, dim=0).backward()
# pyre-ignore [6]
self.embedding_backward(self.contexts[0])
with record_function(f"## emb_backward {iteration} ##"):
# pyre-ignore [6]
self.embedding_backward(context)

with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
if self.is_semi_sync() and self._stash_gradients:
del context # context is no longer needed, deleting to free up memory

with record_function(f"## optimizer {iteration - 1} ##"):
if is_semi_sync and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()
self._mlp_optimizer_step(iteration)

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
with record_function(f"## zero_grad {iteration - 1} ##"):
self._optimizer.zero_grad()
else:
del context

if len(self.batches) >= 2 and not self.is_semi_sync():
if len(self.batches) >= 1 and not is_semi_sync:
torch.cuda.synchronize() # needed to avoid race condition
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])
self.start_embedding_lookup(self.batches[0], self.contexts[0])

self.dequeue_batch()
return output

def _mlp_forward(
Expand All @@ -909,14 +898,9 @@ def _mlp_forward(

def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
default_stream = torch.get_device_module(self._device).current_stream()
streams = (
self._embedding_even_streams
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams
)
assert len(context.embedding_features) == len(context.embedding_tensors)
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
streams,
self._embedding_streams,
context.embedding_tensors,
context.embedding_features,
context.detached_embedding_tensors,
Expand All @@ -939,7 +923,9 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
embs_to_backprop.append(tensor)
grads_to_use.append(grad)
else:
if isinstance(features, Iterable):
if isinstance(features, str):
invalid_features.append(features)
elif isinstance(features, Iterable):
invalid_features.extend(features)
else:
invalid_features.append(features)
Expand Down Expand Up @@ -1012,13 +998,14 @@ def start_embedding_lookup(
batch, context, torch.get_device_module(self._device).current_stream()
)
for i, module in enumerate(self._pipelined_modules):
stream = (
self._embedding_even_streams[i]
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams[i]
)
stream = self._embedding_streams[i]
with self._stream_context(stream):
_start_embedding_lookup(module, context, stream)
_start_embedding_lookup(
module,
context,
source_stream=self._data_dist_stream,
target_stream=stream,
)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
Expand Down
Loading

0 comments on commit 575e081

Please sign in to comment.