diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index bc5c41ba7..bef0ff9af 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -139,6 +139,13 @@ def setUp(self, backend: str = "nccl") -> None: SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), qcomms_config=st.sampled_from( [ None, @@ -162,6 +169,7 @@ def setUp(self, backend: str = "nccl") -> None: def test_sharding_rw( self, sharder_type: str, + kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] @@ -174,7 +182,6 @@ def test_sharding_rw( ) sharding_type = ShardingType.ROW_WISE.value - kernel_type = EmbeddingComputeKernel.FUSED.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size @@ -206,6 +213,11 @@ def test_sharding_rw( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + ], + ), apply_optimizer_in_backward_config=st.sampled_from([None]), # TODO - need to enable optimizer overlapped behavior for data_parallel tables ) @@ -213,12 +225,12 @@ def test_sharding_rw( def test_sharding_dp( self, sharder_type: str, + kernel_type: str, apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], ) -> None: sharding_type = ShardingType.DATA_PARALLEL.value - kernel_type = EmbeddingComputeKernel.DENSE.value self._test_sharding( # pyre-ignore[6] sharders=[ @@ -236,6 +248,13 @@ def test_sharding_dp( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), qcomms_config=st.sampled_from( [ None, @@ -259,14 +278,20 @@ def test_sharding_dp( def test_sharding_cw( self, sharder_type: str, + kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharding_type = ShardingType.COLUMN_WISE.value - kernel_type = EmbeddingComputeKernel.FUSED.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size @@ -300,6 +325,90 @@ def test_sharding_cw( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_twcw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_COLUMN_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), qcomms_config=st.sampled_from( [ # None, @@ -324,14 +433,97 @@ def test_sharding_cw( def test_sharding_tw( self, sharder_type: str, + kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharding_type = ShardingType.TABLE_WISE.value - kernel_type = EmbeddingComputeKernel.FUSED.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.BF16, + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_twrw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.TABLE_ROW_WISE.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size @@ -364,6 +556,8 @@ def test_sharding_tw( ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), global_constant_batch=st.booleans(), diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index a27035e26..9437a7ee6 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -11,7 +11,6 @@ from typing import Any, Callable, cast, Dict, List, Optional, OrderedDict, Tuple import numpy as np - import torch import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType @@ -23,10 +22,7 @@ EmbeddingComputeKernel, EmbeddingTableConfig, ) -from torchrec.distributed.embeddingbag import ( - EmbeddingBagCollectionSharder, - ShardedEmbeddingBagCollection, -) +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.fused_embeddingbag import ShardedFusedEmbeddingBagCollection from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner import ( @@ -469,19 +465,29 @@ def test_meta_device_dmp_state_dict(self) -> None: ), sharding_type=st.sampled_from( [ + ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), kernel_type=st.sampled_from( [ EmbeddingComputeKernel.FUSED.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), + is_training=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_load_state_dict( - self, sharder_type: str, sharding_type: str, kernel_type: str + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + is_training: bool, ) -> None: if ( self.device == torch.device("cpu") @@ -506,11 +512,20 @@ def test_load_state_dict( m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) # validate the models are equivalent - with torch.no_grad(): - loss1, pred1 = m1(batch) - loss2, pred2 = m2(batch) - self.assertTrue(torch.equal(loss1, loss2)) - self.assertTrue(torch.equal(pred1, pred2)) + if is_training: + for _ in range(2): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + loss1.backward() + loss2.backward() + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + else: + with torch.no_grad(): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) sd1 = m1.state_dict() for key, value in m2.state_dict().items(): v2 = sd1[key] @@ -543,10 +558,11 @@ def test_load_state_dict( EmbeddingComputeKernel.DENSE.value, ] ), + is_training=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_load_state_dict_dp( - self, sharder_type: str, sharding_type: str, kernel_type: str + self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool ) -> None: sharders = [ cast( @@ -565,11 +581,20 @@ def test_load_state_dict_dp( m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) # validate the models are equivalent - with torch.no_grad(): - loss1, pred1 = m1(batch) - loss2, pred2 = m2(batch) - self.assertTrue(torch.equal(loss1, loss2)) - self.assertTrue(torch.equal(pred1, pred2)) + if is_training: + for _ in range(2): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + loss1.backward() + loss2.backward() + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + else: + with torch.no_grad(): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) sd1 = m1.state_dict() for key, value in m2.state_dict().items(): v2 = sd1[key] @@ -587,17 +612,45 @@ def test_load_state_dict_dp( # pyre-ignore[56] @given( - sharders=st.sampled_from( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( [ - [EmbeddingBagCollectionSharder()], - # [EmbeddingBagSharder()], + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_load_state_dict_prefix( - self, sharders: List[ModuleSharder[nn.Module]] + self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder(sharder_type, sharding_type, kernel_type), + ), + ] (m1, m2), batch = self._generate_dmps_and_batch(sharders) # load the second's (m2's) with the first (m1's) state_dict @@ -607,6 +660,21 @@ def test_load_state_dict_prefix( ) # validate the models are equivalent + if is_training: + for _ in range(2): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + loss1.backward() + loss2.backward() + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + else: + with torch.no_grad(): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + sd1 = m1.state_dict() for key, value in m2.state_dict().items(): v2 = sd1[key] @@ -632,19 +700,31 @@ def test_load_state_dict_prefix( sharding_type=st.sampled_from( [ ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), kernel_type=st.sampled_from( [ # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_params_and_buffers( self, sharder_type: str, sharding_type: str, kernel_type: str ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharders = [ create_test_sharder(sharder_type, sharding_type, kernel_type), ] @@ -671,13 +751,22 @@ def test_params_and_buffers( kernel_type=st.sampled_from( [ EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), + is_training=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_load_state_dict_cw_multiple_shards( - self, sharder_type: str, sharding_type: str, kernel_type: str + self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharders = [ cast( ModuleSharder[nn.Module], @@ -717,10 +806,20 @@ def test_load_state_dict_cw_multiple_shards( m2.fused_optimizer.load_state_dict(src_optimizer_state_dict) # validate the models are equivalent - loss1, pred1 = m1(batch) - loss2, pred2 = m2(batch) - self.assertTrue(torch.equal(loss1, loss2)) - self.assertTrue(torch.equal(pred1, pred2)) + if is_training: + for _ in range(2): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + loss1.backward() + loss2.backward() + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + else: + with torch.no_grad(): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) sd1 = m1.state_dict() for key, value in m2.state_dict().items():