Skip to content

Commit

Permalink
Add methods to detach model from sparse data dist staged pipeline (#2049
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #2049

Sparse data dist pipelining causes sharded trec module forward to be replaced with `PipelinedForward` variants that uses context to fetch data for current rank.

However, there are use cases where we want to perform a simple forward on the trec sharded modules without using a pipeline (e.g. for simple local debug evals during training). In such cases, it is useful to have a way to detach the model from SDD pipelining.

Add `detach()` API to detach the model. Model will be re-attached when `pipeline.progress()` is called.

Reviewed By: zzzwen

Differential Revision: D57688338

fbshipit-source-id: f40d7824162bf286737b6b172d4ab3c9b40c80dc
  • Loading branch information
sarckk authored and facebook-github-bot committed Jun 5, 2024
1 parent f699979 commit da49f44
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 20 deletions.
134 changes: 134 additions & 0 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,3 +983,137 @@ def on_flush_end() -> None:

# Flush end not called this time
self.assertEqual(flush_end_called, 1)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_model_detach(self) -> None:
model = self._setup_model()

sharding_type = ShardingType.TABLE_WISE.value
fused_params = {}
kernel_type = EmbeddingComputeKernel.FUSED.value

sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)

(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type
)

copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
apply_jit=False,
)

pipeline_stages = [
PipelineStage(
name="data_copy",
runnable=partial(get_h2d_func, device=self.device),
stream=torch.cuda.Stream(),
),
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]

pipeline = StagedTrainPipeline(
pipeline_stages=pipeline_stages,
compute_stream=torch.cuda.current_stream(),
)

data = self._generate_data(
num_batches=12,
batch_size=32,
)
dataloader = iter(data)

for i in range(5):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

model_in = pipeline.progress(dataloader)
optim_pipelined.zero_grad()
loss_pred, pred_pipelined = sharded_model_pipelined(model_in)
loss_pred.backward()
optim_pipelined.step()

self.assertTrue(torch.equal(pred, pred_pipelined))

# Check internal states
ebcs = [
sharded_model_pipelined.module.sparse.ebc,
sharded_model_pipelined.module.sparse.weighted_ebc,
]
for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)
self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1)

detached_model = sdd.detach()

# Check internal states
for ebc in ebcs:
self.assertNotIsInstance(ebc.forward, PipelinedForward)
self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 0)

# Check we can run backward and optimizer ond detached model
batch = data[5].to(self.device)
loss_detached, detached_out = detached_model(batch)
loss_sharded, out = sharded_model(batch)
self.assertTrue(torch.equal(detached_out, out))
loss_detached.backward()
loss_sharded.backward()
optim.step()
optim_pipelined.step()

# Check fwd of detached model is same as non-pipelined model
with torch.no_grad():
batch = data[6].to(self.device)
_, detached_out = detached_model(batch)
_, out = sharded_model(batch)
self.assertTrue(torch.equal(detached_out, out))

# Check that pipeline re-attaches the model again without issues
for i in range(5, 12):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

model_in = pipeline.progress(dataloader)
optim_pipelined.zero_grad()
loss_pred, pred_pipelined = sharded_model_pipelined(model_in)
loss_pred.backward()
optim_pipelined.step()

self.assertTrue(torch.equal(pred, pred_pipelined))

for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)
self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1)

