Skip to content

Commit

Permalink
- Add CompiledAutograd pipeline (pytorch#2310)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2310

Add new pipeline for CompiledAutograd development.

Reviewed By: dstaay-fb, xmfan, yf225

Differential Revision: D61403499

fbshipit-source-id: 7bf0720e0c1078815315278fffd79c2d7470882f
  • Loading branch information
flaviotruzzi authored and facebook-github-bot committed Aug 22, 2024
1 parent 9418355 commit f7e444d
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 11 deletions.
3 changes: 3 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def purge(self) -> None:


class CommOpGradientScaling(torch.autograd.Function):
# user override: inline autograd.Function is safe to trace since only tensor mutations / no global state
_compiled_autograd_should_lift = False

@staticmethod
# pyre-ignore
def forward(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TrainPipelineBase, # noqa
TrainPipelinePT2, # noqa
TrainPipelineSparseDist, # noqa
TrainPipelineSparseDistCompAutograd, # noqa
)
from torchrec.distributed.train_pipeline.utils import ( # noqa
_override_input_dist_forwards, # noqa
Expand Down
43 changes: 33 additions & 10 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
Expand Down Expand Up @@ -53,6 +54,7 @@
TrainPipelinePT2,
TrainPipelineSemiSync,
TrainPipelineSparseDist,
TrainPipelineSparseDistCompAutograd,
)
from torchrec.distributed.train_pipeline.utils import (
DataLoadingThread,
Expand Down Expand Up @@ -393,7 +395,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
sharded_sparse_arch_pipeline.parameters(), lr=0.1
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
sharded_sparse_arch_pipeline,
optimizer_pipeline,
self.device,
Expand Down Expand Up @@ -441,7 +443,7 @@ def _setup_pipeline(
dict(in_backward_optimizer_filter(distributed_model.named_parameters())),
lambda params: optim.SGD(params, lr=0.1),
)
return TrainPipelineSparseDist(
return self.pipeline_class(
model=distributed_model,
optimizer=optimizer_distributed,
device=self.device,
Expand Down Expand Up @@ -508,7 +510,7 @@ def test_equal_to_non_pipelined(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -621,7 +623,7 @@ def test_model_detach_during_train(self) -> None:
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -719,7 +721,7 @@ def test_model_detach_after_train(self) -> None:
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -862,7 +864,7 @@ def _check_output_equal(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1116,7 +1118,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None:
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1171,7 +1173,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive(
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1217,7 +1219,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None:
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1280,7 +1282,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None:
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -2100,3 +2102,24 @@ def gpu_preproc(x: StageOut) -> StageOut:
self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
torch.testing.assert_close(out, ref_out)


class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest):
def setUp(self) -> None:
super().setUp()
self.pipeline_class = TrainPipelineSparseDistCompAutograd
torch._dynamo.reset()
counters["compiled_autograd"].clear()
# Compiled Autograd don't work with Anomaly Mode
torch.autograd.set_detect_anomaly(False)

def tearDown(self) -> None:
# Every single test has two captures, one for forward and one for backward
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
return super().tearDown()

@unittest.skip("Dynamo only supports FSDP with use_orig_params=True")
# pyre-ignore[56]
@given(execute_all_batches=st.booleans())
def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None:
super().test_pipelining_fsdp_pre_trace()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TestEBCSharder,
TestSparseNN,
)
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist
from torchrec.distributed.types import ModuleSharder, ShardingEnv
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.test_utils import get_free_port, init_distributed_single_host
Expand Down Expand Up @@ -59,6 +60,7 @@ def setUp(self) -> None:
]

self.device = torch.device("cuda:0")
self.pipeline_class = TrainPipelineSparseDist

def tearDown(self) -> None:
super().tearDown()
Expand Down
90 changes: 89 additions & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
# pyre-strict

import abc
import contextlib
import logging
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
ContextManager,
Deque,
Dict,
Generic,
Expand All @@ -27,6 +30,7 @@
)

import torch
import torchrec.distributed.comm_ops
from torch.autograd.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
from torchrec.distributed.model_parallel import ShardedModule
Expand Down Expand Up @@ -59,7 +63,6 @@
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.pt2.utils import default_pipeline_input_transformer
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1506,3 +1509,88 @@ def progress(
return self.progress(dataloader_iter)

return out


class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]):
"""
This pipeline clone the TrainPipelineSparseDist, but execute the progress
method within compiled autograd context.
"""

def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
execute_all_batches: bool = True,
apply_jit: bool = False,
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
pipeline_preproc: bool = False,
custom_model_fwd: Optional[
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
] = None,
) -> None:
super().__init__(
model,
optimizer,
device,
execute_all_batches,
apply_jit,
context_type,
pipeline_preproc,
custom_model_fwd,
)

# it will check this path on model to inject configuration other than
# the default one.
self.compiled_autograd_options: Dict[str, Union[str, bool]] = getattr(
model,
"_compiled_autograd_options",
{
"backend": "inductor",
"dynamic": True,
"fullgraph": True,
},
)

torch._dynamo.config.optimize_ddp = "python_reducer"
torch._dynamo.config.inline_inbuilt_nn_modules = True
torch._dynamo.config.skip_fsdp_hooks = False
torch._functorch.config.recompute_views = True
torch._functorch.config.cse = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
"sink_waits",
"raise_comms",
"reorder_compute_for_overlap",
]
self.initialized = False

def get_compiled_autograd_ctx(
self,
) -> ContextManager:
# this allows for pipelining
# to avoid doing a sum on None
# when the pipeline is empty
if not self.initialized:
self.initialized = True
return contextlib.nullcontext()

return torch._dynamo.compiled_autograd.enable(
# pyre-ignore
torch.compile(**self.compiled_autograd_options)
)

@contextmanager
def sync_collectives_ctx(self) -> Iterator[None]:
try:
if is_torchdynamo_compiling():
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
yield
finally:
torchrec.distributed.comm_ops.set_use_sync_collectives(False)

def progress(self, dataloader_iter: Iterator[In]) -> Out:

with self.get_compiled_autograd_ctx(), self.sync_collectives_ctx():
return super().progress(dataloader_iter)

0 comments on commit f7e444d

Please sign in to comment.