diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index ff0f7a740..e4ec76ee3 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -7,6 +7,7 @@ # pyre-strict +import copy import random from dataclasses import dataclass from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union @@ -239,10 +240,16 @@ def _validate_pooling_factor( else None ) - global_float = torch.rand( - (batch_size * world_size, num_float_features), device=device - ) - global_label = torch.rand(batch_size * world_size, device=device) + if randomize_indices: + global_float = torch.rand( + (batch_size * world_size, num_float_features), device=device + ) + global_label = torch.rand(batch_size * world_size, device=device) + else: + global_float = torch.zeros( + (batch_size * world_size, num_float_features), device=device + ) + global_label = torch.zeros(batch_size * world_size, device=device) # Split global batch into local batches. local_inputs = [] @@ -939,6 +946,7 @@ def __init__( max_feature_lengths_list: Optional[List[Dict[str, int]]] = None, feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, over_arch_clazz: Type[nn.Module] = TestOverArch, + preproc_module: Optional[nn.Module] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -960,6 +968,14 @@ def __init__( embedding_names = ( list(embedding_groups.values())[0] if embedding_groups else None ) + self._embedding_names: List[str] = ( + embedding_names + if embedding_names + else [feature for table in tables for feature in table.feature_names] + ) + self._weighted_features: List[str] = [ + feature for table in weighted_tables for feature in table.feature_names + ] self.over: nn.Module = over_arch_clazz( tables, weighted_tables, embedding_names, dense_device ) @@ -967,6 +983,7 @@ def __init__( "dummy_ones", torch.ones(1, device=dense_device), ) + self.preproc_module = preproc_module def sparse_forward(self, input: ModelInput) -> KeyedTensor: return self.sparse( @@ -993,6 +1010,8 @@ def forward( self, input: ModelInput, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.preproc_module: + input = self.preproc_module(input) return self.dense_forward(input, self.sparse_forward(input)) @@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor: continue return kt + + +class TestPreprocNonWeighted(nn.Module): + """ + Basic module for testing + + Args: None + Examples: + >>> TestPreprocNonWeighted() + Returns: + List[KeyedJaggedTensor] + """ + + def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]: + """ + Selects 3 features from a specific KJT + """ + # split + jt_0 = kjt["feature_0"] + jt_1 = kjt["feature_1"] + jt_2 = kjt["feature_2"] + + # merge only features 0,1,2, removing feature 3 + return [ + KeyedJaggedTensor.from_jt_dict( + { + "feature_0": jt_0, + "feature_1": jt_1, + "feature_2": jt_2, + } + ) + ] + + +class TestPreprocWeighted(nn.Module): + """ + Basic module for testing + + Args: None + Examples: + >>> TestPreprocWeighted() + Returns: + List[KeyedJaggedTensor] + """ + + def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]: + """ + Selects 1 feature from specific weighted KJT + """ + + # split + jt_0 = kjt["weighted_feature_0"] + + # keep only weighted_feature_0 + return [ + KeyedJaggedTensor.from_jt_dict( + { + "weighted_feature_0": jt_0, + } + ) + ] + + +class TestModelWithPreproc(nn.Module): + """ + Basic module with up to 3 preproc modules: + - preproc on idlist_features for non-weighted EBC + - preproc on idscore_features for weighted EBC + - optional preproc on model input shared by both EBCs + + Args: + tables, + weighted_tables, + device, + preproc_module, + num_float_features, + run_preproc_inline, + + Example: + >>> TestModelWithPreproc(tables, weighted_tables, device) + + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + preproc_module: Optional[nn.Module] = None, + num_float_features: int = 10, + run_preproc_inline: bool = False, + ) -> None: + super().__init__() + self.dense = TestDenseArch(num_float_features, device) + + self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=tables, + device=device, + ) + self.weighted_ebc = EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + self.preproc_nonweighted = TestPreprocNonWeighted() + self.preproc_weighted = TestPreprocWeighted() + self._preproc_module = preproc_module + self._run_preproc_inline = run_preproc_inline + + def forward( + self, + input: ModelInput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runs preprco for EBC and weighted EBC, optionally runs preproc for input + + Args: + input + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + modified_input = input + + if self._preproc_module is not None: + modified_input = self._preproc_module(modified_input) + elif self._run_preproc_inline: + modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync( + modified_input.idlist_features.keys(), + modified_input.idlist_features.values(), + modified_input.idlist_features.lengths(), + ) + + modified_idlist_features = self.preproc_nonweighted( + modified_input.idlist_features + ) + modified_idscore_features = self.preproc_weighted( + modified_input.idscore_features + ) + ebc_out = self.ebc(modified_idlist_features[0]) + weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0]) + + pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1) + return pred.sum(), pred + + +class TestNegSamplingModule(torch.nn.Module): + """ + Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing + + Args: + extra_input + has_params + + Example: + >>> preproc = TestNegSamplingModule(extra_input) + >>> out = preproc(in) + + Returns: + ModelInput + """ + + def __init__( + self, + extra_input: ModelInput, + has_params: bool = False, + ) -> None: + super().__init__() + self._extra_input = extra_input + if has_params: + self._linear: nn.Module = nn.Linear(30, 30) + + def forward(self, input: ModelInput) -> ModelInput: + """ + Appends extra features to model input + + Args: + input + Returns: + ModelInput + """ + + # merge extra input + modified_input = copy.deepcopy(input) + + # dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0] + modified_input.float_features = torch.concat( + (modified_input.float_features, self._extra_input.float_features), dim=0 + ) + + # stride will be same but features will be joined + modified_input.idlist_features = KeyedJaggedTensor.concat( + [modified_input.idlist_features, self._extra_input.idlist_features] + ) + if self._extra_input.idscore_features is not None: + # stride will be smae but features will be joined + modified_input.idscore_features = KeyedJaggedTensor.concat( + # pyre-ignore + [modified_input.idscore_features, self._extra_input.idscore_features] + ) + + # dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0] + modified_input.label = torch.concat( + (modified_input.label, self._extra_input.label), dim=0 + ) + + return modified_input + + +class TestPositionWeightedPreprocModule(torch.nn.Module): + """ + Basic module for testing + + Args: None + Example: + >>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device) + >>> out = preproc(in) + Returns: + ModelInput + """ + + def __init__( + self, max_feature_lengths: Dict[str, int], device: torch.device + ) -> None: + super().__init__() + self.fp_proc = PositionWeightedProcessor( + max_feature_lengths=max_feature_lengths, + device=device, + ) + + def forward(self, input: ModelInput) -> ModelInput: + """ + Runs PositionWeightedProcessor + + Args: + input + Returns: + ModelInput + """ + modified_input = copy.deepcopy(input) + modified_input.idlist_features = self.fp_proc(modified_input.idlist_features) + return modified_input diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 8aef77d76..f60672833 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -8,6 +8,7 @@ # pyre-strict import copy + import unittest from dataclasses import dataclass from functools import partial @@ -32,6 +33,9 @@ from torchrec.distributed.test_utils.test_model import ( ModelInput, TestEBCSharder, + TestModelWithPreproc, + TestNegSamplingModule, + TestPositionWeightedPreprocModule, TestSparseNN, ) from torchrec.distributed.test_utils.test_sharding import copy_state_dict @@ -54,6 +58,7 @@ DataLoadingThread, get_h2d_func, PipelinedForward, + PipelinedPreproc, PipelineStage, SparseDataDistUtil, StageOut, @@ -821,6 +826,491 @@ def test_multi_dataloader_pipelining(self) -> None: ) +class TrainPipelinePreprocTest(TrainPipelineSparseDistTestBase): + def setUp(self) -> None: + super().setUp() + self.num_batches = 10 + self.batch_size = 32 + self.sharding_type = ShardingType.TABLE_WISE.value + self.kernel_type = EmbeddingComputeKernel.FUSED.value + self.fused_params = {} + + def _check_output_equal( + self, + model: torch.nn.Module, + sharding_type: str, + max_feature_lengths: Optional[List[int]] = None, + ) -> Tuple[nn.Module, TrainPipelineSparseDist[ModelInput, torch.Tensor]]: + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + max_feature_lengths=max_feature_lengths, + ) + dataloader = iter(data) + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, self.kernel_type, self.fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, self.kernel_type, self.fused_params + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = TrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_preproc=True, + ) + + for i in range(self.num_batches): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + pred_pipelined = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred, pred_pipelined)) + + return sharded_model_pipelined, pipeline + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_modules_share_preproc(self) -> None: + """ + Setup: preproc module takes in input batch and returns modified + input batch. EBC and weighted EBC inside model sparse arch subsequently + uses this modified KJT. + + Test case where single preproc module is shared by multiple sharded modules + and output of preproc module needs to be transformed in the SAME way + """ + + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + preproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(preproc_module=preproc_module) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # Check that both EC and EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + self.assertEqual(len(pipeline._pipelined_preprocs), 1) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_preproc_not_shared_with_arg_transform(self) -> None: + """ + Test case where arguments to preproc module is some non-modifying + transformation of the input batch (no nested preproc modules) AND + arguments to multiple sharded modules can be derived from the output + of different preproc modules (i.e. preproc modules not shared). + """ + model = TestModelWithPreproc( + tables=self.tables[:-1], # ignore last table as preproc will remove + weighted_tables=self.weighted_tables[:-1], # ignore last table + device=self.device, + ) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # Check that both EBC and weighted EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + + pipelined_ebc = pipeline._pipelined_modules[0] + pipelined_weighted_ebc = pipeline._pipelined_modules[1] + + # Check pipelined args + for ebc in [pipelined_ebc, pipelined_weighted_ebc]: + self.assertEqual(len(ebc.forward._args), 1) + self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) + self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) + self.assertEqual(len(ebc.forward._args[0].preproc_modules), 2) + self.assertIsInstance( + ebc.forward._args[0].preproc_modules[0], PipelinedPreproc + ) + self.assertEqual(ebc.forward._args[0].preproc_modules[1], None) + + self.assertEqual( + pipelined_ebc.forward._args[0].preproc_modules[0], + pipelined_model.module.preproc_nonweighted, + ) + self.assertEqual( + pipelined_weighted_ebc.forward._args[0].preproc_modules[0], + pipelined_model.module.preproc_weighted, + ) + + # preproc args + self.assertEqual(len(pipeline._pipelined_preprocs), 2) + for i, input_attr_name in [(0, "idlist_features"), (1, "idscore_features")]: + preproc_mod = pipeline._pipelined_preprocs[i] + self.assertEqual(len(preproc_mod._args), 1) + self.assertEqual(preproc_mod._args[0].input_attrs, ["", input_attr_name]) + self.assertEqual(preproc_mod._args[0].is_getitems, [False, False]) + # no parent preproc module in FX graph + self.assertEqual(preproc_mod._args[0].preproc_modules, [None, None]) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_preproc_recursive(self) -> None: + """ + Test recursive case where multiple arguments to preproc module is derived + from output of another preproc module. For example, + + out_a, out_b, out_c = preproc_1(input) + out_d = preproc_2(out_a, out_b) + # do something with out_c + out = ebc(out_d) + """ + extra_input = ModelInput.generate( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + preproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + preproc_module=preproc_module, + ) + + pipelined_model, pipeline = self._check_output_equal(model, self.sharding_type) + + # Check that both EBC and weighted EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + + pipelined_ebc = pipeline._pipelined_modules[0] + pipelined_weighted_ebc = pipeline._pipelined_modules[1] + + # Check pipelined args + for ebc in [pipelined_ebc, pipelined_weighted_ebc]: + self.assertEqual(len(ebc.forward._args), 1) + self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) + self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) + self.assertEqual(len(ebc.forward._args[0].preproc_modules), 2) + self.assertIsInstance( + ebc.forward._args[0].preproc_modules[0], PipelinedPreproc + ) + self.assertEqual(ebc.forward._args[0].preproc_modules[1], None) + + self.assertEqual( + pipelined_ebc.forward._args[0].preproc_modules[0], + pipelined_model.module.preproc_nonweighted, + ) + self.assertEqual( + pipelined_weighted_ebc.forward._args[0].preproc_modules[0], + pipelined_model.module.preproc_weighted, + ) + + # preproc args + self.assertEqual(len(pipeline._pipelined_preprocs), 3) + + parent_preproc_mod = pipelined_model.module._preproc_module + + for preproc_mod in pipeline._pipelined_preprocs: + if preproc_mod == pipelined_model.module.preproc_nonweighted: + self.assertEqual(len(preproc_mod._args), 1) + args = preproc_mod._args[0] + self.assertEqual(args.input_attrs, ["", "idlist_features"]) + self.assertEqual(args.is_getitems, [False, False]) + self.assertEqual(len(args.preproc_modules), 2) + self.assertEqual( + args.preproc_modules[0], + parent_preproc_mod, + ) + self.assertEqual(args.preproc_modules[1], None) + elif preproc_mod == pipelined_model.module.preproc_weighted: + self.assertEqual(len(preproc_mod._args), 1) + args = preproc_mod._args[0] + self.assertEqual(args.input_attrs, ["", "idscore_features"]) + self.assertEqual(args.is_getitems, [False, False]) + self.assertEqual(len(args.preproc_modules), 2) + self.assertEqual( + args.preproc_modules[0], + parent_preproc_mod, + ) + self.assertEqual(args.preproc_modules[1], None) + elif preproc_mod == parent_preproc_mod: + self.assertEqual(len(preproc_mod._args), 1) + args = preproc_mod._args[0] + self.assertEqual(args.input_attrs, [""]) + self.assertEqual(args.is_getitems, [False]) + self.assertEqual(args.preproc_modules, [None]) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: + """ + Test case where preproc module sits in front of sharded module but this cannot be + safely pipelined as it contains trainable params in its child modules + """ + max_feature_lengths = { + "feature_0": 10, + "feature_1": 10, + "feature_2": 10, + "feature_3": 10, + } + + preproc_module = TestPositionWeightedPreprocModule( + max_feature_lengths=max_feature_lengths, + device=self.device, + ) + + model = self._setup_model(preproc_module=preproc_module) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = TrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_preproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + max_feature_lengths=list(max_feature_lengths.values()), + ) + dataloader = iter(data) + + pipeline.progress(dataloader) + + # Check that no modules are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 0) + self.assertEqual(len(pipeline._pipelined_preprocs), 0) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_invalid_preproc_trainable_params_recursive( + self, + ) -> None: + max_feature_lengths = { + "feature_0": 10, + "feature_1": 10, + "feature_2": 10, + "feature_3": 10, + } + + preproc_module = TestPositionWeightedPreprocModule( + max_feature_lengths=max_feature_lengths, + device=self.device, + ) + + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + preproc_module=preproc_module, + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = TrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_preproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + max_feature_lengths=list(max_feature_lengths.values()), + ) + dataloader = iter(data) + pipeline.progress(dataloader) + + # Check that no modules are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 0) + self.assertEqual(len(pipeline._pipelined_preprocs), 0) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: + """ + Test case where preproc module cannot be pipelined because at least one of args + is derived from output of another preproc module whose arg(s) cannot be derived + from input batch (i.e. it has modifying transformations) + """ + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + preproc_module=None, + run_preproc_inline=True, # run preproc inline, outside a module + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = TrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_preproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + ) + dataloader = iter(data) + pipeline.progress(dataloader) + + # Check that only weighted EBC is pipelined + self.assertEqual(len(pipeline._pipelined_modules), 1) + self.assertEqual(len(pipeline._pipelined_preprocs), 1) + self.assertEqual(pipeline._pipelined_modules[0]._is_weighted, True) + self.assertEqual( + pipeline._pipelined_preprocs[0], + sharded_model_pipelined.module.preproc_weighted, + ) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_preproc_fwd_values_cached(self) -> None: + """ + Test to check that during model forward, the preproc module pipelined uses the + saved result from previous iteration(s) and doesn't perform duplicate work + check that fqns for ALL preproc modules are populated in the right train pipeline + context. + """ + extra_input = ModelInput.generate( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + preproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + preproc_module=preproc_module, + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = TrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_preproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + ) + dataloader = iter(data) + + pipeline.progress(dataloader) + + # This was second context that was appended + current_context = pipeline.contexts[0] + cached_results = current_context.preproc_fwd_results + self.assertEqual( + list(cached_results.keys()), + ["_preproc_module", "preproc_nonweighted", "preproc_weighted"], + ) + + # next context cached results should be empty + next_context = pipeline.contexts[1] + next_cached_results = next_context.preproc_fwd_results + self.assertEqual(len(next_cached_results), 0) + + # After progress, next_context should be populated + pipeline.progress(dataloader) + self.assertEqual( + list(next_cached_results.keys()), + ["_preproc_module", "preproc_nonweighted", "preproc_weighted"], + ) + + class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): @unittest.skipIf( not torch.cuda.is_available(), diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 8317f2354..6ca45371a 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -93,12 +93,14 @@ def _setup_model( self, model_type: Type[nn.Module] = TestSparseNN, enable_fsdp: bool = False, + preproc_module: Optional[nn.Module] = None, ) -> nn.Module: unsharded_model = model_type( tables=self.tables, weighted_tables=self.weighted_tables, dense_device=self.device, sparse_device=torch.device("meta"), + preproc_module=preproc_module, ) if enable_fsdp: unsharded_model.over.dhn_arch.linear0 = FSDP( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 854423385..f8dcf08fb 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -8,14 +8,94 @@ # pyre-strict import unittest +from unittest.mock import MagicMock import torch -from torchrec.distributed.train_pipeline.utils import _get_node_args +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule + +from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( + TrainPipelineSparseDistTestBase, +) +from torchrec.distributed.train_pipeline.utils import ( + _get_node_args, + _rewrite_model, + PipelinedForward, + TrainPipelineContext, +) +from torchrec.distributed.types import ShardingType from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -class TestTrainPipelineUtils(unittest.TestCase): +class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_rewrite_model(self) -> None: + sharding_type = ShardingType.TABLE_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + fused_params = {} + + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=10, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + preproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(preproc_module=preproc_module) + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + # Try to rewrite model without ignored_preproc_modules defined, EBC forwards not overwritten to PipelinedForward due to KJT modification + _rewrite_model( + model=sharded_model, + batch=None, + context=TrainPipelineContext(), + dist_stream=None, + ) + self.assertNotIsInstance( + sharded_model.module.sparse.ebc.forward, PipelinedForward + ) + self.assertNotIsInstance( + sharded_model.module.sparse.weighted_ebc.forward, PipelinedForward + ) + + # Now provide preproc module explicitly + _rewrite_model( + model=sharded_model, + batch=None, + context=TrainPipelineContext(), + dist_stream=None, + pipeline_preproc=True, + ) + self.assertIsInstance(sharded_model.module.sparse.ebc.forward, PipelinedForward) + self.assertIsInstance( + sharded_model.module.sparse.weighted_ebc.forward, PipelinedForward + ) + self.assertEqual( + sharded_model.module.sparse.ebc.forward._args[0].preproc_modules[0], + sharded_model.module.preproc_module, + ) + self.assertEqual( + sharded_model.module.sparse.weighted_ebc.forward._args[0].preproc_modules[ + 0 + ], + sharded_model.module.preproc_module, + ) + + +class TestUtils(unittest.TestCase): def test_get_node_args_helper_call_module_kjt(self) -> None: graph = torch.fx.Graph() kjt_args = [] @@ -42,7 +122,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None: ) num_found = 0 - _, num_found = _get_node_args(kjt_node) + _, num_found = _get_node_args( + MagicMock(), kjt_node, set(), TrainPipelineContext(), False + ) # Weights is call_module node, so we should only find 2 args unmodified self.assertEqual(num_found, len(kjt_args) - 1) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 4983c7b4c..b83e0ada8 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -45,6 +45,7 @@ In, Out, PipelinedForward, + PipelinedPreproc, PipelineStage, PrefetchPipelinedForward, PrefetchTrainPipelineContext, @@ -301,6 +302,7 @@ def __init__( execute_all_batches: bool = True, apply_jit: bool = False, context_type: Type[TrainPipelineContext] = TrainPipelineContext, + pipeline_preproc: bool = False, ) -> None: self._model = model self._optimizer = optimizer @@ -342,10 +344,12 @@ def __init__( ] = [] self._model_attached = True + self._pipeline_preproc = pipeline_preproc self._next_index: int = 0 self.contexts: Deque[TrainPipelineContext] = deque() self._pipelined_modules: List[ShardedModule] = [] + self._pipelined_preprocs: List[PipelinedPreproc] = [] self.batches: Deque[Optional[In]] = deque() self._dataloader_iter: Optional[Iterator[In]] = None self._dataloader_exhausted: bool = False @@ -397,6 +401,10 @@ def _set_module_context(self, context: TrainPipelineContext) -> None: for module in self._pipelined_modules: module.forward.set_context(context) + for preproc_module in self._pipelined_preprocs: + # This ensures that next iter model fwd uses cached results + preproc_module.set_context(context) + def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool: batch, context = self.copy_batch_to_gpu(dataloader_iter) if batch is None: @@ -494,13 +502,19 @@ def _pipeline_model( context: TrainPipelineContext, pipelined_forward: Type[PipelinedForward] = PipelinedForward, ) -> None: - self._pipelined_modules, self._model, self._original_forwards = _rewrite_model( + ( + self._pipelined_modules, + self._model, + self._original_forwards, + self._pipelined_preprocs, + ) = _rewrite_model( model=self._model, context=context, dist_stream=self._data_dist_stream, batch=batch, apply_jit=self._apply_jit, pipelined_forward=pipelined_forward, + pipeline_preproc=self._pipeline_preproc, ) # initializes input dist, so we can override input dist forwards self.start_sparse_data_dist(batch, context) @@ -576,9 +590,22 @@ def start_sparse_data_dist( return with record_function(f"## start_sparse_data_dist {context.index} ##"): with self._stream_context(self._data_dist_stream): + if context.event is not None: + context.event.wait() _wait_for_batch(batch, self._memcpy_stream) + + original_contexts = [p.get_context() for p in self._pipelined_preprocs] + + # Temporarily set context for next iter to populate cache + for preproc_mod in self._pipelined_preprocs: + preproc_mod.set_context(context) + _start_data_dist(self._pipelined_modules, batch, context) + # Restore context for model fwd + for module, context in zip(self._pipelined_preprocs, original_contexts): + module.set_context(context) + def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: """ Waits on the input dist splits requests to get the input dist tensors requests, @@ -680,6 +707,7 @@ def __init__( apply_jit: bool = False, start_batch: int = 900, stash_gradients: bool = False, + pipeline_preproc: bool = False, ) -> None: super().__init__( model=model, @@ -688,6 +716,7 @@ def __init__( execute_all_batches=execute_all_batches, apply_jit=apply_jit, context_type=EmbeddingTrainPipelineContext, + pipeline_preproc=pipeline_preproc, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -853,14 +882,9 @@ def start_sparse_data_dist( """ Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version. """ - if batch is None: - return - with record_function(f"## start_sparse_data_dist {context.index} ##"): - with self._stream_context(self._data_dist_stream): - _wait_for_event(batch, self._device, context.event) - _start_data_dist(self._pipelined_modules, batch, context) - context.event = torch.get_device_module(self._device).Event() - context.event.record() + super().start_sparse_data_dist(batch, context) + context.event = torch.get_device_module(self._device).Event() + context.event.record() def start_embedding_lookup( self, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index eca8f84a0..0e9631793 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -48,7 +48,6 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable, Pipelineable - logger: logging.Logger = logging.getLogger(__name__) import torch @@ -97,6 +96,9 @@ class TrainPipelineContext: field(default_factory=list) ) event: Optional[torch.Event] = None + + preproc_fwd_results: Dict[str, Any] = field(default_factory=dict) + index: Optional[int] = None version: int = ( 0 # 1 is current version, 0 is deprecated but supported for backward compatibility @@ -153,9 +155,140 @@ class ArgInfo: input_attrs: List[str] is_getitems: List[bool] + # recursive dataclass as preproc_modules.args -> arginfo.preproc_modules -> so on + preproc_modules: List[Optional["PipelinedPreproc"]] name: Optional[str] +# pyre-ignore +def _build_args_kwargs( + # pyre-ignore + initial_input: Any, + fwd_args: List[ArgInfo], +) -> Tuple[List[Any], Dict[str, Any]]: + args = [] + kwargs = {} + for arg_info in fwd_args: + if arg_info.input_attrs: + arg = initial_input + for attr, is_getitem, preproc_mod in zip( + arg_info.input_attrs, arg_info.is_getitems, arg_info.preproc_modules + ): + if preproc_mod is not None: + # preproc will internally run the same logic recursively + # if its args are derived from other preproc modules + # we can get all inputs to preproc mod based on its recorded args_info + arg passed to it + arg = preproc_mod(arg) + else: + if is_getitem: + arg = arg[attr] + elif attr != "": + arg = getattr(arg, attr) + else: + # neither is_getitem nor valid attr, no-op + arg = arg + if arg_info.name: + kwargs[arg_info.name] = arg + else: + args.append(arg) + else: + args.append(None) + return args, kwargs + + +class PipelinedPreproc(torch.nn.Module): + """ + Wrapper around preproc module found during model graph traversal for sparse data dist + pipelining. In addition to the original module, it encapsulates information needed for + execution such as list of ArgInfo and the current training pipeline context. + + Args: + preproc_module (torch.nn.Module): preproc module to run + fqn (str): fqn of the preproc module in the model being pipelined + args (List[ArgInfo]): list of ArgInfo for the preproc module + context (TrainPipelineContext): Training context for the next iteration / batch + + Returns: + Any + + Example: + preproc = PipelinedPreproc(preproc_module, fqn, args, context) + # module-swap with pipeliend preproc + setattr(model, fqn, preproc) + """ + + def __init__( + self, + preproc_module: torch.nn.Module, + fqn: str, + args: List[ArgInfo], + context: TrainPipelineContext, + ) -> None: + super().__init__() + self._preproc_module = preproc_module + self._fqn = fqn + self._args = args + self._context = context + + @property + def preproc_module(self) -> torch.nn.Module: + return self._preproc_module + + @property + def fqn(self) -> str: + return self._fqn + + # pyre-ignore + def forward(self, *input, **kwargs) -> Any: + """ + Args: + Any args and kwargs during model fwd + During _start_data_dist, input[0] contains the current data + Returns: + Any + """ + if self._fqn in self._context.preproc_fwd_results: + # This should only be hit in two cases: + # 1) During model forward + # During model forward, avoid duplicate work + # by returning the cached result from previous + # iteration's _start_data_dist + # 2) During _start_data_dist when preproc module is + # shared by more than one args. e.g. if we have + # preproc_out_a = preproc_a(input) + # preproc_out_b = preproc_b(preproc_out_a) <- preproc_a shared + # preproc_out_c = preproc_c(preproc_out_a) <-^ + # When processing preproc_b, we cache value of preproc_a(input) + # so when processing preproc_c, we can reuse preproc_a(input) + res = self._context.preproc_fwd_results[self._fqn] + return res + + # Everything below should only be called during _start_data_dist stage + + # Build up arg and kwargs from recursive call to pass to preproc module + # Arguments to preproc module can be also be a derived product + # of another preproc module call, as long as module is pipelineable + + # Use input[0] as _start_data_dist only passes 1 arg + args, kwargs = _build_args_kwargs(input[0], self._args) + + with record_function(f"## sdd_input_preproc {self._context.index} ##"): + res = self._preproc_module(*args, **kwargs) + # Cache results, only during _start_data_dist + self._context.preproc_fwd_results[self._fqn] = res + return res + + @property + def args(self) -> List[ArgInfo]: + return self._args + + def set_context(self, context: TrainPipelineContext) -> None: + self._context = context + + def get_context(self) -> TrainPipelineContext: + return self._context + + class BaseForward: def __init__( self, @@ -258,7 +391,6 @@ def __init__( # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: - assert ( self._name # pyre-ignore [16] @@ -422,22 +554,8 @@ def _start_data_dist( # False means this argument is getting while getattr # and this info was done in the _rewrite_model by tracing the # entire model to get the arg_info_list - args = [] - kwargs = {} - for arg_info in forward.args: - if arg_info.input_attrs: - arg = batch - for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems): - if is_getitem: - arg = arg[attr] - else: - arg = getattr(arg, attr) - if arg_info.name: - kwargs[arg_info.name] = arg - else: - args.append(arg) - else: - args.append(None) + args, kwargs = _build_args_kwargs(batch, forward.args) + # Start input distribution. module_ctx = module.create_context() if context.version == 0: @@ -522,17 +640,32 @@ def _check_args_for_call_module( return False +def _check_preproc_pipelineable( + module: torch.nn.Module, +) -> bool: + for _, _ in module.named_parameters(recurse=True): + # Cannot have any trainable params for it to be pipelined + logger.warning( + f"Module {module} cannot be pipelined as it has trainable parameters" + ) + return False + return True + + def _get_node_args_helper( + model: torch.nn.Module, # pyre-ignore arguments, num_found: int, + pipelined_preprocs: Set[PipelinedPreproc], + context: TrainPipelineContext, + pipeline_preproc: bool, ) -> Tuple[List[ArgInfo], int]: """ Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. It also counts the number of (args + kwargs) found. """ - - arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] + arg_info_list = [ArgInfo([], [], [], None) for _ in range(len(arguments))] for arg, arg_info in zip(arguments, arg_info_list): if arg is None: num_found += 1 @@ -547,6 +680,13 @@ def _get_node_args_helper( # pyre-ignore[16] arg_info.input_attrs.insert(0, child_node.ph_key) arg_info.is_getitems.insert(0, False) + arg_info.preproc_modules.insert(0, None) + else: + # no-op + arg_info.input_attrs.insert(0, "") + arg_info.is_getitems.insert(0, False) + arg_info.preproc_modules.insert(0, None) + num_found += 1 break elif ( @@ -561,6 +701,7 @@ def _get_node_args_helper( # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, False) + arg_info.preproc_modules.insert(0, None) arg = child_node.args[0] elif ( child_node.op == "call_function" @@ -574,6 +715,7 @@ def _get_node_args_helper( # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, True) + arg_info.preproc_modules.insert(0, None) arg = child_node.args[0] elif ( child_node.op == "call_function" @@ -610,18 +752,109 @@ def _get_node_args_helper( arg = child_node.kwargs["values"] else: arg = child_node.args[1] + elif child_node.op == "call_module": + preproc_module_fqn = str(child_node.target) + preproc_module = getattr(model, preproc_module_fqn, None) + + if not pipeline_preproc: + logger.warning( + f"Found module {preproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_preproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_preproc=True`" + ) + break + + if not preproc_module: + # Could not find such module, should not happen + break + + if isinstance(preproc_module, PipelinedPreproc): + # Already did module swap and registered args, early exit + arg_info.input_attrs.insert(0, "") # dummy value + arg_info.is_getitems.insert(0, False) + pipelined_preprocs.add(preproc_module) + arg_info.preproc_modules.insert(0, preproc_module) + num_found += 1 + break + + if not isinstance(preproc_module, torch.nn.Module): + logger.warning( + f"Expected preproc_module to be nn.Module but was {type(preproc_module)}" + ) + break + + # check if module is safe to pipeline i.e.no trainable param + if not _check_preproc_pipelineable(preproc_module): + break + + # For module calls, `self` isn't counted + total_num_args = len(child_node.args) + len(child_node.kwargs) + if total_num_args == 0: + # module call without any args, assume KJT modified + break + + # recursive call to check that all inputs to this preproc module + # is either made of preproc module or non-modifying train batch input + # transformations + preproc_args, num_found_safe_preproc_args = _get_node_args( + model, child_node, pipelined_preprocs, context, pipeline_preproc + ) + if num_found_safe_preproc_args == total_num_args: + logger.info( + f"""Module {preproc_module} is a valid preproc module (no + trainable params and inputs can be derived from train batch input + via a series of either valid preproc modules or non-modifying + transformations) and will be applied during sparse data dist + stage""" + ) + + pipelined_preproc_module = PipelinedPreproc( + preproc_module, + preproc_module_fqn, + preproc_args, + context, + ) + + # module swap + setattr(model, preproc_module_fqn, pipelined_preproc_module) + + arg_info.input_attrs.insert(0, "") # dummy value + arg_info.is_getitems.insert(0, False) + pipelined_preprocs.add(pipelined_preproc_module) + arg_info.preproc_modules.insert(0, pipelined_preproc_module) + + num_found += 1 + + # we cannot set any other `arg` value here + # break to avoid infinite loop + break else: break return arg_info_list, num_found def _get_node_args( + model: torch.nn.Module, node: Node, + pipelined_preprocs: Set[PipelinedPreproc], + context: TrainPipelineContext, + pipeline_preproc: bool, ) -> Tuple[List[ArgInfo], int]: num_found = 0 - pos_arg_info_list, num_found = _get_node_args_helper(node.args, num_found) + + pos_arg_info_list, num_found = _get_node_args_helper( + model, + node.args, + num_found, + pipelined_preprocs, + context, + pipeline_preproc, + ) kwargs_arg_info_list, num_found = _get_node_args_helper( - node.kwargs.values(), num_found + model, + node.kwargs.values(), + num_found, + pipelined_preprocs, + context, + pipeline_preproc, ) # Replace with proper names for kwargs @@ -629,7 +862,8 @@ def _get_node_args( arg_info_list.name = name arg_info_list = pos_arg_info_list + kwargs_arg_info_list - return arg_info_list, num_found + + return (arg_info_list, num_found) def _get_leaf_module_names_helper( @@ -744,7 +978,13 @@ def _rewrite_model( # noqa C901 batch: Optional[In] = None, apply_jit: bool = False, pipelined_forward: Type[BaseForward] = PipelinedForward, -) -> Tuple[List[ShardedModule], torch.nn.Module, List[Callable[..., Any]]]: + pipeline_preproc: bool = False, +) -> Tuple[ + List[ShardedModule], + torch.nn.Module, + List[Callable[..., Any]], + List[PipelinedPreproc], +]: input_model = model # Get underlying nn.Module if isinstance(model, DistributedModelParallel): @@ -781,12 +1021,21 @@ def _rewrite_model( # noqa C901 # i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'. pipelined_forwards = [] original_forwards = [] + + pipelined_preprocs: Set[PipelinedPreproc] = set() + for node in graph.nodes: if node.op == "call_module" and node.target in sharded_modules: total_num_args = len(node.args) + len(node.kwargs) if total_num_args == 0: continue - arg_info_list, num_found = _get_node_args(node) + arg_info_list, num_found = _get_node_args( + model, + node, + pipelined_preprocs, + context, + pipeline_preproc, + ) if num_found == total_num_args: logger.info(f"Module '{node.target}'' will be pipelined") @@ -812,7 +1061,7 @@ def _rewrite_model( # noqa C901 if isinstance(input_model, DistributedModelParallel): input_model.module = graph_model - return pipelined_forwards, input_model, original_forwards + return pipelined_forwards, input_model, original_forwards, list(pipelined_preprocs) def _override_input_dist_forwards( @@ -985,7 +1234,8 @@ def detach(self) -> torch.nn.Module: def start_sparse_data_dist(self, batch: In) -> In: if not self.initialized: # Step 1: Pipeline input dist in trec sharded modules - self._pipelined_modules, self.model, self._original_forwards = ( + # TODO (yhshin): support preproc modules for `StagedTrainPipeline` + self._pipelined_modules, self.model, self._original_forwards, _ = ( _rewrite_model( model=self.model, context=self.context,