# Check pipeline exhausted
preproc_input = pipeline.progress(dataloader)
self.assertIsNone(preproc_input)
2 changes: 1 addition & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _init_pipelined_modules(
self.start_sparse_data_dist(batch, context)
return

self._pipelined_modules, self._model = _rewrite_model(
self._pipelined_modules, self._model, _ = _rewrite_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
Expand Down
114 changes: 95 additions & 19 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.fx.node import Node
from torch.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAll
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
from torchrec.distributed.embedding_sharding import (
FusedKJTListSplitsAwaitable,
KJTListSplitsAwaitable,
Expand Down Expand Up @@ -663,14 +663,45 @@ def _jit_modules(module: torch.nn.Module, path: str, optional: bool = True) -> b
return len(sharded_children) > 0


def _pipeline_detach_model(
pipelined_modules: List[ShardedModule],
# pyre-ignore[2]
original_forwards: List[Callable[..., Any]],
original_kjt_dist_forwards: List[
Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]
],
) -> None:
kjt_dists = []
for mod, original_fwd in zip(pipelined_modules, original_forwards):
# pyre-ignore
mod.forward = original_fwd

for _, child_module in mod.named_modules():
if not hasattr(child_module, "_input_dists"):
continue
for input_dist in child_module._input_dists:
if hasattr(input_dist, "_dist"):
kjt_dists.append(input_dist._dist)
assert len(kjt_dists) == len(
original_kjt_dist_forwards
), f"Number of KJT dists ({len(kjt_dists)}) does not match number of kjt dist forwards provided ({len(original_kjt_dist_forwards)})"

for kjt_dist, original_kjt_dist_fwd in zip(
kjt_dists,
original_kjt_dist_forwards,
):
kjt_dist.forward = original_kjt_dist_fwd


# pyre-ignore[3]
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
batch: Optional[In] = None,
apply_jit: bool = False,
pipelined_forward: Type[BaseForward] = PipelinedForward,
) -> Tuple[List[ShardedModule], torch.nn.Module]:
) -> Tuple[List[ShardedModule], torch.nn.Module, List[Callable[..., Any]]]:
input_model = model
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
Expand Down Expand Up @@ -706,6 +737,7 @@ def _rewrite_model( # noqa C901
# Select sharded modules, which are top-level in the forward call graph,
# i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'.
pipelined_forwards = []
original_forwards = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
Expand All @@ -716,6 +748,7 @@ def _rewrite_model( # noqa C901
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
original_forwards.append(child.forward)
child.forward = pipelined_forward(
node.target,
arg_info_list,
Expand All @@ -736,14 +769,17 @@ def _rewrite_model( # noqa C901
if isinstance(input_model, DistributedModelParallel):
input_model.module = graph_model

return pipelined_forwards, input_model
return pipelined_forwards, input_model, original_forwards


def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> None:
def _override_input_dist_forwards(
pipelined_modules: List[ShardedModule],
) -> List[Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]]:
"""
Overrides each input dist forward to support fusing the splits collective.
NOTE: this can only be called after the input dists are initialized.
"""
original_kjt_dist_forwards = []
for module in pipelined_modules:
for child_fqn, child_module in module.named_modules():
if hasattr(child_module, "_has_uninitialized_input_dist"):
Expand All @@ -757,11 +793,13 @@ def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> Non
for input_dist in child_module._input_dists:
if hasattr(input_dist, "_dist"):
assert isinstance(input_dist._dist, KJTAllToAll)
original_kjt_dist_forwards.append(input_dist._dist.forward)
input_dist._dist.forward = KJTAllToAllForward(
pg=input_dist._dist._pg,
splits=input_dist._dist._splits,
stagger=input_dist._dist._stagger,
)
return original_kjt_dist_forwards


def get_h2d_func(batch: In, device: torch.device) -> Pipelineable:
Expand Down Expand Up @@ -862,31 +900,69 @@ def __init__(
self.context = TrainPipelineContext(version=0)
self.initialized = False
self._pipelined_modules: List[ShardedModule] = []
# pyre-ignore
self.fwd_hook = None

# pyre-ignore
self.original_forward = self.model.forward
self._original_forwards: List[Callable[..., Any]] = []
self._original_kjt_dist_forwards: List[
Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]
] = []

def detach(self) -> torch.nn.Module:
"""
Removes sparse data dist (SDD) pipelining from model forward and input dist.
Modifies existing model in place and returns the model.
def forward_hook(
module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor]],
output: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> None:
self.wait_sparse_data_dist()
detach() can be called at any point, and inflight batches do not need to be
flushed before calling it. Calling pipeline.progress() will re-attach the model
to the pipeline and the pipeline will progress normally from the point it was detached (i.e. inflight batches will be kept when calling detach).
While the model is detached, it is equivalent to the model before passing to
the pipeline, so forward and backward passes, and optimizer updates can be
carried out normally.
"""
if self.initialized:
assert self.fwd_hook is not None
self.fwd_hook.remove()

_pipeline_detach_model(
pipelined_modules=self._pipelined_modules,
original_forwards=self._original_forwards,
original_kjt_dist_forwards=self._original_kjt_dist_forwards,
)

self.model.register_forward_hook(forward_hook)
self.initialized = False
return self.model

def start_sparse_data_dist(self, batch: In) -> In:
if not self.initialized:
self._pipelined_modules, self.model = _rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
# Step 1: Pipeline input dist in trec sharded modules
self._pipelined_modules, self.model, self._original_forwards = (
_rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
)
)
# initializes input dist, so we can override input dist forwards
_start_data_dist(self._pipelined_modules, batch, self.context)
_override_input_dist_forwards(self._pipelined_modules)
self._original_kjt_dist_forwards = _override_input_dist_forwards(
self._pipelined_modules
)

# Step 2: Register post-forward hook to wait SDD
def forward_hook(
module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor]],
output: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> None:
self.wait_sparse_data_dist()

self.fwd_hook = self.model.register_forward_hook(forward_hook)

self.initialized = True

_start_data_dist(self._pipelined_modules, batch, self.context)
Expand Down

0 comments on commit da49f44

Please sign in to comment.