From 92b903f98cff499edc60cfffa429691cb6ca39a3 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 17 Dec 2024 19:05:58 -0800 Subject: [PATCH] Make streams device-agnostic (#2644) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2644 https://github.com/pytorch/torchrec/pull/2598 (D64220706) causes failures when using other accelerators that do not support CUDA. Making the stream contexts hardware agnostic. Reviewed By: hpnhxxwn, iamzainhuda Differential Revision: D67363141 fbshipit-source-id: fc2c6fec1dcbbe15f0385e299b666207d2d9a8f5 --- .../train_pipeline/train_pipelines.py | 1 + torchrec/distributed/train_pipeline/utils.py | 38 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index d42a2e9ac..e747a6283 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -1002,6 +1002,7 @@ def start_embedding_lookup( context, source_stream=self._data_dist_stream, target_stream=stream, + stream_context=self._stream_context, ) event = torch.get_device_module(self._device).Event() event.record() diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index efc772a90..25aa9fe96 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -11,6 +11,7 @@ import itertools import logging from collections import defaultdict, OrderedDict +from contextlib import AbstractContextManager from dataclasses import dataclass, field from itertools import chain @@ -248,6 +249,21 @@ def recursive_record_stream( recursive_record_stream(v, stream) +class NoOpStream: + """No-Op Context manager that takes in a stream""" + + def __init__(self, stream: Optional[torch.Stream]) -> None: + self._stream = stream + + def __enter__(self) -> "NoOpStream": + """Return `self` upon entering the runtime context.""" + return self + + # pyre-ignore + def __exit__(self, exc_type, exc_value, traceback) -> None: + return None + + class PipelinedPreproc(torch.nn.Module): """ Wrapper around preproc module found during model graph traversal for sparse data dist @@ -297,6 +313,17 @@ def __init__( f"Preproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" ) + if self._dist_stream: + device: torch.device = self._dist_stream.device + # pyre-ignore + self._stream_context = ( + torch.get_device_module(device).stream + if device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + else: + self._stream_context = NoOpStream + @property def preproc_module(self) -> torch.nn.Module: return self._preproc_module @@ -341,8 +368,7 @@ def forward(self, *input, **kwargs) -> Any: with record_function(f"## sdd_input_preproc {self._context.index} ##"): # should be no-op as we call this in dist stream - # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream - with torch.cuda.stream(self._dist_stream): + with self._stream_context(self._dist_stream): res = self._preproc_module(*args, **kwargs) # Ensure preproc modules output is safe to use from default stream later @@ -364,8 +390,7 @@ def forward(self, *input, **kwargs) -> Any: f"Result of preproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" ) - # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream - with torch.cuda.stream(self._default_stream): + with self._stream_context(self._default_stream): # Cache results, only during _start_data_dist self._context.preproc_fwd_results[self._fqn] = res @@ -760,10 +785,11 @@ def _start_embedding_lookup( context: EmbeddingTrainPipelineContext, source_stream: Optional[torch.Stream], target_stream: Optional[torch.Stream], + # pyre-ignore[2] + stream_context: Callable[..., AbstractContextManager[Any, Any]], ) -> None: module_context = context.module_contexts[module.forward.name] - # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream - with torch.cuda.stream(source_stream): + with stream_context(source_stream): kjt = context.input_dist_tensors_requests[module.forward.name].wait() if target_stream is not None: