diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index b83e0ada8..63e71cda9 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -38,7 +38,7 @@ _start_embedding_lookup, _to_device, _wait_for_batch, - _wait_for_event, + _wait_for_events, DataLoadingThread, EmbeddingPipelinedForward, EmbeddingTrainPipelineContext, @@ -590,8 +590,6 @@ def start_sparse_data_dist( return with record_function(f"## start_sparse_data_dist {context.index} ##"): with self._stream_context(self._data_dist_stream): - if context.event is not None: - context.event.wait() _wait_for_batch(batch, self._memcpy_stream) original_contexts = [p.get_context() for p in self._pipelined_preprocs] @@ -737,11 +735,8 @@ def __init__( if device.type in ["cuda", "mtia"] else None ) - self._bwd_sync_stream: Optional[torch.Stream] = ( - (torch.get_device_module(self._device).Stream(priority=0)) - if device.type in ["cuda", "mtia"] - else None - ) + self._embedding_odd_streams: List[Optional[torch.Stream]] = [] + self._embedding_even_streams: List[Optional[torch.Stream]] = [] self._gradients: Dict[str, torch.Tensor] = {} def _grad_swap(self) -> None: @@ -751,6 +746,29 @@ def _grad_swap(self) -> None: self._gradients[name] = param.grad.clone() param.grad = grad + def _init_embedding_streams(self) -> None: + + for _ in self._pipelined_modules: + self._embedding_odd_streams.append( + (torch.get_device_module(self._device).Stream(priority=0)) + if self._device.type in ["cuda", "mtia"] + else None + ) + self._embedding_even_streams.append( + (torch.get_device_module(self._device).Stream(priority=0)) + if self._device.type in ["cuda", "mtia"] + else None + ) + + def _validate_optimizer(self) -> None: + for pipelined_module in self._pipelined_modules: + pipelined_params = set(pipelined_module.parameters()) + for group in self._optimizer.param_groups: + if not set(group["params"]).isdisjoint(pipelined_params): + logger.warning( + f"SemiSync pipelined {type(pipelined_module)} and optimizer share parameters" + ) + def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: # pipeline is already filled if len(self.batches) >= 3: @@ -770,7 +788,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: # pyre-ignore [6] EmbeddingPipelinedForward, ) + self._init_embedding_streams() self.wait_sparse_data_dist(self.contexts[0]) + self._validate_optimizer() # pyre-ignore [6] self.start_embedding_lookup(self.batches[0], self.contexts[0]) @@ -824,26 +844,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self.wait_sparse_data_dist(self.contexts[2]) if self._model.training: - # backward would put an implicit sync point in stream called from, ideally - # this would different from optimizer so it could start earilier, but currently not safe to do so. - with self._stream_context(self._overarch_stream): - with record_function(f"## backward {self.contexts[0].index} ##"): - torch.sum(losses, dim=0).backward() - - with self._stream_context(self._overarch_stream): - with record_function( - f"## optimizer {cast(int, self.contexts[0].index) - 1} ##" - ): - if self.is_semi_sync() and self._stash_gradients: - self._grad_swap() - self._mlp_optimizer_step() - - with record_function( - f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##" - ): - self._optimizer.zero_grad() + with record_function(f"## backward {self.contexts[0].index} ##"): + torch.sum(losses, dim=0).backward() + # pyre-ignore [6] + self.embedding_backward(self.contexts[0]) + + with record_function( + f"## optimizer {cast(int, self.contexts[0].index) - 1} ##" + ): + if self.is_semi_sync() and self._stash_gradients: + self._grad_swap() + self._mlp_optimizer_step() + + with record_function( + f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##" + ): + self._optimizer.zero_grad() if len(self.batches) >= 2 and not self.is_semi_sync(): + torch.cuda.synchronize() # needed to avoid race condition # pyre-ignore [6] self.start_embedding_lookup(self.batches[1], self.contexts[1]) @@ -854,10 +873,29 @@ def _mlp_forward( self, batch: In, context: TrainPipelineContext ) -> Tuple[torch.Tensor, Out]: with record_function(f"## forward {context.index} ##"): - with self._stream_context(self._overarch_stream): - _wait_for_event(batch, self._device, context.event) - context.event = None - return cast(Tuple[torch.Tensor, Out], self._model(batch)) + _wait_for_events( + batch, context, torch.get_device_module(self._device).current_stream() + ) + return cast(Tuple[torch.Tensor, Out], self._model(batch)) + + def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: + default_stream = torch.get_device_module(self._device).current_stream() + streams = ( + self._embedding_even_streams + if cast(int, context.index) % 2 == 0 + else self._embedding_odd_streams + ) + for stream, emb_tensors, detached_emb_tensors in zip( + streams, + context.embedding_tensors, + context.detached_embedding_tensors, + ): + with self._stream_context(stream): + grads = [tensor.grad for tensor in detached_emb_tensors] + if stream: + stream.wait_stream(default_stream) + # pyre-ignore + torch.autograd.backward(emb_tensors, grads) def copy_batch_to_gpu( self, @@ -870,8 +908,9 @@ def copy_batch_to_gpu( if batch is not None: batch = _to_device(batch, self._device, non_blocking=True) context = self._create_context() - context.event = torch.get_device_module(self._device).Event() - context.event.record() + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) return batch, context def start_sparse_data_dist( @@ -882,9 +921,25 @@ def start_sparse_data_dist( """ Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version. """ - super().start_sparse_data_dist(batch, context) - context.event = torch.get_device_module(self._device).Event() - context.event.record() + if batch is None: + return + + # Temporarily set context for next iter to populate cache + original_contexts = [p.get_context() for p in self._pipelined_preprocs] + for preproc_mod in self._pipelined_preprocs: + preproc_mod.set_context(context) + + with record_function(f"## start_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + _wait_for_events(batch, context, self._data_dist_stream) + _start_data_dist(self._pipelined_modules, batch, context) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) + + # Restore context for model forward + for module, context in zip(self._pipelined_preprocs, original_contexts): + module.set_context(context) def start_embedding_lookup( self, @@ -897,17 +952,20 @@ def start_embedding_lookup( if batch is None: return with record_function(f"## start_embedding_lookup {context.index} ##"): - with self._stream_context( - self._embedding_even_stream - if cast(int, context.index) % 2 == 0 - else self._embedding_odd_stream - ): - _wait_for_event(batch, self._device, context.event) - _start_embedding_lookup( - self._pipelined_modules, batch, context, self._device + _wait_for_events( + batch, context, torch.get_device_module(self._device).current_stream() + ) + for i, module in enumerate(self._pipelined_modules): + stream = ( + self._embedding_even_streams[i] + if cast(int, context.index) % 2 == 0 + else self._embedding_odd_streams[i] ) - context.event = torch.get_device_module(self._device).Event() - context.event.record() + with self._stream_context(stream): + _start_embedding_lookup(module, batch, context, stream) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index e0e823689..2972112cc 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -43,9 +43,9 @@ ) from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule -from torchrec.distributed.types import Awaitable +from torchrec.distributed.types import Awaitable, LazyNoWait -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Multistreamable, Pipelineable logger: logging.Logger = logging.getLogger(__name__) @@ -95,10 +95,8 @@ class TrainPipelineContext: fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = ( field(default_factory=list) ) - event: Optional[torch.Event] = None - + events: List[torch.Event] = field(default_factory=list) preproc_fwd_results: Dict[str, Any] = field(default_factory=dict) - index: Optional[int] = None version: int = ( 0 # 1 is current version, 0 is deprecated but supported for backward compatibility @@ -116,6 +114,8 @@ class PrefetchTrainPipelineContext(TrainPipelineContext): @dataclass class EmbeddingTrainPipelineContext(TrainPipelineContext): embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict) + embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) + detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) @dataclass @@ -369,7 +369,35 @@ def __call__(self, *input, **kwargs) -> Awaitable: ) cur_stream = torch.get_device_module(self._device).current_stream() ctx.record_stream(cur_stream) - return self._context.embedding_a2a_requests.pop(self._name) + awaitable = self._context.embedding_a2a_requests.pop(self._name) + embeddings = awaitable.wait() # trigger awaitable manually for type checking + tensors = [] + detached_tensors = [] + if isinstance(embeddings, Dict): + for jt in embeddings.values(): + assert isinstance(jt, JaggedTensor) + tensor = jt.values() + detached_tensor = tensor.detach().requires_grad_() + detached_tensor.retain_grad() + jt._values = detached_tensor + tensors.append(tensor) + detached_tensors.append(detached_tensor) + # pyre-ignore [16] + self._context.embedding_tensors.append(tensors) + # pyre-ignore [16] + self._context.detached_embedding_tensors.append(detached_tensors) + else: + assert isinstance(embeddings, KeyedTensor) + tensor = embeddings.values() + detached_tensor = tensor.detach().requires_grad_() + detached_tensor.retain_grad() + embeddings._values = detached_tensor + tensors.append(tensor) + detached_tensors.append(detached_tensor) + self._context.embedding_tensors.append(tensors) + self._context.detached_embedding_tensors.append(detached_tensors) + + return LazyNoWait(embeddings) class PrefetchPipelinedForward(BaseForward): @@ -513,22 +541,23 @@ def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None: batch.record_stream(cur_stream) -def _wait_for_event( +def _wait_for_events( batch: In, - device: torch.device, - event: Optional[torch.Event], + context: TrainPipelineContext, + stream: Optional[torch.Stream], ) -> None: """ - Wait for event + Wait for any outstanding events for a given context """ - if event is not None: - event.wait() - cur_stream = torch.get_device_module(device).current_stream() - assert isinstance( - batch, (torch.Tensor, Multistreamable) - ), f"{type(batch)} must implement Multistreamable interface" - batch.record_stream(cur_stream) + for event in context.events: + event.wait() + context.events.clear() + if stream: + assert isinstance( + batch, (torch.Tensor, Multistreamable) + ), f"{type(batch)} must implement Multistreamable interface" + batch.record_stream(stream) def _start_data_dist( @@ -569,25 +598,19 @@ def _start_data_dist( def _start_embedding_lookup( - pipelined_modules: List[ShardedModule], + module: ShardedModule, batch: In, # not used in this function context: EmbeddingTrainPipelineContext, - device: torch.device, + stream: Optional[torch.Stream], ) -> None: - cur_stream = torch.get_device_module(device).current_stream() - kjts_per_module = [] - for module in pipelined_modules: - kjts = context.input_dist_tensors_requests[module.forward.name].wait() - kjts.record_stream(cur_stream) - kjts_per_module.append(kjts) - - for module, kjts in zip(pipelined_modules, kjts_per_module): - module_name = module.forward.name - module_context = context.module_contexts[module.forward.name] - module_context.record_stream(cur_stream) - a2a_awaitable = module.compute_and_output_dist(module_context, kjts) - # pyre-ignore[6] - context.embedding_a2a_requests[module_name] = a2a_awaitable + kjt = context.input_dist_tensors_requests[module.forward.name].wait() + module_context = context.module_contexts[module.forward.name] + if stream: + kjt.record_stream(stream) + module_context.record_stream(stream) + a2a_awaitable = module.compute_and_output_dist(module_context, kjt) + # pyre-ignore[6] + context.embedding_a2a_requests[module.forward.name] = a2a_awaitable def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: