Skip to content

Commit

Permalink
Overlap comms on backward pass (#2117)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2117

Resolves issues around cuda streams / NCCL Deadlock with autograd.

Basically create seperate streams per pipelined embedding arch.

Reviewed By: sarckk

Differential Revision: D58220332

fbshipit-source-id: e203acad4a92702b94a42e2106d6de4f5d89e112
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Jun 29, 2024
1 parent 6850941 commit 7e4ef94
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 78 deletions.
148 changes: 103 additions & 45 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
_start_embedding_lookup,
_to_device,
_wait_for_batch,
_wait_for_event,
_wait_for_events,
DataLoadingThread,
EmbeddingPipelinedForward,
EmbeddingTrainPipelineContext,
Expand Down Expand Up @@ -590,8 +590,6 @@ def start_sparse_data_dist(
return
with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
if context.event is not None:
context.event.wait()
_wait_for_batch(batch, self._memcpy_stream)

original_contexts = [p.get_context() for p in self._pipelined_preprocs]
Expand Down Expand Up @@ -737,11 +735,8 @@ def __init__(
if device.type in ["cuda", "mtia"]
else None
)
self._bwd_sync_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=0))
if device.type in ["cuda", "mtia"]
else None
)
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
self._embedding_even_streams: List[Optional[torch.Stream]] = []
self._gradients: Dict[str, torch.Tensor] = {}

def _grad_swap(self) -> None:
Expand All @@ -751,6 +746,29 @@ def _grad_swap(self) -> None:
self._gradients[name] = param.grad.clone()
param.grad = grad

def _init_embedding_streams(self) -> None:

for _ in self._pipelined_modules:
self._embedding_odd_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)
self._embedding_even_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)

def _validate_optimizer(self) -> None:
for pipelined_module in self._pipelined_modules:
pipelined_params = set(pipelined_module.parameters())
for group in self._optimizer.param_groups:
if not set(group["params"]).isdisjoint(pipelined_params):
logger.warning(
f"SemiSync pipelined {type(pipelined_module)} and optimizer share parameters"
)

def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pipeline is already filled
if len(self.batches) >= 3:
Expand All @@ -770,7 +788,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pyre-ignore [6]
EmbeddingPipelinedForward,
)
self._init_embedding_streams()
self.wait_sparse_data_dist(self.contexts[0])
self._validate_optimizer()
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[0], self.contexts[0])

Expand Down Expand Up @@ -824,26 +844,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.wait_sparse_data_dist(self.contexts[2])

if self._model.training:
# backward would put an implicit sync point in stream called from, ideally
# this would different from optimizer so it could start earilier, but currently not safe to do so.
with self._stream_context(self._overarch_stream):
with record_function(f"## backward {self.contexts[0].index} ##"):
torch.sum(losses, dim=0).backward()

with self._stream_context(self._overarch_stream):
with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
if self.is_semi_sync() and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
self._optimizer.zero_grad()
with record_function(f"## backward {self.contexts[0].index} ##"):
torch.sum(losses, dim=0).backward()
# pyre-ignore [6]
self.embedding_backward(self.contexts[0])

with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
if self.is_semi_sync() and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
self._optimizer.zero_grad()

if len(self.batches) >= 2 and not self.is_semi_sync():
torch.cuda.synchronize() # needed to avoid race condition
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])

Expand All @@ -854,10 +873,29 @@ def _mlp_forward(
self, batch: In, context: TrainPipelineContext
) -> Tuple[torch.Tensor, Out]:
with record_function(f"## forward {context.index} ##"):
with self._stream_context(self._overarch_stream):
_wait_for_event(batch, self._device, context.event)
context.event = None
return cast(Tuple[torch.Tensor, Out], self._model(batch))
_wait_for_events(
batch, context, torch.get_device_module(self._device).current_stream()
)
return cast(Tuple[torch.Tensor, Out], self._model(batch))

def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
default_stream = torch.get_device_module(self._device).current_stream()
streams = (
self._embedding_even_streams
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams
)
for stream, emb_tensors, detached_emb_tensors in zip(
streams,
context.embedding_tensors,
context.detached_embedding_tensors,
):
with self._stream_context(stream):
grads = [tensor.grad for tensor in detached_emb_tensors]
if stream:
stream.wait_stream(default_stream)
# pyre-ignore
torch.autograd.backward(emb_tensors, grads)

def copy_batch_to_gpu(
self,
Expand All @@ -870,8 +908,9 @@ def copy_batch_to_gpu(
if batch is not None:
batch = _to_device(batch, self._device, non_blocking=True)
context = self._create_context()
context.event = torch.get_device_module(self._device).Event()
context.event.record()
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
return batch, context

def start_sparse_data_dist(
Expand All @@ -882,9 +921,25 @@ def start_sparse_data_dist(
"""
Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version.
"""
super().start_sparse_data_dist(batch, context)
context.event = torch.get_device_module(self._device).Event()
context.event.record()
if batch is None:
return

# Temporarily set context for next iter to populate cache
original_contexts = [p.get_context() for p in self._pipelined_preprocs]
for preproc_mod in self._pipelined_preprocs:
preproc_mod.set_context(context)

with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
_start_data_dist(self._pipelined_modules, batch, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)

# Restore context for model forward
for module, context in zip(self._pipelined_preprocs, original_contexts):
module.set_context(context)

def start_embedding_lookup(
self,
Expand All @@ -897,17 +952,20 @@ def start_embedding_lookup(
if batch is None:
return
with record_function(f"## start_embedding_lookup {context.index} ##"):
with self._stream_context(
self._embedding_even_stream
if cast(int, context.index) % 2 == 0
else self._embedding_odd_stream
):
_wait_for_event(batch, self._device, context.event)
_start_embedding_lookup(
self._pipelined_modules, batch, context, self._device
_wait_for_events(
batch, context, torch.get_device_module(self._device).current_stream()
)
for i, module in enumerate(self._pipelined_modules):
stream = (
self._embedding_even_streams[i]
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams[i]
)
context.event = torch.get_device_module(self._device).Event()
context.event.record()
with self._stream_context(stream):
_start_embedding_lookup(module, batch, context, stream)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)


class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
Expand Down
89 changes: 56 additions & 33 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
)
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule

from torchrec.distributed.types import Awaitable
from torchrec.distributed.types import Awaitable, LazyNoWait

from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Multistreamable, Pipelineable

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,10 +95,8 @@ class TrainPipelineContext:
fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = (
field(default_factory=list)
)
event: Optional[torch.Event] = None

events: List[torch.Event] = field(default_factory=list)
preproc_fwd_results: Dict[str, Any] = field(default_factory=dict)

index: Optional[int] = None
version: int = (
0 # 1 is current version, 0 is deprecated but supported for backward compatibility
Expand All @@ -116,6 +114,8 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
@dataclass
class EmbeddingTrainPipelineContext(TrainPipelineContext):
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -369,7 +369,35 @@ def __call__(self, *input, **kwargs) -> Awaitable:
)
cur_stream = torch.get_device_module(self._device).current_stream()
ctx.record_stream(cur_stream)
return self._context.embedding_a2a_requests.pop(self._name)
awaitable = self._context.embedding_a2a_requests.pop(self._name)
embeddings = awaitable.wait() # trigger awaitable manually for type checking
tensors = []
detached_tensors = []
if isinstance(embeddings, Dict):
for jt in embeddings.values():
assert isinstance(jt, JaggedTensor)
tensor = jt.values()
detached_tensor = tensor.detach().requires_grad_()
detached_tensor.retain_grad()
jt._values = detached_tensor
tensors.append(tensor)
detached_tensors.append(detached_tensor)
# pyre-ignore [16]
self._context.embedding_tensors.append(tensors)
# pyre-ignore [16]
self._context.detached_embedding_tensors.append(detached_tensors)
else:
assert isinstance(embeddings, KeyedTensor)
tensor = embeddings.values()
detached_tensor = tensor.detach().requires_grad_()
detached_tensor.retain_grad()
embeddings._values = detached_tensor
tensors.append(tensor)
detached_tensors.append(detached_tensor)
self._context.embedding_tensors.append(tensors)
self._context.detached_embedding_tensors.append(detached_tensors)

return LazyNoWait(embeddings)


class PrefetchPipelinedForward(BaseForward):
Expand Down Expand Up @@ -513,22 +541,23 @@ def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None:
batch.record_stream(cur_stream)


def _wait_for_event(
def _wait_for_events(
batch: In,
device: torch.device,
event: Optional[torch.Event],
context: TrainPipelineContext,
stream: Optional[torch.Stream],
) -> None:
"""
Wait for event
Wait for any outstanding events for a given context
"""
if event is not None:
event.wait()
cur_stream = torch.get_device_module(device).current_stream()

assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
for event in context.events:
event.wait()
context.events.clear()
if stream:
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(stream)


def _start_data_dist(
Expand Down Expand Up @@ -569,25 +598,19 @@ def _start_data_dist(


def _start_embedding_lookup(
pipelined_modules: List[ShardedModule],
module: ShardedModule,
batch: In, # not used in this function
context: EmbeddingTrainPipelineContext,
device: torch.device,
stream: Optional[torch.Stream],
) -> None:
cur_stream = torch.get_device_module(device).current_stream()
kjts_per_module = []
for module in pipelined_modules:
kjts = context.input_dist_tensors_requests[module.forward.name].wait()
kjts.record_stream(cur_stream)
kjts_per_module.append(kjts)

for module, kjts in zip(pipelined_modules, kjts_per_module):
module_name = module.forward.name
module_context = context.module_contexts[module.forward.name]
module_context.record_stream(cur_stream)
a2a_awaitable = module.compute_and_output_dist(module_context, kjts)
# pyre-ignore[6]
context.embedding_a2a_requests[module_name] = a2a_awaitable
kjt = context.input_dist_tensors_requests[module.forward.name].wait()
module_context = context.module_contexts[module.forward.name]
if stream:
kjt.record_stream(stream)
module_context.record_stream(stream)
a2a_awaitable = module.compute_and_output_dist(module_context, kjt)
# pyre-ignore[6]
context.embedding_a2a_requests[module.forward.name] = a2a_awaitable


def _fuse_input_dist_splits(context: TrainPipelineContext) -> None:
Expand Down

0 comments on commit 7e4ef94

Please sign in to comment.