From 54047a5c20afd754fd332783dec410a3389c8cb0 Mon Sep 17 00:00:00 2001
From: Yong Hoon Shin <yhshin@meta.com>
Date: Mon, 24 Jun 2024 10:58:45 -0700
Subject: [PATCH] Add ability to specify pipelineable preproc modules to ignore
 during SDD model rewrite (#2149)

Summary:
Pull Request resolved: https://github.com/pytorch/torchrec/pull/2149

Make torchrec automatically pipeline any modules that don't have trainable params during sparse data dist pipelining.

tldr; with some traversal logic changes, TorchRec sparse data dist pipeline can support arbitrary input transformations at input dist stage as long as they are composed of either nn.Module calls or currently supported ops (mainly getattr and getitem)

Differential Revision: D57944338
---
 torchrec/distributed/test_utils/test_model.py | 270 +++++++++-
 .../tests/test_train_pipelines.py             | 490 ++++++++++++++++++
 .../tests/test_train_pipelines_base.py        |   2 +
 .../tests/test_train_pipelines_utils.py       |  88 +++-
 .../train_pipeline/train_pipelines.py         |  42 +-
 torchrec/distributed/train_pipeline/utils.py  | 304 ++++++++++-
 6 files changed, 1153 insertions(+), 43 deletions(-)

diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py
index ff0f7a740..e4ec76ee3 100644
--- a/torchrec/distributed/test_utils/test_model.py
+++ b/torchrec/distributed/test_utils/test_model.py
@@ -7,6 +7,7 @@
 
 # pyre-strict
 
+import copy
 import random
 from dataclasses import dataclass
 from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
@@ -239,10 +240,16 @@ def _validate_pooling_factor(
             else None
         )
 
-        global_float = torch.rand(
-            (batch_size * world_size, num_float_features), device=device
-        )
-        global_label = torch.rand(batch_size * world_size, device=device)
+        if randomize_indices:
+            global_float = torch.rand(
+                (batch_size * world_size, num_float_features), device=device
+            )
+            global_label = torch.rand(batch_size * world_size, device=device)
+        else:
+            global_float = torch.zeros(
+                (batch_size * world_size, num_float_features), device=device
+            )
+            global_label = torch.zeros(batch_size * world_size, device=device)
 
         # Split global batch into local batches.
         local_inputs = []
@@ -939,6 +946,7 @@ def __init__(
         max_feature_lengths_list: Optional[List[Dict[str, int]]] = None,
         feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
         over_arch_clazz: Type[nn.Module] = TestOverArch,
+        preproc_module: Optional[nn.Module] = None,
     ) -> None:
         super().__init__(
             tables=cast(List[BaseEmbeddingConfig], tables),
@@ -960,6 +968,14 @@ def __init__(
         embedding_names = (
             list(embedding_groups.values())[0] if embedding_groups else None
         )
+        self._embedding_names: List[str] = (
+            embedding_names
+            if embedding_names
+            else [feature for table in tables for feature in table.feature_names]
+        )
+        self._weighted_features: List[str] = [
+            feature for table in weighted_tables for feature in table.feature_names
+        ]
         self.over: nn.Module = over_arch_clazz(
             tables, weighted_tables, embedding_names, dense_device
         )
@@ -967,6 +983,7 @@ def __init__(
             "dummy_ones",
             torch.ones(1, device=dense_device),
         )
+        self.preproc_module = preproc_module
 
     def sparse_forward(self, input: ModelInput) -> KeyedTensor:
         return self.sparse(
@@ -993,6 +1010,8 @@ def forward(
         self,
         input: ModelInput,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        if self.preproc_module:
+            input = self.preproc_module(input)
         return self.dense_forward(input, self.sparse_forward(input))
 
 
@@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor:
         continue
 
     return kt
+
+
+class TestPreprocNonWeighted(nn.Module):
+    """
+    Basic module for testing
+
+    Args: None
+    Examples:
+        >>> TestPreprocNonWeighted()
+    Returns:
+        List[KeyedJaggedTensor]
+    """
+
+    def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
+        """
+        Selects 3 features from a specific KJT
+        """
+        # split
+        jt_0 = kjt["feature_0"]
+        jt_1 = kjt["feature_1"]
+        jt_2 = kjt["feature_2"]
+
+        # merge only features 0,1,2, removing feature 3
+        return [
+            KeyedJaggedTensor.from_jt_dict(
+                {
+                    "feature_0": jt_0,
+                    "feature_1": jt_1,
+                    "feature_2": jt_2,
+                }
+            )
+        ]
+
+
+class TestPreprocWeighted(nn.Module):
+    """
+    Basic module for testing
+
+    Args: None
+    Examples:
+        >>> TestPreprocWeighted()
+    Returns:
+        List[KeyedJaggedTensor]
+    """
+
+    def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
+        """
+        Selects 1 feature from specific weighted KJT
+        """
+
+        # split
+        jt_0 = kjt["weighted_feature_0"]
+
+        # keep only weighted_feature_0
+        return [
+            KeyedJaggedTensor.from_jt_dict(
+                {
+                    "weighted_feature_0": jt_0,
+                }
+            )
+        ]
+
+
+class TestModelWithPreproc(nn.Module):
+    """
+    Basic module with up to 3 preproc modules:
+    - preproc on idlist_features for non-weighted EBC
+    - preproc on idscore_features for weighted EBC
+    - optional preproc on model input shared by both EBCs
+
+    Args:
+        tables,
+        weighted_tables,
+        device,
+        preproc_module,
+        num_float_features,
+        run_preproc_inline,
+
+    Example:
+        >>> TestModelWithPreproc(tables, weighted_tables, device)
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]
+    """
+
+    def __init__(
+        self,
+        tables: List[EmbeddingBagConfig],
+        weighted_tables: List[EmbeddingBagConfig],
+        device: torch.device,
+        preproc_module: Optional[nn.Module] = None,
+        num_float_features: int = 10,
+        run_preproc_inline: bool = False,
+    ) -> None:
+        super().__init__()
+        self.dense = TestDenseArch(num_float_features, device)
+
+        self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
+            tables=tables,
+            device=device,
+        )
+        self.weighted_ebc = EmbeddingBagCollection(
+            tables=weighted_tables,
+            is_weighted=True,
+            device=device,
+        )
+        self.preproc_nonweighted = TestPreprocNonWeighted()
+        self.preproc_weighted = TestPreprocWeighted()
+        self._preproc_module = preproc_module
+        self._run_preproc_inline = run_preproc_inline
+
+    def forward(
+        self,
+        input: ModelInput,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Runs preprco for EBC and weighted EBC, optionally runs preproc for input
+
+        Args:
+            input
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]
+        """
+        modified_input = input
+
+        if self._preproc_module is not None:
+            modified_input = self._preproc_module(modified_input)
+        elif self._run_preproc_inline:
+            modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
+                modified_input.idlist_features.keys(),
+                modified_input.idlist_features.values(),
+                modified_input.idlist_features.lengths(),
+            )
+
+        modified_idlist_features = self.preproc_nonweighted(
+            modified_input.idlist_features
+        )
+        modified_idscore_features = self.preproc_weighted(
+            modified_input.idscore_features
+        )
+        ebc_out = self.ebc(modified_idlist_features[0])
+        weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])
+
+        pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
+        return pred.sum(), pred
+
+
+class TestNegSamplingModule(torch.nn.Module):
+    """
+    Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
+
+    Args:
+        extra_input
+        has_params
+
+    Example:
+        >>> preproc = TestNegSamplingModule(extra_input)
+        >>> out = preproc(in)
+
+    Returns:
+        ModelInput
+    """
+
+    def __init__(
+        self,
+        extra_input: ModelInput,
+        has_params: bool = False,
+    ) -> None:
+        super().__init__()
+        self._extra_input = extra_input
+        if has_params:
+            self._linear: nn.Module = nn.Linear(30, 30)
+
+    def forward(self, input: ModelInput) -> ModelInput:
+        """
+        Appends extra features to model input
+
+        Args:
+            input
+        Returns:
+            ModelInput
+        """
+
+        # merge extra input
+        modified_input = copy.deepcopy(input)
+
+        # dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0]
+        modified_input.float_features = torch.concat(
+            (modified_input.float_features, self._extra_input.float_features), dim=0
+        )
+
+        # stride will be same but features will be joined
+        modified_input.idlist_features = KeyedJaggedTensor.concat(
+            [modified_input.idlist_features, self._extra_input.idlist_features]
+        )
+        if self._extra_input.idscore_features is not None:
+            # stride will be smae but features will be joined
+            modified_input.idscore_features = KeyedJaggedTensor.concat(
+                # pyre-ignore
+                [modified_input.idscore_features, self._extra_input.idscore_features]
+            )
+
+        # dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0]
+        modified_input.label = torch.concat(
+            (modified_input.label, self._extra_input.label), dim=0
+        )
+
+        return modified_input
+
+
+class TestPositionWeightedPreprocModule(torch.nn.Module):
+    """
+    Basic module for testing
+
+    Args: None
+    Example:
+        >>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
+        >>> out = preproc(in)
+    Returns:
+        ModelInput
+    """
+
+    def __init__(
+        self, max_feature_lengths: Dict[str, int], device: torch.device
+    ) -> None:
+        super().__init__()
+        self.fp_proc = PositionWeightedProcessor(
+            max_feature_lengths=max_feature_lengths,
+            device=device,
+        )
+
+    def forward(self, input: ModelInput) -> ModelInput:
+        """
+        Runs PositionWeightedProcessor
+
+        Args:
+            input
+        Returns:
+            ModelInput
+        """
+        modified_input = copy.deepcopy(input)
+        modified_input.idlist_features = self.fp_proc(modified_input.idlist_features)
+        return modified_input
diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
index 8aef77d76..f60672833 100644
--- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
+++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
@@ -8,6 +8,7 @@
 # pyre-strict
 
 import copy
+
 import unittest
 from dataclasses import dataclass
 from functools import partial
@@ -32,6 +33,9 @@
 from torchrec.distributed.test_utils.test_model import (
     ModelInput,
     TestEBCSharder,
+    TestModelWithPreproc,
+    TestNegSamplingModule,
+    TestPositionWeightedPreprocModule,
     TestSparseNN,
 )
 from torchrec.distributed.test_utils.test_sharding import copy_state_dict
@@ -54,6 +58,7 @@
     DataLoadingThread,
     get_h2d_func,
     PipelinedForward,
+    PipelinedPreproc,
     PipelineStage,
     SparseDataDistUtil,
     StageOut,
@@ -821,6 +826,491 @@ def test_multi_dataloader_pipelining(self) -> None:
         )
 
 
+class TrainPipelinePreprocTest(TrainPipelineSparseDistTestBase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.num_batches = 10
+        self.batch_size = 32
+        self.sharding_type = ShardingType.TABLE_WISE.value
+        self.kernel_type = EmbeddingComputeKernel.FUSED.value
+        self.fused_params = {}
+
+    def _check_output_equal(
+        self,
+        model: torch.nn.Module,
+        sharding_type: str,
+        max_feature_lengths: Optional[List[int]] = None,
+    ) -> Tuple[nn.Module, TrainPipelineSparseDist[ModelInput, torch.Tensor]]:
+        data = self._generate_data(
+            num_batches=self.num_batches,
+            batch_size=self.batch_size,
+            max_feature_lengths=max_feature_lengths,
+        )
+        dataloader = iter(data)
+
+        sharded_model, optim = self._generate_sharded_model_and_optimizer(
+            model, sharding_type, self.kernel_type, self.fused_params
+        )
+
+        (
+            sharded_model_pipelined,
+            optim_pipelined,
+        ) = self._generate_sharded_model_and_optimizer(
+            model, sharding_type, self.kernel_type, self.fused_params
+        )
+        copy_state_dict(
+            sharded_model.state_dict(), sharded_model_pipelined.state_dict()
+        )
+
+        pipeline = TrainPipelineSparseDist(
+            model=sharded_model_pipelined,
+            optimizer=optim_pipelined,
+            device=self.device,
+            execute_all_batches=True,
+            pipeline_preproc=True,
+        )
+
+        for i in range(self.num_batches):
+            batch = data[i]
+            # Forward + backward w/o pipelining
+            batch = batch.to(self.device)
+            optim.zero_grad()
+            loss, pred = sharded_model(batch)
+            loss.backward()
+            optim.step()
+
+            pred_pipelined = pipeline.progress(dataloader)
+            self.assertTrue(torch.equal(pred, pred_pipelined))
+
+        return sharded_model_pipelined, pipeline
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_modules_share_preproc(self) -> None:
+        """
+        Setup: preproc module takes in input batch and returns modified
+        input batch. EBC and weighted EBC inside model sparse arch subsequently
+        uses this modified KJT.
+
+        Test case where single preproc module is shared by multiple sharded modules
+        and output of preproc module needs to be transformed in the SAME way
+        """
+
+        extra_input = ModelInput.generate(
+            tables=self.tables,
+            weighted_tables=self.weighted_tables,
+            batch_size=self.batch_size,
+            world_size=1,
+            num_float_features=10,
+            randomize_indices=False,
+        )[0].to(self.device)
+
+        preproc_module = TestNegSamplingModule(
+            extra_input=extra_input,
+        )
+        model = self._setup_model(preproc_module=preproc_module)
+
+        pipelined_model, pipeline = self._check_output_equal(
+            model,
+            self.sharding_type,
+        )
+
+        # Check that both EC and EBC pipelined
+        self.assertEqual(len(pipeline._pipelined_modules), 2)
+        self.assertEqual(len(pipeline._pipelined_preprocs), 1)
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_preproc_not_shared_with_arg_transform(self) -> None:
+        """
+        Test case where arguments to preproc module is some non-modifying
+        transformation of the input batch (no nested preproc modules) AND
+        arguments to multiple sharded modules can be derived from the output
+        of different preproc modules (i.e. preproc modules not shared).
+        """
+        model = TestModelWithPreproc(
+            tables=self.tables[:-1],  # ignore last table as preproc will remove
+            weighted_tables=self.weighted_tables[:-1],  # ignore last table
+            device=self.device,
+        )
+
+        pipelined_model, pipeline = self._check_output_equal(
+            model,
+            self.sharding_type,
+        )
+
+        # Check that both EBC and weighted EBC pipelined
+        self.assertEqual(len(pipeline._pipelined_modules), 2)
+
+        pipelined_ebc = pipeline._pipelined_modules[0]
+        pipelined_weighted_ebc = pipeline._pipelined_modules[1]
+
+        # Check pipelined args
+        for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
+            self.assertEqual(len(ebc.forward._args), 1)
+            self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
+            self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
+            self.assertEqual(len(ebc.forward._args[0].preproc_modules), 2)
+            self.assertIsInstance(
+                ebc.forward._args[0].preproc_modules[0], PipelinedPreproc
+            )
+            self.assertEqual(ebc.forward._args[0].preproc_modules[1], None)
+
+        self.assertEqual(
+            pipelined_ebc.forward._args[0].preproc_modules[0],
+            pipelined_model.module.preproc_nonweighted,
+        )
+        self.assertEqual(
+            pipelined_weighted_ebc.forward._args[0].preproc_modules[0],
+            pipelined_model.module.preproc_weighted,
+        )
+
+        # preproc args
+        self.assertEqual(len(pipeline._pipelined_preprocs), 2)
+        for i, input_attr_name in [(0, "idlist_features"), (1, "idscore_features")]:
+            preproc_mod = pipeline._pipelined_preprocs[i]
+            self.assertEqual(len(preproc_mod._args), 1)
+            self.assertEqual(preproc_mod._args[0].input_attrs, ["", input_attr_name])
+            self.assertEqual(preproc_mod._args[0].is_getitems, [False, False])
+            # no parent preproc module in FX graph
+            self.assertEqual(preproc_mod._args[0].preproc_modules, [None, None])
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_preproc_recursive(self) -> None:
+        """
+        Test recursive case where multiple arguments to preproc module is derived
+        from output of another preproc module. For example,
+
+        out_a, out_b, out_c = preproc_1(input)
+        out_d = preproc_2(out_a, out_b)
+        # do something with out_c
+        out = ebc(out_d)
+        """
+        extra_input = ModelInput.generate(
+            tables=self.tables[:-1],
+            weighted_tables=self.weighted_tables[:-1],
+            batch_size=self.batch_size,
+            world_size=1,
+            num_float_features=10,
+            randomize_indices=False,
+        )[0].to(self.device)
+
+        preproc_module = TestNegSamplingModule(
+            extra_input=extra_input,
+        )
+
+        model = TestModelWithPreproc(
+            tables=self.tables[:-1],
+            weighted_tables=self.weighted_tables[:-1],
+            device=self.device,
+            preproc_module=preproc_module,
+        )
+
+        pipelined_model, pipeline = self._check_output_equal(model, self.sharding_type)
+
+        # Check that both EBC and weighted EBC pipelined
+        self.assertEqual(len(pipeline._pipelined_modules), 2)
+
+        pipelined_ebc = pipeline._pipelined_modules[0]
+        pipelined_weighted_ebc = pipeline._pipelined_modules[1]
+
+        # Check pipelined args
+        for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
+            self.assertEqual(len(ebc.forward._args), 1)
+            self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
+            self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
+            self.assertEqual(len(ebc.forward._args[0].preproc_modules), 2)
+            self.assertIsInstance(
+                ebc.forward._args[0].preproc_modules[0], PipelinedPreproc
+            )
+            self.assertEqual(ebc.forward._args[0].preproc_modules[1], None)
+
+        self.assertEqual(
+            pipelined_ebc.forward._args[0].preproc_modules[0],
+            pipelined_model.module.preproc_nonweighted,
+        )
+        self.assertEqual(
+            pipelined_weighted_ebc.forward._args[0].preproc_modules[0],
+            pipelined_model.module.preproc_weighted,
+        )
+
+        # preproc args
+        self.assertEqual(len(pipeline._pipelined_preprocs), 3)
+
+        parent_preproc_mod = pipelined_model.module._preproc_module
+
+        for preproc_mod in pipeline._pipelined_preprocs:
+            if preproc_mod == pipelined_model.module.preproc_nonweighted:
+                self.assertEqual(len(preproc_mod._args), 1)
+                args = preproc_mod._args[0]
+                self.assertEqual(args.input_attrs, ["", "idlist_features"])
+                self.assertEqual(args.is_getitems, [False, False])
+                self.assertEqual(len(args.preproc_modules), 2)
+                self.assertEqual(
+                    args.preproc_modules[0],
+                    parent_preproc_mod,
+                )
+                self.assertEqual(args.preproc_modules[1], None)
+            elif preproc_mod == pipelined_model.module.preproc_weighted:
+                self.assertEqual(len(preproc_mod._args), 1)
+                args = preproc_mod._args[0]
+                self.assertEqual(args.input_attrs, ["", "idscore_features"])
+                self.assertEqual(args.is_getitems, [False, False])
+                self.assertEqual(len(args.preproc_modules), 2)
+                self.assertEqual(
+                    args.preproc_modules[0],
+                    parent_preproc_mod,
+                )
+                self.assertEqual(args.preproc_modules[1], None)
+            elif preproc_mod == parent_preproc_mod:
+                self.assertEqual(len(preproc_mod._args), 1)
+                args = preproc_mod._args[0]
+                self.assertEqual(args.input_attrs, [""])
+                self.assertEqual(args.is_getitems, [False])
+                self.assertEqual(args.preproc_modules, [None])
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None:
+        """
+        Test case where preproc module sits in front of sharded module but this cannot be
+        safely pipelined as it contains trainable params in its child modules
+        """
+        max_feature_lengths = {
+            "feature_0": 10,
+            "feature_1": 10,
+            "feature_2": 10,
+            "feature_3": 10,
+        }
+
+        preproc_module = TestPositionWeightedPreprocModule(
+            max_feature_lengths=max_feature_lengths,
+            device=self.device,
+        )
+
+        model = self._setup_model(preproc_module=preproc_module)
+
+        (
+            sharded_model_pipelined,
+            optim_pipelined,
+        ) = self._generate_sharded_model_and_optimizer(
+            model, self.sharding_type, self.kernel_type, self.fused_params
+        )
+
+        pipeline = TrainPipelineSparseDist(
+            model=sharded_model_pipelined,
+            optimizer=optim_pipelined,
+            device=self.device,
+            execute_all_batches=True,
+            pipeline_preproc=True,
+        )
+
+        data = self._generate_data(
+            num_batches=self.num_batches,
+            batch_size=self.batch_size,
+            max_feature_lengths=list(max_feature_lengths.values()),
+        )
+        dataloader = iter(data)
+
+        pipeline.progress(dataloader)
+
+        # Check that no modules are pipelined
+        self.assertEqual(len(pipeline._pipelined_modules), 0)
+        self.assertEqual(len(pipeline._pipelined_preprocs), 0)
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_invalid_preproc_trainable_params_recursive(
+        self,
+    ) -> None:
+        max_feature_lengths = {
+            "feature_0": 10,
+            "feature_1": 10,
+            "feature_2": 10,
+            "feature_3": 10,
+        }
+
+        preproc_module = TestPositionWeightedPreprocModule(
+            max_feature_lengths=max_feature_lengths,
+            device=self.device,
+        )
+
+        model = TestModelWithPreproc(
+            tables=self.tables[:-1],
+            weighted_tables=self.weighted_tables[:-1],
+            device=self.device,
+            preproc_module=preproc_module,
+        )
+
+        (
+            sharded_model_pipelined,
+            optim_pipelined,
+        ) = self._generate_sharded_model_and_optimizer(
+            model, self.sharding_type, self.kernel_type, self.fused_params
+        )
+
+        pipeline = TrainPipelineSparseDist(
+            model=sharded_model_pipelined,
+            optimizer=optim_pipelined,
+            device=self.device,
+            execute_all_batches=True,
+            pipeline_preproc=True,
+        )
+
+        data = self._generate_data(
+            num_batches=self.num_batches,
+            batch_size=self.batch_size,
+            max_feature_lengths=list(max_feature_lengths.values()),
+        )
+        dataloader = iter(data)
+        pipeline.progress(dataloader)
+
+        # Check that no modules are pipelined
+        self.assertEqual(len(pipeline._pipelined_modules), 0)
+        self.assertEqual(len(pipeline._pipelined_preprocs), 0)
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None:
+        """
+        Test case where preproc module cannot be pipelined because at least one of args
+        is derived from output of another preproc module whose arg(s) cannot be derived
+        from input batch (i.e. it has modifying transformations)
+        """
+        model = TestModelWithPreproc(
+            tables=self.tables[:-1],
+            weighted_tables=self.weighted_tables[:-1],
+            device=self.device,
+            preproc_module=None,
+            run_preproc_inline=True,  # run preproc inline, outside a module
+        )
+
+        (
+            sharded_model_pipelined,
+            optim_pipelined,
+        ) = self._generate_sharded_model_and_optimizer(
+            model, self.sharding_type, self.kernel_type, self.fused_params
+        )
+
+        pipeline = TrainPipelineSparseDist(
+            model=sharded_model_pipelined,
+            optimizer=optim_pipelined,
+            device=self.device,
+            execute_all_batches=True,
+            pipeline_preproc=True,
+        )
+
+        data = self._generate_data(
+            num_batches=self.num_batches,
+            batch_size=self.batch_size,
+        )
+        dataloader = iter(data)
+        pipeline.progress(dataloader)
+
+        # Check that only weighted EBC is pipelined
+        self.assertEqual(len(pipeline._pipelined_modules), 1)
+        self.assertEqual(len(pipeline._pipelined_preprocs), 1)
+        self.assertEqual(pipeline._pipelined_modules[0]._is_weighted, True)
+        self.assertEqual(
+            pipeline._pipelined_preprocs[0],
+            sharded_model_pipelined.module.preproc_weighted,
+        )
+
+    # pyre-ignore
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_pipeline_preproc_fwd_values_cached(self) -> None:
+        """
+        Test to check that during model forward, the preproc module pipelined uses the
+        saved result from previous iteration(s) and doesn't perform duplicate work
+        check that fqns for ALL preproc modules are populated in the right train pipeline
+        context.
+        """
+        extra_input = ModelInput.generate(
+            tables=self.tables[:-1],
+            weighted_tables=self.weighted_tables[:-1],
+            batch_size=self.batch_size,
+            world_size=1,
+            num_float_features=10,
+            randomize_indices=False,
+        )[0].to(self.device)
+
+        preproc_module = TestNegSamplingModule(
+            extra_input=extra_input,
+        )
+
+        model = TestModelWithPreproc(
+            tables=self.tables[:-1],
+            weighted_tables=self.weighted_tables[:-1],
+            device=self.device,
+            preproc_module=preproc_module,
+        )
+
+        (
+            sharded_model_pipelined,
+            optim_pipelined,
+        ) = self._generate_sharded_model_and_optimizer(
+            model, self.sharding_type, self.kernel_type, self.fused_params
+        )
+
+        pipeline = TrainPipelineSparseDist(
+            model=sharded_model_pipelined,
+            optimizer=optim_pipelined,
+            device=self.device,
+            execute_all_batches=True,
+            pipeline_preproc=True,
+        )
+
+        data = self._generate_data(
+            num_batches=self.num_batches,
+            batch_size=self.batch_size,
+        )
+        dataloader = iter(data)
+
+        pipeline.progress(dataloader)
+
+        # This was second context that was appended
+        current_context = pipeline.contexts[0]
+        cached_results = current_context.preproc_fwd_results
+        self.assertEqual(
+            list(cached_results.keys()),
+            ["_preproc_module", "preproc_nonweighted", "preproc_weighted"],
+        )
+
+        # next context cached results should be empty
+        next_context = pipeline.contexts[1]
+        next_cached_results = next_context.preproc_fwd_results
+        self.assertEqual(len(next_cached_results), 0)
+
+        # After progress, next_context should be populated
+        pipeline.progress(dataloader)
+        self.assertEqual(
+            list(next_cached_results.keys()),
+            ["_preproc_module", "preproc_nonweighted", "preproc_weighted"],
+        )
+
+
 class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
     @unittest.skipIf(
         not torch.cuda.is_available(),
diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py
index 8317f2354..6ca45371a 100644
--- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py
+++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py
@@ -93,12 +93,14 @@ def _setup_model(
         self,
         model_type: Type[nn.Module] = TestSparseNN,
         enable_fsdp: bool = False,
+        preproc_module: Optional[nn.Module] = None,
     ) -> nn.Module:
         unsharded_model = model_type(
             tables=self.tables,
             weighted_tables=self.weighted_tables,
             dense_device=self.device,
             sparse_device=torch.device("meta"),
+            preproc_module=preproc_module,
         )
         if enable_fsdp:
             unsharded_model.over.dhn_arch.linear0 = FSDP(
diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py
index 854423385..f8dcf08fb 100644
--- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py
+++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py
@@ -8,14 +8,94 @@
 # pyre-strict
 
 import unittest
+from unittest.mock import MagicMock
 
 import torch
-from torchrec.distributed.train_pipeline.utils import _get_node_args
+from torchrec.distributed.embedding_types import EmbeddingComputeKernel
+from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
+
+from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
+    TrainPipelineSparseDistTestBase,
+)
+from torchrec.distributed.train_pipeline.utils import (
+    _get_node_args,
+    _rewrite_model,
+    PipelinedForward,
+    TrainPipelineContext,
+)
+from torchrec.distributed.types import ShardingType
 
 from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
 
 
-class TestTrainPipelineUtils(unittest.TestCase):
+class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
+    # pyre-fixme[56]: Pyre was not able to infer the type of argument
+    @unittest.skipIf(
+        not torch.cuda.is_available(),
+        "Not enough GPUs, this test requires at least one GPU",
+    )
+    def test_rewrite_model(self) -> None:
+        sharding_type = ShardingType.TABLE_WISE.value
+        kernel_type = EmbeddingComputeKernel.FUSED.value
+        fused_params = {}
+
+        extra_input = ModelInput.generate(
+            tables=self.tables,
+            weighted_tables=self.weighted_tables,
+            batch_size=10,
+            world_size=1,
+            num_float_features=10,
+            randomize_indices=False,
+        )[0].to(self.device)
+
+        preproc_module = TestNegSamplingModule(
+            extra_input=extra_input,
+        )
+        model = self._setup_model(preproc_module=preproc_module)
+
+        sharded_model, optim = self._generate_sharded_model_and_optimizer(
+            model, sharding_type, kernel_type, fused_params
+        )
+
+        # Try to rewrite model without ignored_preproc_modules defined, EBC forwards not overwritten to PipelinedForward due to KJT modification
+        _rewrite_model(
+            model=sharded_model,
+            batch=None,
+            context=TrainPipelineContext(),
+            dist_stream=None,
+        )
+        self.assertNotIsInstance(
+            sharded_model.module.sparse.ebc.forward, PipelinedForward
+        )
+        self.assertNotIsInstance(
+            sharded_model.module.sparse.weighted_ebc.forward, PipelinedForward
+        )
+
+        # Now provide preproc module explicitly
+        _rewrite_model(
+            model=sharded_model,
+            batch=None,
+            context=TrainPipelineContext(),
+            dist_stream=None,
+            pipeline_preproc=True,
+        )
+        self.assertIsInstance(sharded_model.module.sparse.ebc.forward, PipelinedForward)
+        self.assertIsInstance(
+            sharded_model.module.sparse.weighted_ebc.forward, PipelinedForward
+        )
+        self.assertEqual(
+            sharded_model.module.sparse.ebc.forward._args[0].preproc_modules[0],
+            sharded_model.module.preproc_module,
+        )
+        self.assertEqual(
+            sharded_model.module.sparse.weighted_ebc.forward._args[0].preproc_modules[
+                0
+            ],
+            sharded_model.module.preproc_module,
+        )
+
+
+class TestUtils(unittest.TestCase):
     def test_get_node_args_helper_call_module_kjt(self) -> None:
         graph = torch.fx.Graph()
         kjt_args = []
@@ -42,7 +122,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
         )
 
         num_found = 0
-        _, num_found = _get_node_args(kjt_node)
+        _, num_found = _get_node_args(
+            MagicMock(), kjt_node, set(), TrainPipelineContext(), False
+        )
 
         # Weights is call_module node, so we should only find 2 args unmodified
         self.assertEqual(num_found, len(kjt_args) - 1)
diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py
index 21d2855ce..dc1f6b55d 100644
--- a/torchrec/distributed/train_pipeline/train_pipelines.py
+++ b/torchrec/distributed/train_pipeline/train_pipelines.py
@@ -46,6 +46,7 @@
     In,
     Out,
     PipelinedForward,
+    PipelinedPreproc,
     PipelineStage,
     PrefetchPipelinedForward,
     PrefetchTrainPipelineContext,
@@ -301,6 +302,7 @@ def __init__(
         execute_all_batches: bool = True,
         apply_jit: bool = False,
         context_type: Type[TrainPipelineContext] = TrainPipelineContext,
+        pipeline_preproc: bool = False,
     ) -> None:
         self._model = model
         self._optimizer = optimizer
@@ -342,10 +344,12 @@ def __init__(
         ] = []
 
         self._model_attached = True
+        self._pipeline_preproc = pipeline_preproc
 
         self._next_index: int = 0
         self.contexts: Deque[TrainPipelineContext] = deque()
         self._pipelined_modules: List[ShardedModule] = []
+        self._pipelined_preprocs: List[PipelinedPreproc] = []
         self.batches: Deque[Optional[In]] = deque()
         self._dataloader_iter: Optional[Iterator[In]] = None
         self._dataloader_exhausted: bool = False
@@ -397,6 +401,10 @@ def _set_module_context(self, context: TrainPipelineContext) -> None:
         for module in self._pipelined_modules:
             module.forward.set_context(context)
 
+        for preproc_module in self._pipelined_preprocs:
+            # This ensures that next iter model fwd uses cached results
+            preproc_module.set_context(context)
+
     def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
         batch, context = self.copy_batch_to_gpu(dataloader_iter)
         if batch is None:
@@ -494,13 +502,19 @@ def _pipeline_model(
         context: TrainPipelineContext,
         pipelined_forward: Type[PipelinedForward] = PipelinedForward,
     ) -> None:
-        self._pipelined_modules, self._model, self._original_forwards = _rewrite_model(
+        (
+            self._pipelined_modules,
+            self._model,
+            self._original_forwards,
+            self._pipelined_preprocs,
+        ) = _rewrite_model(
             model=self._model,
             context=context,
             dist_stream=self._data_dist_stream,
             batch=batch,
             apply_jit=self._apply_jit,
             pipelined_forward=pipelined_forward,
+            pipeline_preproc=self._pipeline_preproc,
         )
         # initializes input dist, so we can override input dist forwards
         self.start_sparse_data_dist(batch, context)
@@ -576,9 +590,22 @@ def start_sparse_data_dist(
             return
         with record_function(f"## start_sparse_data_dist {context.index} ##"):
             with self._stream_context(self._data_dist_stream):
+                if context.event is not None:
+                    context.event.wait()
                 _wait_for_batch(batch, self._memcpy_stream)
+
+                original_contexts = [p.get_context() for p in self._pipelined_preprocs]
+
+                # Temporarily set context for next iter to populate cache
+                for preproc_mod in self._pipelined_preprocs:
+                    preproc_mod.set_context(context)
+
                 _start_data_dist(self._pipelined_modules, batch, context)
 
+                # Restore context for model fwd
+                for module, context in zip(self._pipelined_preprocs, original_contexts):
+                    module.set_context(context)
+
     def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
         """
         Waits on the input dist splits requests to get the input dist tensors requests,
@@ -680,6 +707,7 @@ def __init__(
         apply_jit: bool = False,
         start_batch: int = 900,
         stash_gradients: bool = False,
+        pipeline_preproc: bool = False,
     ) -> None:
         super().__init__(
             model=model,
@@ -688,6 +716,7 @@ def __init__(
             execute_all_batches=execute_all_batches,
             apply_jit=apply_jit,
             context_type=EmbeddingTrainPipelineContext,
+            pipeline_preproc=pipeline_preproc,
         )
         self._start_batch = start_batch
         self._stash_gradients = stash_gradients
@@ -853,14 +882,9 @@ def start_sparse_data_dist(
         """
         Waits for batch to finish getting copied to GPU, then starts the input dist.  This is Event based version.
         """
-        if batch is None:
-            return
-        with record_function(f"## start_sparse_data_dist {context.index} ##"):
-            with self._stream_context(self._data_dist_stream):
-                _wait_for_event(batch, self._device, context.event)
-                _start_data_dist(self._pipelined_modules, batch, context)
-                context.event = torch.get_device_module(self._device).Event()
-                context.event.record()
+        super().start_sparse_data_dist(batch, context)
+        context.event = torch.get_device_module(self._device).Event()
+        context.event.record()
 
     def start_embedding_lookup(
         self,
diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py
index eca8f84a0..0e9631793 100644
--- a/torchrec/distributed/train_pipeline/utils.py
+++ b/torchrec/distributed/train_pipeline/utils.py
@@ -48,7 +48,6 @@
 from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
 from torchrec.streamable import Multistreamable, Pipelineable
 
-
 logger: logging.Logger = logging.getLogger(__name__)
 
 import torch
@@ -97,6 +96,9 @@ class TrainPipelineContext:
         field(default_factory=list)
     )
     event: Optional[torch.Event] = None
+
+    preproc_fwd_results: Dict[str, Any] = field(default_factory=dict)
+
     index: Optional[int] = None
     version: int = (
         0  # 1 is current version, 0 is deprecated but supported for backward compatibility
@@ -153,9 +155,140 @@ class ArgInfo:
 
     input_attrs: List[str]
     is_getitems: List[bool]
+    # recursive dataclass as preproc_modules.args -> arginfo.preproc_modules -> so on
+    preproc_modules: List[Optional["PipelinedPreproc"]]
     name: Optional[str]
 
 
+# pyre-ignore
+def _build_args_kwargs(
+    # pyre-ignore
+    initial_input: Any,
+    fwd_args: List[ArgInfo],
+) -> Tuple[List[Any], Dict[str, Any]]:
+    args = []
+    kwargs = {}
+    for arg_info in fwd_args:
+        if arg_info.input_attrs:
+            arg = initial_input
+            for attr, is_getitem, preproc_mod in zip(
+                arg_info.input_attrs, arg_info.is_getitems, arg_info.preproc_modules
+            ):
+                if preproc_mod is not None:
+                    # preproc will internally run the same logic recursively
+                    # if its args are derived from other preproc modules
+                    # we can get all inputs to preproc mod based on its recorded args_info + arg passed to it
+                    arg = preproc_mod(arg)
+                else:
+                    if is_getitem:
+                        arg = arg[attr]
+                    elif attr != "":
+                        arg = getattr(arg, attr)
+                    else:
+                        # neither is_getitem nor valid attr, no-op
+                        arg = arg
+            if arg_info.name:
+                kwargs[arg_info.name] = arg
+            else:
+                args.append(arg)
+        else:
+            args.append(None)
+    return args, kwargs
+
+
+class PipelinedPreproc(torch.nn.Module):
+    """
+    Wrapper around preproc module found during model graph traversal for sparse data dist
+    pipelining. In addition to the original module, it encapsulates information needed for
+    execution such as list of ArgInfo and the current training pipeline context.
+
+    Args:
+        preproc_module (torch.nn.Module): preproc module to run
+        fqn (str): fqn of the preproc module in the model being pipelined
+        args (List[ArgInfo]): list of ArgInfo for the preproc module
+        context (TrainPipelineContext): Training context for the next iteration / batch
+
+    Returns:
+        Any
+
+    Example:
+        preproc = PipelinedPreproc(preproc_module, fqn, args, context)
+        # module-swap with pipeliend preproc
+        setattr(model, fqn, preproc)
+    """
+
+    def __init__(
+        self,
+        preproc_module: torch.nn.Module,
+        fqn: str,
+        args: List[ArgInfo],
+        context: TrainPipelineContext,
+    ) -> None:
+        super().__init__()
+        self._preproc_module = preproc_module
+        self._fqn = fqn
+        self._args = args
+        self._context = context
+
+    @property
+    def preproc_module(self) -> torch.nn.Module:
+        return self._preproc_module
+
+    @property
+    def fqn(self) -> str:
+        return self._fqn
+
+    # pyre-ignore
+    def forward(self, *input, **kwargs) -> Any:
+        """
+        Args:
+            Any args and kwargs during model fwd
+            During _start_data_dist, input[0] contains the current data
+        Returns:
+            Any
+        """
+        if self._fqn in self._context.preproc_fwd_results:
+            # This should only be hit in two cases:
+            # 1) During model forward
+            # During model forward, avoid duplicate work
+            # by returning the cached result from previous
+            # iteration's _start_data_dist
+            # 2) During _start_data_dist when preproc module is
+            # shared by more than one args. e.g. if we have
+            # preproc_out_a = preproc_a(input)
+            # preproc_out_b = preproc_b(preproc_out_a) <- preproc_a shared
+            # preproc_out_c = preproc_c(preproc_out_a) <-^
+            # When processing preproc_b, we cache value of preproc_a(input)
+            # so when processing preproc_c, we can reuse preproc_a(input)
+            res = self._context.preproc_fwd_results[self._fqn]
+            return res
+
+        # Everything below should only be called during _start_data_dist stage
+
+        # Build up arg and kwargs from recursive call to pass to preproc module
+        # Arguments to preproc module can be also be a derived product
+        # of another preproc module call, as long as module is pipelineable
+
+        # Use input[0] as _start_data_dist only passes 1 arg
+        args, kwargs = _build_args_kwargs(input[0], self._args)
+
+        with record_function(f"## sdd_input_preproc {self._context.index} ##"):
+            res = self._preproc_module(*args, **kwargs)
+            # Cache results, only during _start_data_dist
+            self._context.preproc_fwd_results[self._fqn] = res
+            return res
+
+    @property
+    def args(self) -> List[ArgInfo]:
+        return self._args
+
+    def set_context(self, context: TrainPipelineContext) -> None:
+        self._context = context
+
+    def get_context(self) -> TrainPipelineContext:
+        return self._context
+
+
 class BaseForward:
     def __init__(
         self,
@@ -258,7 +391,6 @@ def __init__(
 
     # pyre-ignore [2, 24]
     def __call__(self, *input, **kwargs) -> Awaitable:
-
         assert (
             self._name
             # pyre-ignore [16]
@@ -422,22 +554,8 @@ def _start_data_dist(
         # False means this argument is getting while getattr
         # and this info was done in the _rewrite_model by tracing the
         # entire model to get the arg_info_list
-        args = []
-        kwargs = {}
-        for arg_info in forward.args:
-            if arg_info.input_attrs:
-                arg = batch
-                for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
-                    if is_getitem:
-                        arg = arg[attr]
-                    else:
-                        arg = getattr(arg, attr)
-                if arg_info.name:
-                    kwargs[arg_info.name] = arg
-                else:
-                    args.append(arg)
-            else:
-                args.append(None)
+        args, kwargs = _build_args_kwargs(batch, forward.args)
+
         # Start input distribution.
         module_ctx = module.create_context()
         if context.version == 0:
@@ -522,17 +640,32 @@ def _check_args_for_call_module(
     return False
 
 
+def _check_preproc_pipelineable(
+    module: torch.nn.Module,
+) -> bool:
+    for _, _ in module.named_parameters(recurse=True):
+        # Cannot have any trainable params for it to be pipelined
+        logger.warning(
+            f"Module {module} cannot be pipelined as it has trainable parameters"
+        )
+        return False
+    return True
+
+
 def _get_node_args_helper(
+    model: torch.nn.Module,
     # pyre-ignore
     arguments,
     num_found: int,
+    pipelined_preprocs: Set[PipelinedPreproc],
+    context: TrainPipelineContext,
+    pipeline_preproc: bool,
 ) -> Tuple[List[ArgInfo], int]:
     """
     Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
     It also counts the number of (args + kwargs) found.
     """
-
-    arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
+    arg_info_list = [ArgInfo([], [], [], None) for _ in range(len(arguments))]
     for arg, arg_info in zip(arguments, arg_info_list):
         if arg is None:
             num_found += 1
@@ -547,6 +680,13 @@ def _get_node_args_helper(
                     # pyre-ignore[16]
                     arg_info.input_attrs.insert(0, child_node.ph_key)
                     arg_info.is_getitems.insert(0, False)
+                    arg_info.preproc_modules.insert(0, None)
+                else:
+                    # no-op
+                    arg_info.input_attrs.insert(0, "")
+                    arg_info.is_getitems.insert(0, False)
+                    arg_info.preproc_modules.insert(0, None)
+
                 num_found += 1
                 break
             elif (
@@ -561,6 +701,7 @@ def _get_node_args_helper(
                 #  memory_format, Tensor, typing.Tuple[typing.Any, ...]]`.
                 arg_info.input_attrs.insert(0, child_node.args[1])
                 arg_info.is_getitems.insert(0, False)
+                arg_info.preproc_modules.insert(0, None)
                 arg = child_node.args[0]
             elif (
                 child_node.op == "call_function"
@@ -574,6 +715,7 @@ def _get_node_args_helper(
                 #  memory_format, Tensor, typing.Tuple[typing.Any, ...]]`.
                 arg_info.input_attrs.insert(0, child_node.args[1])
                 arg_info.is_getitems.insert(0, True)
+                arg_info.preproc_modules.insert(0, None)
                 arg = child_node.args[0]
             elif (
                 child_node.op == "call_function"
@@ -610,18 +752,109 @@ def _get_node_args_helper(
                     arg = child_node.kwargs["values"]
                 else:
                     arg = child_node.args[1]
+            elif child_node.op == "call_module":
+                preproc_module_fqn = str(child_node.target)
+                preproc_module = getattr(model, preproc_module_fqn, None)
+
+                if not pipeline_preproc:
+                    logger.warning(
+                        f"Found module {preproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_preproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_preproc=True`"
+                    )
+                    break
+
+                if not preproc_module:
+                    # Could not find such module, should not happen
+                    break
+
+                if isinstance(preproc_module, PipelinedPreproc):
+                    # Already did module swap and registered args, early exit
+                    arg_info.input_attrs.insert(0, "")  # dummy value
+                    arg_info.is_getitems.insert(0, False)
+                    pipelined_preprocs.add(preproc_module)
+                    arg_info.preproc_modules.insert(0, preproc_module)
+                    num_found += 1
+                    break
+
+                if not isinstance(preproc_module, torch.nn.Module):
+                    logger.warning(
+                        f"Expected preproc_module to be nn.Module but was {type(preproc_module)}"
+                    )
+                    break
+
+                # check if module is safe to pipeline i.e.no trainable param
+                if not _check_preproc_pipelineable(preproc_module):
+                    break
+
+                # For module calls, `self` isn't counted
+                total_num_args = len(child_node.args) + len(child_node.kwargs)
+                if total_num_args == 0:
+                    # module call without any args, assume KJT modified
+                    break
+
+                # recursive call to check that all inputs to this preproc module
+                # is either made of preproc module or non-modifying train batch input
+                # transformations
+                preproc_args, num_found_safe_preproc_args = _get_node_args(
+                    model, child_node, pipelined_preprocs, context, pipeline_preproc
+                )
+                if num_found_safe_preproc_args == total_num_args:
+                    logger.info(
+                        f"""Module {preproc_module} is a valid preproc module (no
+                        trainable params and inputs can be derived from train batch input
+                         via a series of either valid preproc modules or non-modifying
+                         transformations) and will be applied during sparse data dist 
+                         stage"""
+                    )
+
+                    pipelined_preproc_module = PipelinedPreproc(
+                        preproc_module,
+                        preproc_module_fqn,
+                        preproc_args,
+                        context,
+                    )
+
+                    # module swap
+                    setattr(model, preproc_module_fqn, pipelined_preproc_module)
+
+                    arg_info.input_attrs.insert(0, "")  # dummy value
+                    arg_info.is_getitems.insert(0, False)
+                    pipelined_preprocs.add(pipelined_preproc_module)
+                    arg_info.preproc_modules.insert(0, pipelined_preproc_module)
+
+                    num_found += 1
+
+                # we cannot set any other `arg` value here
+                # break to avoid infinite loop
+                break
             else:
                 break
     return arg_info_list, num_found
 
 
 def _get_node_args(
+    model: torch.nn.Module,
     node: Node,
+    pipelined_preprocs: Set[PipelinedPreproc],
+    context: TrainPipelineContext,
+    pipeline_preproc: bool,
 ) -> Tuple[List[ArgInfo], int]:
     num_found = 0
-    pos_arg_info_list, num_found = _get_node_args_helper(node.args, num_found)
+
+    pos_arg_info_list, num_found = _get_node_args_helper(
+        model,
+        node.args,
+        num_found,
+        pipelined_preprocs,
+        context,
+        pipeline_preproc,
+    )
     kwargs_arg_info_list, num_found = _get_node_args_helper(
-        node.kwargs.values(), num_found
+        model,
+        node.kwargs.values(),
+        num_found,
+        pipelined_preprocs,
+        context,
+        pipeline_preproc,
     )
 
     # Replace with proper names for kwargs
@@ -629,7 +862,8 @@ def _get_node_args(
         arg_info_list.name = name
 
     arg_info_list = pos_arg_info_list + kwargs_arg_info_list
-    return arg_info_list, num_found
+
+    return (arg_info_list, num_found)
 
 
 def _get_leaf_module_names_helper(
@@ -744,7 +978,13 @@ def _rewrite_model(  # noqa C901
     batch: Optional[In] = None,
     apply_jit: bool = False,
     pipelined_forward: Type[BaseForward] = PipelinedForward,
-) -> Tuple[List[ShardedModule], torch.nn.Module, List[Callable[..., Any]]]:
+    pipeline_preproc: bool = False,
+) -> Tuple[
+    List[ShardedModule],
+    torch.nn.Module,
+    List[Callable[..., Any]],
+    List[PipelinedPreproc],
+]:
     input_model = model
     # Get underlying nn.Module
     if isinstance(model, DistributedModelParallel):
@@ -781,12 +1021,21 @@ def _rewrite_model(  # noqa C901
     # i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'.
     pipelined_forwards = []
     original_forwards = []
+
+    pipelined_preprocs: Set[PipelinedPreproc] = set()
+
     for node in graph.nodes:
         if node.op == "call_module" and node.target in sharded_modules:
             total_num_args = len(node.args) + len(node.kwargs)
             if total_num_args == 0:
                 continue
-            arg_info_list, num_found = _get_node_args(node)
+            arg_info_list, num_found = _get_node_args(
+                model,
+                node,
+                pipelined_preprocs,
+                context,
+                pipeline_preproc,
+            )
 
             if num_found == total_num_args:
                 logger.info(f"Module '{node.target}'' will be pipelined")
@@ -812,7 +1061,7 @@ def _rewrite_model(  # noqa C901
         if isinstance(input_model, DistributedModelParallel):
             input_model.module = graph_model
 
-    return pipelined_forwards, input_model, original_forwards
+    return pipelined_forwards, input_model, original_forwards, list(pipelined_preprocs)
 
 
 def _override_input_dist_forwards(
@@ -985,7 +1234,8 @@ def detach(self) -> torch.nn.Module:
     def start_sparse_data_dist(self, batch: In) -> In:
         if not self.initialized:
             # Step 1: Pipeline input dist in trec sharded modules
-            self._pipelined_modules, self.model, self._original_forwards = (
+            # TODO (yhshin): support preproc modules for `StagedTrainPipeline`
+            self._pipelined_modules, self.model, self._original_forwards, _ = (
                 _rewrite_model(
                     model=self.model,
                     context=self.context,