Skip to content

Commit

Permalink
Add custom model fwd in train pipelines (#2324)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2324

Add missing pipelline_preproc and custom_moel_fwd args.

Reviewed By: chrisxcai

Differential Revision: D61564467

fbshipit-source-id: 280f0e83de13e10ff2901bbda611d7ba76c8ac68
  • Loading branch information
sarckk authored and facebook-github-bot committed Aug 23, 2024
1 parent 6f0ea08 commit 97585b8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
48 changes: 48 additions & 0 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,54 @@ def test_multi_dataloader_pipelining(self) -> None:
)
)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_custom_fwd(
self,
) -> None:
data = self._generate_data(
num_batches=4,
batch_size=32,
)
dataloader = iter(data)

fused_params_pipelined = {}
sharding_type = ShardingType.ROW_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
sharded_model_pipelined: torch.nn.Module

model = self._setup_model()

(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params_pipelined
)

def custom_model_fwd(
input: Optional[ModelInput],
) -> Tuple[torch.Tensor, torch.Tensor]:
loss, pred = sharded_model_pipelined(input)
batch_size = pred.size(0)
return loss, pred.expand(batch_size * 2, -1)

pipeline = TrainPipelineSparseDist(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
execute_all_batches=True,
custom_model_fwd=custom_model_fwd,
)

for _ in data:
# Forward + backward w/ pipelining
pred_pipeline = pipeline.progress(dataloader)
self.assertEqual(pred_pipeline.size(0), 64)


class TrainPipelinePreprocTest(TrainPipelineSparseDistTestBase):
def setUp(self) -> None:
Expand Down
28 changes: 17 additions & 11 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def __init__(
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
pipeline_preproc: bool = False,
custom_model_fwd: Optional[
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
) -> None:
self._model = model
Expand Down Expand Up @@ -366,6 +366,10 @@ def __init__(
self._dataloader_exhausted: bool = False
self._context_type: Type[TrainPipelineContext] = context_type

self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = (
custom_model_fwd if custom_model_fwd else model
)

# DEPRECATED FIELDS
self._batch_i: Optional[In] = None
self._batch_ip1: Optional[In] = None
Expand Down Expand Up @@ -483,9 +487,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:

# forward
with record_function("## forward ##"):
losses, output = cast(
Tuple[torch.Tensor, Out], self._model(self.batches[0])
)
losses, output = self._model_fwd(self.batches[0])

if len(self.batches) >= 2:
self.wait_sparse_data_dist(self.contexts[1])
Expand Down Expand Up @@ -718,7 +720,7 @@ def __init__(
stash_gradients: bool = False,
pipeline_preproc: bool = False,
custom_model_fwd: Optional[
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
) -> None:
super().__init__(
Expand All @@ -729,6 +731,7 @@ def __init__(
apply_jit=apply_jit,
context_type=EmbeddingTrainPipelineContext,
pipeline_preproc=pipeline_preproc,
custom_model_fwd=custom_model_fwd,
)
self._start_batch = start_batch
self._stash_gradients = stash_gradients
Expand All @@ -752,9 +755,6 @@ def __init__(
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
self._embedding_even_streams: List[Optional[torch.Stream]] = []
self._gradients: Dict[str, torch.Tensor] = {}
self._model_fwd: Union[
torch.nn.Module, Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
] = (custom_model_fwd if custom_model_fwd is not None else model)

def _grad_swap(self) -> None:
for name, param in self._model.named_parameters():
Expand Down Expand Up @@ -893,7 +893,7 @@ def _mlp_forward(
_wait_for_events(
batch, context, torch.get_device_module(self._device).current_stream()
)
return cast(Tuple[torch.Tensor, Out], self._model_fwd(batch))
return self._model_fwd(batch)

def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
default_stream = torch.get_device_module(self._device).current_stream()
Expand Down Expand Up @@ -1020,6 +1020,10 @@ def __init__(
device: torch.device,
execute_all_batches: bool = True,
apply_jit: bool = False,
pipeline_preproc: bool = False,
custom_model_fwd: Optional[
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
) -> None:
super().__init__(
model=model,
Expand All @@ -1028,6 +1032,8 @@ def __init__(
execute_all_batches=execute_all_batches,
apply_jit=apply_jit,
context_type=PrefetchTrainPipelineContext,
pipeline_preproc=pipeline_preproc,
custom_model_fwd=custom_model_fwd,
)
self._context = PrefetchTrainPipelineContext(version=0)
self._prefetch_stream: Optional[torch.Stream] = (
Expand Down Expand Up @@ -1084,7 +1090,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self._wait_sparse_data_dist()
# forward
with record_function("## forward ##"):
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))
losses, output = self._model_fwd(self._batch_i)

self._prefetch(self._batch_ip1)

Expand Down Expand Up @@ -1527,7 +1533,7 @@ def __init__(
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
pipeline_preproc: bool = False,
custom_model_fwd: Optional[
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
) -> None:
super().__init__(
Expand Down

0 comments on commit 97585b8

Please sign in to comment.