diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index 3a09b39a7..1dd0b1209 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -9,7 +9,7 @@ import os import unittest from dataclasses import dataclass -from typing import cast, Dict, List, Optional, Tuple +from typing import Any, cast, Dict, List, Optional, Tuple from unittest.mock import MagicMock import torch @@ -17,6 +17,7 @@ from hypothesis import given, settings, strategies as st, Verbosity from torch import nn, optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim import Optimizer from torchrec.distributed import DistributedModelParallel from torchrec.distributed.embedding_types import EmbeddingComputeKernel, KJTList from torchrec.distributed.embeddingbag import ( @@ -41,6 +42,7 @@ from torchrec.distributed.train_pipeline import ( DataLoadingThread, EvalPipelineSparseDist, + PrefetchTrainPipelineSparseDist, TrainPipelineBase, TrainPipelineSparseDist, ) @@ -52,7 +54,7 @@ ShardingPlan, ShardingType, ) -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.optim.keyed import KeyedOptimizerWrapper @@ -177,10 +179,12 @@ def test_equal_to_non_pipelined(self) -> None: pred_gpu = pipeline.progress(dataloader) self.assertEqual(pred_gpu.device, self.device) + # Results will be close but not exactly equal as one model is on CPU and other on GPU + # If both were on GPU, the results will be exactly the same self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) -class TrainPipelineSparseDistTest(unittest.TestCase): +class TrainPipelineSparseDistTestBase(unittest.TestCase): def setUp(self) -> None: os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) @@ -191,7 +195,6 @@ def setUp(self) -> None: num_features = 4 num_weighted_features = 2 - self.tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 100, @@ -217,6 +220,31 @@ def tearDown(self) -> None: super().tearDown() dist.destroy_process_group(self.pg) + def _generate_data( + self, + num_batches: int = 5, + batch_size: int = 1, + ) -> List[ModelInput]: + return [ + ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=1, + num_float_features=10, + )[0] + for i in range(num_batches) + ] + + def _set_table_weights_precision(self, dtype: DataType) -> None: + for i in range(len(self.tables)): + self.tables[i].data_type = dtype + + for i in range(len(self.weighted_tables)): + self.weighted_tables[i].data_type = dtype + + +class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( not torch.cuda.is_available(), @@ -336,7 +364,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: optimizer_no_pipeline.step() pred_pipeline = pipeline.progress(dataloader) - torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) + self.assertTrue(torch.equal(pred_pipeline.cpu(), pred.cpu())) self.assertEqual(len(pipeline._pipelined_modules), 1) self.assertIsInstance( @@ -405,43 +433,6 @@ def _setup_cpu_model_and_opt(self) -> Tuple[TestSparseNN, optim.SGD]: cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=0.1) return cpu_model, cpu_optimizer - def _generate_data(self, num_batches: int = 5) -> List[ModelInput]: - return [ - ModelInput.generate( - tables=self.tables, - weighted_tables=self.weighted_tables, - batch_size=1, - world_size=1, - num_float_features=10, - )[0] - for i in range(num_batches) - ] - - def _test_pipelining( - self, - sharder: EmbeddingBagCollectionSharder, - execute_all_batches: bool, - ) -> None: - pipeline = self._setup_pipeline(sharder, execute_all_batches) - cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt() - data = self._generate_data() - - dataloader = iter(data) - if not execute_all_batches: - data = data[:-2] - - for batch in data: - cpu_optimizer.zero_grad() - loss, pred = cpu_model(batch) - loss.backward() - cpu_optimizer.step() - - pred_gpu = pipeline.progress(dataloader) - - self.assertEqual(len(pipeline._pipelined_modules), 2) - self.assertEqual(pred_gpu.device, self.device) - self.assertEqual(pred_gpu.cpu().size(), pred.size()) - @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", @@ -459,7 +450,6 @@ def test_pipelining(self, execute_all_batches: bool) -> None: ) cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt() data = self._generate_data() - dataloader = iter(data) if not execute_all_batches: data = data[:-2] @@ -574,6 +564,152 @@ def test_multi_dataloader_pipelining(self) -> None: ) +class PrefetchTrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): + def generate_sharded_model_and_optimizer( + self, + model: nn.Module, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + ) -> Tuple[nn.Module, Optimizer]: + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ) + optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) + return sharded_model, optimizer + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + execute_all_batches=st.booleans(), + weight_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + cache_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + load_factor=st.sampled_from( + [ + 0.2, + 0.4, + 0.6, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_equal_to_non_pipelined( + self, + execute_all_batches: bool, + weight_precision: DataType, + cache_precision: DataType, + load_factor: float, + sharding_type: str, + kernel_type: str, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + mixed_precision: bool = weight_precision != cache_precision + self._set_table_weights_precision(weight_precision) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "cache_load_factor": load_factor, + "cache_precision": cache_precision, + "stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16 + } + fused_params_pipelined = { + **fused_params, + "prefetch_pipeline": True, + } + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + sharded_model, optim = self.generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self.generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = PrefetchTrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + ) + + if not execute_all_batches: + data = data[:-3] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if not mixed_precision: + # Rounding error is expected when using different precisions for weights and cache + self.assertTrue(torch.equal(pred, pred_pipeline)) + else: + torch.testing.assert_close(pred, pred_pipeline) + + class DataLoadingThreadTest(unittest.TestCase): def test_fetch_data(self) -> None: data = []