diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 6c7878189..ad46d9d69 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -827,6 +827,54 @@ def test_multi_dataloader_pipelining(self) -> None: ) ) + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_custom_fwd( + self, + ) -> None: + data = self._generate_data( + num_batches=4, + batch_size=32, + ) + dataloader = iter(data) + + fused_params_pipelined = {} + sharding_type = ShardingType.ROW_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + sharded_model_pipelined: torch.nn.Module + + model = self._setup_model() + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + + def custom_model_fwd( + input: Optional[ModelInput], + ) -> Tuple[torch.Tensor, torch.Tensor]: + loss, pred = sharded_model_pipelined(input) + batch_size = pred.size(0) + return loss, pred.expand(batch_size * 2, -1) + + pipeline = TrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + custom_model_fwd=custom_model_fwd, + ) + + for _ in data: + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + self.assertEqual(pred_pipeline.size(0), 64) + class TrainPipelinePreprocTest(TrainPipelineSparseDistTestBase): def setUp(self) -> None: diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index ae0aecd69..bab055c94 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -312,7 +312,7 @@ def __init__( context_type: Type[TrainPipelineContext] = TrainPipelineContext, pipeline_preproc: bool = False, custom_model_fwd: Optional[ - Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, ) -> None: self._model = model @@ -366,6 +366,10 @@ def __init__( self._dataloader_exhausted: bool = False self._context_type: Type[TrainPipelineContext] = context_type + self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = ( + custom_model_fwd if custom_model_fwd else model + ) + # DEPRECATED FIELDS self._batch_i: Optional[In] = None self._batch_ip1: Optional[In] = None @@ -483,9 +487,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # forward with record_function("## forward ##"): - losses, output = cast( - Tuple[torch.Tensor, Out], self._model(self.batches[0]) - ) + losses, output = self._model_fwd(self.batches[0]) if len(self.batches) >= 2: self.wait_sparse_data_dist(self.contexts[1]) @@ -718,7 +720,7 @@ def __init__( stash_gradients: bool = False, pipeline_preproc: bool = False, custom_model_fwd: Optional[ - Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, ) -> None: super().__init__( @@ -729,6 +731,7 @@ def __init__( apply_jit=apply_jit, context_type=EmbeddingTrainPipelineContext, pipeline_preproc=pipeline_preproc, + custom_model_fwd=custom_model_fwd, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -752,9 +755,6 @@ def __init__( self._embedding_odd_streams: List[Optional[torch.Stream]] = [] self._embedding_even_streams: List[Optional[torch.Stream]] = [] self._gradients: Dict[str, torch.Tensor] = {} - self._model_fwd: Union[ - torch.nn.Module, Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] - ] = (custom_model_fwd if custom_model_fwd is not None else model) def _grad_swap(self) -> None: for name, param in self._model.named_parameters(): @@ -893,7 +893,7 @@ def _mlp_forward( _wait_for_events( batch, context, torch.get_device_module(self._device).current_stream() ) - return cast(Tuple[torch.Tensor, Out], self._model_fwd(batch)) + return self._model_fwd(batch) def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: default_stream = torch.get_device_module(self._device).current_stream() @@ -1020,6 +1020,10 @@ def __init__( device: torch.device, execute_all_batches: bool = True, apply_jit: bool = False, + pipeline_preproc: bool = False, + custom_model_fwd: Optional[ + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] + ] = None, ) -> None: super().__init__( model=model, @@ -1028,6 +1032,8 @@ def __init__( execute_all_batches=execute_all_batches, apply_jit=apply_jit, context_type=PrefetchTrainPipelineContext, + pipeline_preproc=pipeline_preproc, + custom_model_fwd=custom_model_fwd, ) self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.Stream] = ( @@ -1084,7 +1090,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self._wait_sparse_data_dist() # forward with record_function("## forward ##"): - losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) + losses, output = self._model_fwd(self._batch_i) self._prefetch(self._batch_ip1) @@ -1527,7 +1533,7 @@ def __init__( context_type: Type[TrainPipelineContext] = TrainPipelineContext, pipeline_preproc: bool = False, custom_model_fwd: Optional[ - Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, ) -> None: super().__init__(