diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 2886b23eb..50cbc1ffe 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -376,7 +376,6 @@ def __init__( self._dataloader_iter: Optional[Iterator[In]] = None self._dataloader_exhausted: bool = False self._context_type: Type[TrainPipelineContext] = context_type - self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = ( custom_model_fwd if custom_model_fwd else model )